fix mergeTriviallyConnectedComputes
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m48s

This commit is contained in:
NiccoloN
2026-04-14 15:13:45 +02:00
parent 77fe293062
commit a7dee5b840

View File

@@ -8,6 +8,7 @@
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
@@ -22,8 +23,8 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -51,7 +52,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private: private:
void annotateWeightsConstants(func::FuncOp funcOp) const; void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp); void encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeSingleChildCompute(func::FuncOp funcOp); void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
}; };
} // namespace } // namespace
@@ -148,7 +149,7 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc); encapsulateGlobalInstruction(*entryFunc);
mergeSingleChildCompute(*entryFunc); mergeTriviallyConnectedComputes(*entryFunc);
// Dump to file for debug // Dump to file for debug
dumpModule(moduleOp, "spatial0"); dumpModule(moduleOp, "spatial0");
@@ -230,66 +231,61 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
} }
} }
void ONNXToSpatialPass::mergeSingleChildCompute(func::FuncOp funcOp) { void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
llvm::SmallVector<spatial::SpatWeightedCompute> computeSingleChild;
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>()) for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) { if (compute->hasOneUse()) {
auto user = *compute->getUsers().begin(); auto user = *compute->getUsers().begin();
if (user->getNumOperands() == 1) if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
if (llvm::isa<spatial::SpatWeightedCompute>(user)) trivialComputes.push_back(compute);
computeSingleChild.push_back(compute);
} }
IRMapping mapper; while (!trivialComputes.empty()) {
while (!computeSingleChild.empty()) { auto compute = trivialComputes.front();
auto compute = computeSingleChild.front();
auto child = dyn_cast_if_present<spatial::SpatWeightedCompute>(*compute->getUsers().begin()); if (compute.use_empty()) {
assert(child && "Child required!"); std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
rewriter.setInsertionPointAfter(compute.getOperation()); rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute = auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{(int) compute.getWeights().size(), (int) compute.getInputs().size()}); {static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
llvm::dbgs() << "After Creation\n";
newCompute.dump();
IRMapping mapper;
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
llvm::dbgs() << "After Clone\n";
newCompute.dump();
auto newTerminator = newCompute.getBody().front().getTerminator(); auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0)); mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
newTerminator->erase(); newTerminator->erase();
llvm::dbgs() << "After terminator\n";
newCompute.dump();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) for (auto& op : child.getBody().front())
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
child.replaceAllUsesWith(newCompute); child.replaceAllUsesWith(newCompute);
assert(child->getUses().empty() && "It's not obvius"); toErase.insert(child);
llvm::dbgs() << "Node\n";
newCompute.dump();
llvm::dbgs() << "Parent\n"; std::swap(trivialComputes.front(), trivialComputes.back());
compute.dump(); trivialComputes.pop_back();
toErase.insert(compute);
llvm::dbgs() << "Child\n"; if (newCompute->hasOneUse()) {
child.dump();
child.erase();
compute.erase();
if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) {
auto user = *newCompute->getUsers().begin(); auto user = *newCompute->getUsers().begin();
if (user->getNumOperands() == 1) if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
if (llvm::isa<spatial::SpatWeightedCompute>(user)) trivialComputes.push_back(newCompute);
computeSingleChild.push_back(newCompute);
} }
std::swap(computeSingleChild.front(), computeSingleChild.back()); }
computeSingleChild.pop_back();
for (auto compute : toErase) {
compute.getResult(0).dropAllUses();
compute.erase();
} }
} }