Not functioning merging
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
@@ -13,6 +14,8 @@
|
|||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <iterator>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
@@ -48,6 +51,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
private:
|
private:
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||||
|
void mergeSingleChildCompute(func::FuncOp funcOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -144,6 +148,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
encapsulateGlobalInstruction(*entryFunc);
|
||||||
|
mergeSingleChildCompute(*entryFunc);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "spatial");
|
dumpModule(moduleOp, "spatial");
|
||||||
@@ -225,6 +230,69 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ONNXToSpatialPass::mergeSingleChildCompute(func::FuncOp funcOp) {
|
||||||
|
llvm::SmallVector<spatial::SpatWeightedCompute> computeSingleChild;
|
||||||
|
Location loc = funcOp.getLoc();
|
||||||
|
IRRewriter rewriter(&getContext());
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||||
|
if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) {
|
||||||
|
auto user = *compute->getUsers().begin();
|
||||||
|
if (user->getNumOperands() == 1)
|
||||||
|
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
||||||
|
computeSingleChild.push_back(compute);
|
||||||
|
}
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
while (!computeSingleChild.empty()) {
|
||||||
|
auto compute = computeSingleChild.front();
|
||||||
|
auto child = dyn_cast_if_present<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||||
|
assert(child && "Child required!");
|
||||||
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
|
auto newCompute =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{(int) compute.getWeights().size(), (int) compute.getInputs().size()});
|
||||||
|
llvm::dbgs() << "After Creation\n";
|
||||||
|
newCompute.dump();
|
||||||
|
|
||||||
|
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||||
|
llvm::dbgs() << "After Clone\n";
|
||||||
|
newCompute.dump();
|
||||||
|
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||||
|
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||||
|
newTerminator->erase();
|
||||||
|
llvm::dbgs() << "After terminator\n";
|
||||||
|
newCompute.dump();
|
||||||
|
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||||
|
|
||||||
|
for (auto& op : child.getBody().front())
|
||||||
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
|
child.replaceAllUsesWith(newCompute);
|
||||||
|
assert(child->getUses().empty() && "It's not obvius");
|
||||||
|
llvm::dbgs() << "Node\n";
|
||||||
|
newCompute.dump();
|
||||||
|
|
||||||
|
llvm::dbgs() << "Parent\n";
|
||||||
|
compute.dump();
|
||||||
|
|
||||||
|
llvm::dbgs() << "Child\n";
|
||||||
|
child.dump();
|
||||||
|
|
||||||
|
child.erase();
|
||||||
|
compute.erase();
|
||||||
|
|
||||||
|
if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) {
|
||||||
|
auto user = *newCompute->getUsers().begin();
|
||||||
|
if (user->getNumOperands() == 1)
|
||||||
|
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
||||||
|
computeSingleChild.push_back(newCompute);
|
||||||
|
}
|
||||||
|
std::swap(computeSingleChild.front(), computeSingleChild.back());
|
||||||
|
computeSingleChild.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
bool isAlwaysWeight =
|
bool isAlwaysWeight =
|
||||||
|
|||||||
Reference in New Issue
Block a user