From 5c839e62c1a68371a34016c16ac255469a1d8796 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Mon, 27 Apr 2026 13:48:03 +0200 Subject: [PATCH 1/2] Func Input converted to symbol --- src/PIM/Common/PimCommon.cpp | 1 + src/PIM/Compiler/PimCodeGen.cpp | 5 + src/PIM/Compiler/PimCodeGen.hpp | 1 + .../Conversion/SpatialToPim/CMakeLists.txt | 1 + src/PIM/Conversion/SpatialToPim/Patterns.cpp | 88 ++++++++++++++ src/PIM/Conversion/SpatialToPim/Patterns.hpp | 10 ++ .../SpatialToPim/SpatialToPimPass.cpp | 110 ++++++------------ .../Bufferization/PimBufferizationPass.cpp | 10 ++ 8 files changed, 151 insertions(+), 75 deletions(-) create mode 100644 src/PIM/Conversion/SpatialToPim/Patterns.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/Patterns.hpp diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index 3ca1839..2333e99 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index a2deda2..080a335 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -412,6 +412,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa emitInstruction(std::move(json)); } +void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const { +} + void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const { auto srcAddr = addressOf(transposeOp.getInput(), knowledge); auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge); @@ -581,6 +584,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge); else if (auto vsoftmaxOp = dyn_cast(op)) coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); + else if (auto getGlobalOp = dyn_cast(op)) + coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); else { op.emitError("Unsupported codegen for this operation"); op.dump(); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index bad6b55..42e2fb1 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -106,6 +106,7 @@ public: void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const; void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const; + void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const; }; diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index d8222c8..351d96d 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -5,6 +5,7 @@ add_public_tablegen_target(SpatialToPimIncGen) add_pim_library(OMSpatialToPim SpatialToPimPass.cpp Common.cpp + Patterns.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp new file mode 100644 index 0000000..5dbecd2 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -0,0 +1,88 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +#include "llvm/ADT/STLExtras.h" + +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override { + + if (funcOp.getArguments().empty()) + return failure(); + + if (llvm::all_of(funcOp.getArguments(), + [](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); })) + return failure(); + + Location loc = funcOp.getLoc(); + + for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) { + if (arg.getUses().empty()) + continue; + + rewriter.setInsertionPoint(funcOp.getOperation()); + + assert(isa(arg.getType())); + + auto argRankedTensorType = llvm::dyn_cast(arg.getType()); + mlir::MemRefType memRefType = + mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType()); + + std::string argName = "arg_" + std::to_string(index); + + memref::GlobalOp::create(rewriter, + loc, + rewriter.getStringAttr(argName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + {}, + {}, + {}); + + for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) { + auto argUser = argUses.getOwner(); + if (auto spatCompute = dyn_cast(argUser)) { + auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create(rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(toTensor); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else { + rewriter.setInsertionPoint(argUser); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + rewriter.startOpModification(argUser); + argUses.set(getGlobalOp); + rewriter.finalizeOpModification(argUser); + } + } + } + + return success(); + } +}; + +} // namespace +void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.hpp b/src/PIM/Conversion/SpatialToPim/Patterns.hpp new file mode 100644 index 0000000..e34f6ab --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Patterns.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + + +namespace onnx_mlir { + +void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns); + +} diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index ebdb429..ba7beac 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -1,20 +1,24 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_os_ostream.h" #include @@ -23,6 +27,7 @@ #include #include "Conversion/ONNXToSpatial/Common.hpp" +#include "Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -146,12 +151,24 @@ void SpatialToPimPass::runOnOperation() { scf::SCFDialect, BuiltinDialect>(); - RewritePatternSet patterns(ctx); - populateWithGenerated(patterns); + { + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { - signalPassFailure(); - return; + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } + + { + RewritePatternSet patterns(ctx); + populateGlobalTensorToMemrefPatterns(patterns); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + return; + } } auto entryFunc = getPimEntryFunc(moduleOp); @@ -466,11 +483,11 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); - auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { - auto tensorType = cast(valueToReplace.getType()); + auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { + auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; - rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); + rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); @@ -479,85 +496,28 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu loc, tensorType, deviceTensor, - hostTensor, + inputTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); - rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult()); + rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; - // Replace input tensors with memRefs - SmallVector inputTensors; - for (size_t i = 0; i < funcOp.getNumArguments(); i++) { - BlockArgument tensorArg = funcOp.getArgument(i); - DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i); - ShapedType tensorArgType = cast(tensorArg.getType()); - MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); - - if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc))) - return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering"); - BlockArgument memRefArg = funcOp.getArgument(i + 1); - - Block& block = funcOp.getBody().front(); - rewriter.setInsertionPoint(&block.front()); - auto toTensorOp = - bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); - inputTensors.push_back(toTensorOp); - - tensorArg.replaceAllUsesWith(toTensorOp); - if (failed(funcOp.eraseArgument(i))) - return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering"); - } - - llvm::SmallSet sliceOpsToRemove; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { - unsigned numComputeWeights = computeOp.getWeights().size(); - for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { - TypedValue tensorSource; - int64_t elementsOffset = 0; - - if (auto sliceOp = dyn_cast(computeOpInput.getDefiningOp())) { - tensorSource = cast>(sliceOp.getSource()); - - if (isa(tensorSource.getDefiningOp())) - continue; - - ArrayRef sourceShape = tensorSource.getType().getShape(); - ArrayRef sliceOffsets = sliceOp.getStaticOffsets(); - ArrayRef sliceSizes = sliceOp.getStaticSizes(); - ArrayRef sliceStrides = sliceOp.getStaticStrides(); - assert("Extracting slice non-contiguous in memory" - && isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides)); - - for (size_t i = 0; i < sliceOffsets.size(); i++) { - int64_t partialOffset = sliceOffsets[i]; - if (partialOffset != 0) - for (size_t j = i + 1; j < sourceShape.size(); j++) - partialOffset *= sourceShape[j]; - elementsOffset += partialOffset; - } - - computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource); - sliceOpsToRemove.insert(sliceOp); + assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle"); + assert(computeOp.getBody().front().getNumArguments() == 0 + && "Already removed from mergeNode and global input handle"); + for (auto getGlobal : computeOp.getOps()) { + if (getGlobal.getName().starts_with("arg")) { + assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute"); + auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin(); + insertMemCopyHostToDev(toTensorOpValue, 0); } - else - tensorSource = cast>(computeOpInput); - - // Compute results must be transferred through channels via send/receive - if (isa(tensorSource.getDefiningOp())) - continue; - - BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); - insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset); } } - for (auto sliceOp : sliceOpsToRemove) - if (sliceOp->getUses().empty()) - rewriter.eraseOp(sliceOp); - return success(); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 0cd8482..cf42e6e 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -45,6 +45,16 @@ void PimBufferizationPass::runOnOperation() { bufferization::OneShotBufferizationOptions options; options.allowUnknownOps = true; bufferization::BufferizationState state; + + /*for (auto funcOp : moduleOp.getOps()) {*/ + /* for (auto pimCoreOp : funcOp.getOps()) {*/ + /* if (failed(bufferization::runOneShotBufferize(pimCoreOp, options, state))) {*/ + /* moduleOp.emitError("Failed to bufferize PIM and Spatial ops");*/ + /* signalPassFailure();*/ + /* }*/ + /* }*/ + /*}*/ + if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); signalPassFailure(); From 9dccc2c701cec657089f72ebc6c0c5051905cba4 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Tue, 28 Apr 2026 12:42:01 +0200 Subject: [PATCH 2/2] Translate global constant to symble --- src/PIM/Compiler/PimCodeGen.cpp | 18 ++- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 62 ++++++--- src/PIM/Conversion/SpatialToPim/Patterns.cpp | 120 ++++++++++++++++-- .../SpatialToPim/SpatialToPimPass.cpp | 7 +- .../MergeComputeNodesPass.cpp | 50 ++++++-- 5 files changed, 212 insertions(+), 45 deletions(-) diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 080a335..102e2ae 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -3,9 +3,11 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" @@ -53,9 +55,23 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { SmallDenseMap globalConstants; SmallVector, 16> globalAliases; + SmallVector args; + + + for (mlir::Value arg : funcOp.getArguments()){ + gatherMemEntry(arg); + args.push_back(arg); + } + funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { if (!hasWeightAlways(getGlobalOp)) { auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (globalMemrefOp.getName().starts_with("arg")){ + StringRef indexStr = globalMemrefOp.getName().substr(4); + int index = 0; + llvm::to_integer(indexStr,index, 10); + globalAliases.push_back({getGlobalOp.getResult(), args[index]}); + } auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult()); if (inserted) gatherMemEntry(getGlobalOp.getResult()); @@ -64,8 +80,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { } }); - for (mlir::Value arg : funcOp.getArguments()) - gatherMemEntry(arg); funcOp.walk([&](memref::AllocOp allocOp) { if (!allocOp->getParentOfType()) diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 28e7078..eaee28c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -148,6 +148,7 @@ void ONNXToSpatialPass::runOnOperation() { llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; annotateWeightsConstants(*entryFunc); + encapsulateGlobalInstruction(*entryFunc); if (failed(promoteConstantInputsToWeights(*entryFunc))) { @@ -166,19 +167,46 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func if (T toRemoveOp = llvm::dyn_cast_if_present(inst)) { Value source = funcSource(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp); - if (isa_and_present(source.getDefiningOp())) { - auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); - auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); - newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); - rewriter.setInsertionPointToEnd(BB); - IRMapping mapper; - mapper.map(source, BB->getArgument(0)); - auto newInst = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); - inst->replaceAllUsesWith(newCompute->getResults()); - inst->erase(); - return true; + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + mapper.map(source, BB->getArgument(0)); + auto newInst = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); + inst->erase(); + return true; + } + return false; +} + +bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) { + if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present(inst)) { + for (auto& use : toRemoveOp->getUses()) { + auto users = use.getOwner(); + if (auto spatCompUser = dyn_cast(users)) { + unsigned int poistionUses = use.getOperandNumber(); + if (poistionUses < spatCompUser.getInputs().getBeginOperandIndex()) + return false; + }else { + return false; + } } + auto source = toRemoveOp.getSource(); + rewriter.setInsertionPointAfter(toRemoveOp); + auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + mapper.map(source, BB->getArgument(0)); + auto newInst = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); + inst->replaceAllUsesWith(newCompute->getResults()); + inst->erase(); + return true; } return false; } @@ -187,8 +215,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { if (auto toRemoveOp = llvm::dyn_cast_if_present(inst)) { auto sources = toRemoveOp.getInputs(); rewriter.setInsertionPointAfter(toRemoveOp); - if (llvm::any_of( - sources, [](auto source) { return isa_and_present(source.getDefiningOp()); })) { + if (llvm::any_of(sources, + [](auto source) { return isa_and_present(source.getDefiningOp()); })) { auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources); SmallVector sourceTypes; SmallVector sourceLoc; @@ -277,8 +305,7 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { while (keep) { keep = false; for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { - keep |= encapsulator( - rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); }); + keep |= encapsulateSlice(rewriter, loc, &instruction); keep |= encapsulator( rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); @@ -324,8 +351,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { rewriter.setInsertionPointAfter(compute.getOperation()); - auto newCompute = - spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); + auto newCompute = spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); newCompute.getProperties().setOperandSegmentSizes( {static_cast(compute.getWeights().size()), static_cast(compute.getInputs().size())}); diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index 5dbecd2..a84a2ef 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -1,11 +1,19 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" +#include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -13,6 +21,96 @@ using namespace mlir; namespace onnx_mlir { namespace { +struct ArithConstToGlobalMemoryPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override { + static int i = 0; + Location loc = constantOp.getLoc(); + + if (hasWeightAlways(constantOp)) + return failure(); + + if (!isa(constantOp->getParentOp())) + return failure(); + + rewriter.setInsertionPoint(constantOp->getParentOfType()); + + auto constRankedTensorType = llvm::dyn_cast(constantOp.getType()); + + if (constRankedTensorType) { + mlir::MemRefType memRefType = + mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType()); + std::string argName = "const_" + std::to_string(i++); + memref::GlobalOp::create(rewriter, + loc, + rewriter.getStringAttr(argName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + constantOp.getValueAttr(), + rewriter.getUnitAttr(), + {}); + + for (auto& constUses : constantOp->getUses()) { + auto constUsers = constUses.getOwner(); + + if (auto spatCompute = llvm::dyn_cast(constUsers)) { + + auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(toTensor); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else { + llvm_unreachable("Who are using const globally"); + } + } + } + else if (constantOp.getType().isIntOrIndexOrFloat()) { + llvm::DenseMap mapSpatComputeToConst; + + for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { + auto constUsers = constUses.getOwner(); + + if (auto spatCompute = llvm::dyn_cast(constUsers)) { + + auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(newConst->getResult(0)); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else { + auto parent = constUsers->getParentOfType(); + assert(parent && "Global Constant used direcly not within a compute"); + if (!mapSpatComputeToConst.contains(parent)) { + rewriter.setInsertionPoint(&parent.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); + mapSpatComputeToConst.insert({parent, newConst->getResult(0)}); + } + constUses.set(mapSpatComputeToConst[parent]); + } + } + } + auto parent = constantOp->getParentOp(); + rewriter.eraseOp(constantOp); + return success(); + } +}; + struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -42,13 +140,13 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern(patterns.getContext()); + patterns.add(patterns.getContext()); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index ba7beac..d9c7100 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" @@ -165,10 +166,7 @@ void SpatialToPimPass::runOnOperation() { RewritePatternSet patterns(ctx); populateGlobalTensorToMemrefPatterns(patterns); - if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { - signalPassFailure(); - return; - } + walkAndApplyPatterns(moduleOp, std::move(patterns)); } auto entryFunc = getPimEntryFunc(moduleOp); @@ -504,6 +502,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; + for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle"); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 976f418..6672d3b 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -95,23 +95,22 @@ void generateReport(func::FuncOp funcOp, const std::string& name) { auto expectedPrintedValue = currentComputeId + 1; bool rangePrinted = false; cI++; - for (; cI <= lastIndex; ++cI){ + for (; cI <= lastIndex; ++cI) { auto candidateToPrint = std::get<0>(collectedData[cI]); - if (candidateToPrint == expectedPrintedValue){ + if (candidateToPrint == expectedPrintedValue) { expectedPrintedValue = candidateToPrint + 1; rangePrinted = true; - } else { - if (rangePrinted) { + } + else { + if (rangePrinted) os << " - " << expectedPrintedValue - 1; - } os << " , " << candidateToPrint; rangePrinted = false; expectedPrintedValue = candidateToPrint + 1; } } - if (rangePrinted && currentComputeId != expectedPrintedValue - 1){ - os << " - " << expectedPrintedValue - 1; - } + if (rangePrinted && currentComputeId != expectedPrintedValue - 1) + os << " - " << expectedPrintedValue - 1; os << " :\n"; os << "\tNumber of instructions " << currentNumInst << "\n"; @@ -193,11 +192,34 @@ public: LogicalResult initialize(MLIRContext* context) override { return success(); } + void verifyOrderAssumption(std::vector& dominanceOrderCompute) { + uint64_t computeNumber = 0; + llvm::DenseSet visited; + mlir::func::FuncOp funcOp = getOperation(); + for (auto spatCompute : funcOp.getOps()) + computeNumber++; + + assert(computeNumber == dominanceOrderCompute.size()); + + for(auto domCompute : dominanceOrderCompute){ + visited.insert(domCompute); + for(auto domInput : domCompute.getInputs() ){ + if(auto domImputAsCompute = dyn_cast_if_present(domInput.getDefiningOp())){ + assert(visited.contains(domImputAsCompute) && "Dominance order violated\n"); + } + } + } + + } + void runOnOperation() override { DCPAnalysisResult& analysisResult = getAnalysis().getResult(); auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu; auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap; + func::FuncOp func = getOperation(); + verifyOrderAssumption(analysisResult.dominanceOrderCompute); + for (auto currentComputeNode : analysisResult.dominanceOrderCompute) { size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode); if (!cpuToNewComputeMap.contains(cpu)) { @@ -219,11 +241,19 @@ public: } for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) { - for (auto users : computeNodeToRemove->getUsers()) + if (!computeNodeToRemove->use_empty()) { + llvm::dbgs() << "Full module\n"; + computeNodeToRemove->getParentOfType()->dump(); + + llvm::dbgs() << "Compute with uses:\n"; + computeNodeToRemove.dump(); + } + for (auto users : computeNodeToRemove->getUsers()) { + llvm::dbgs() << "Users:\n"; users->dump(); + } computeNodeToRemove.erase(); } - func::FuncOp func = getOperation(); dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); generateReport(func, "spatial1_dcp_merged_report"); }