MergeDCP pass all test
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s
This commit is contained in:
@@ -5,11 +5,15 @@
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
@@ -42,6 +46,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -126,11 +131,88 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
encapsulateGlobalInstruction(*entryFunc);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "spatial");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
|
||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||
Value source = funcSource(toRemoveOp);
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
||||
auto sources = toRemoveOp.getInputs();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (llvm::any_of(
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
||||
llvm::SmallVector<Type> sourceTypes;
|
||||
llvm::SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources){
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
}
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
for(auto [source,bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO what we want to keep in global?
|
||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
bool keep = true;
|
||||
while (keep) {
|
||||
keep = false;
|
||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||
keep |= encapsulator<tensor::ExtractSliceOp>(
|
||||
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
||||
|
||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
||||
|
||||
keep |= encapsulator<ONNXTransposeOp>(
|
||||
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
|
||||
|
||||
keep |= encapsulator<tensor::CollapseShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
|
||||
|
||||
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
|
||||
Reference in New Issue
Block a user