diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 1c47aba..862fa72 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -13,6 +14,8 @@ #include "llvm/Support/raw_os_ostream.h" #include +#include +#include #include "Common.hpp" #include "Common/PimCommon.hpp" @@ -48,6 +51,7 @@ struct ONNXToSpatialPass : PassWrapper computeSingleChild; + Location loc = funcOp.getLoc(); + IRRewriter rewriter(&getContext()); + for (auto compute : funcOp.getOps()) + if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) { + auto user = *compute->getUsers().begin(); + if (user->getNumOperands() == 1) + if (llvm::isa(user)) + computeSingleChild.push_back(compute); + } + + IRMapping mapper; + while (!computeSingleChild.empty()) { + auto compute = computeSingleChild.front(); + auto child = dyn_cast_if_present(*compute->getUsers().begin()); + assert(child && "Child required!"); + rewriter.setInsertionPointAfter(compute.getOperation()); + auto newCompute = + spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); + newCompute.getProperties().setOperandSegmentSizes( + {(int) compute.getWeights().size(), (int) compute.getInputs().size()}); + llvm::dbgs() << "After Creation\n"; + newCompute.dump(); + + compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); + llvm::dbgs() << "After Clone\n"; + newCompute.dump(); + auto newTerminator = newCompute.getBody().front().getTerminator(); + mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0)); + newTerminator->erase(); + llvm::dbgs() << "After terminator\n"; + newCompute.dump(); + rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); + + for (auto& op : child.getBody().front()) + rewriter.clone(op, mapper); + + child.replaceAllUsesWith(newCompute); + assert(child->getUses().empty() && "It's not obvius"); + llvm::dbgs() << "Node\n"; + newCompute.dump(); + + llvm::dbgs() << "Parent\n"; + compute.dump(); + + llvm::dbgs() << "Child\n"; + child.dump(); + + child.erase(); + compute.erase(); + + if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) { + auto user = *newCompute->getUsers().begin(); + if (user->getNumOperands() == 1) + if (llvm::isa(user)) + computeSingleChild.push_back(newCompute); + } + std::swap(computeSingleChild.front(), computeSingleChild.back()); + computeSingleChild.pop_back(); + } +} + void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { funcOp.walk([&](arith::ConstantOp constantOp) { bool isAlwaysWeight =