Not functioning merging

This commit is contained in:
ilgeco
2026-04-14 13:43:30 +02:00
parent ab8aff5bac
commit 6727785ab7

View File

@@ -1,6 +1,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -13,6 +14,8 @@
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <iterator>
#include <utility>
#include "Common.hpp"
#include "Common/PimCommon.hpp"
@@ -48,6 +51,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeSingleChildCompute(func::FuncOp funcOp);
};
} // namespace
@@ -144,6 +148,7 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
mergeSingleChildCompute(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial");
@@ -225,6 +230,69 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
}
}
void ONNXToSpatialPass::mergeSingleChildCompute(func::FuncOp funcOp) {
llvm::SmallVector<spatial::SpatWeightedCompute> computeSingleChild;
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) {
auto user = *compute->getUsers().begin();
if (user->getNumOperands() == 1)
if (llvm::isa<spatial::SpatWeightedCompute>(user))
computeSingleChild.push_back(compute);
}
IRMapping mapper;
while (!computeSingleChild.empty()) {
auto compute = computeSingleChild.front();
auto child = dyn_cast_if_present<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
assert(child && "Child required!");
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{(int) compute.getWeights().size(), (int) compute.getInputs().size()});
llvm::dbgs() << "After Creation\n";
newCompute.dump();
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
llvm::dbgs() << "After Clone\n";
newCompute.dump();
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
newTerminator->erase();
llvm::dbgs() << "After terminator\n";
newCompute.dump();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front())
rewriter.clone(op, mapper);
child.replaceAllUsesWith(newCompute);
assert(child->getUses().empty() && "It's not obvius");
llvm::dbgs() << "Node\n";
newCompute.dump();
llvm::dbgs() << "Parent\n";
compute.dump();
llvm::dbgs() << "Child\n";
child.dump();
child.erase();
compute.erase();
if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) {
auto user = *newCompute->getUsers().begin();
if (user->getNumOperands() == 1)
if (llvm::isa<spatial::SpatWeightedCompute>(user))
computeSingleChild.push_back(newCompute);
}
std::swap(computeSingleChild.front(), computeSingleChild.back());
computeSingleChild.pop_back();
}
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =