MergeDCP pass all test
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s

This commit is contained in:
ilgeco
2026-04-08 20:39:01 +02:00
parent 813368f625
commit 3f870fb74b
3 changed files with 157 additions and 44 deletions

View File

@@ -5,11 +5,15 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -42,6 +46,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
};
} // namespace
@@ -126,11 +131,88 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial");
}
template <typename T>
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
inst->replaceAllUsesWith(newCompute);
inst->erase();
return true;
}
}
return false;
}
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of(
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
llvm::SmallVector<Type> sourceTypes;
llvm::SmallVector<Location> sourceLoc;
for (auto source : sources){
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for(auto [source,bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
inst->replaceAllUsesWith(newCompute);
inst->erase();
return true;
}
}
return false;
}
// TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
bool keep = true;
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
keep |= encapsulator<tensor::ExtractSliceOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
keep |= encapsulator<ONNXTransposeOp>(
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
keep |= encapsulator<tensor::CollapseShapeOp>(
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
keep |= encapsulateConcat(rewriter, loc, &instruction);
}
}
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =