Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
@@ -13,6 +14,8 @@
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
@@ -48,6 +51,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
void mergeSingleChildCompute(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -144,6 +148,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
encapsulateGlobalInstruction(*entryFunc);
|
||||
mergeSingleChildCompute(*entryFunc);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "spatial0");
|
||||
@@ -225,6 +230,69 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::mergeSingleChildCompute(func::FuncOp funcOp) {
|
||||
llvm::SmallVector<spatial::SpatWeightedCompute> computeSingleChild;
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||
if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) {
|
||||
auto user = *compute->getUsers().begin();
|
||||
if (user->getNumOperands() == 1)
|
||||
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
||||
computeSingleChild.push_back(compute);
|
||||
}
|
||||
|
||||
IRMapping mapper;
|
||||
while (!computeSingleChild.empty()) {
|
||||
auto compute = computeSingleChild.front();
|
||||
auto child = dyn_cast_if_present<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
assert(child && "Child required!");
|
||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) compute.getWeights().size(), (int) compute.getInputs().size()});
|
||||
llvm::dbgs() << "After Creation\n";
|
||||
newCompute.dump();
|
||||
|
||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||
llvm::dbgs() << "After Clone\n";
|
||||
newCompute.dump();
|
||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||
newTerminator->erase();
|
||||
llvm::dbgs() << "After terminator\n";
|
||||
newCompute.dump();
|
||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||
|
||||
for (auto& op : child.getBody().front())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
child.replaceAllUsesWith(newCompute);
|
||||
assert(child->getUses().empty() && "It's not obvius");
|
||||
llvm::dbgs() << "Node\n";
|
||||
newCompute.dump();
|
||||
|
||||
llvm::dbgs() << "Parent\n";
|
||||
compute.dump();
|
||||
|
||||
llvm::dbgs() << "Child\n";
|
||||
child.dump();
|
||||
|
||||
child.erase();
|
||||
compute.erase();
|
||||
|
||||
if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) {
|
||||
auto user = *newCompute->getUsers().begin();
|
||||
if (user->getNumOperands() == 1)
|
||||
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
||||
computeSingleChild.push_back(newCompute);
|
||||
}
|
||||
std::swap(computeSingleChild.front(), computeSingleChild.back());
|
||||
computeSingleChild.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
|
||||
Reference in New Issue
Block a user