From 5c839e62c1a68371a34016c16ac255469a1d8796 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Mon, 27 Apr 2026 13:48:03 +0200 Subject: [PATCH] Func Input converted to symbol --- src/PIM/Common/PimCommon.cpp | 1 + src/PIM/Compiler/PimCodeGen.cpp | 5 + src/PIM/Compiler/PimCodeGen.hpp | 1 + .../Conversion/SpatialToPim/CMakeLists.txt | 1 + src/PIM/Conversion/SpatialToPim/Patterns.cpp | 88 ++++++++++++++ src/PIM/Conversion/SpatialToPim/Patterns.hpp | 10 ++ .../SpatialToPim/SpatialToPimPass.cpp | 110 ++++++------------ .../Bufferization/PimBufferizationPass.cpp | 10 ++ 8 files changed, 151 insertions(+), 75 deletions(-) create mode 100644 src/PIM/Conversion/SpatialToPim/Patterns.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/Patterns.hpp diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index 3ca1839..2333e99 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index a2deda2..080a335 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -412,6 +412,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa emitInstruction(std::move(json)); } +void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const { +} + void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const { auto srcAddr = addressOf(transposeOp.getInput(), knowledge); auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge); @@ -581,6 +584,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge); else if (auto vsoftmaxOp = dyn_cast(op)) coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); + else if (auto getGlobalOp = dyn_cast(op)) + coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); else { op.emitError("Unsupported codegen for this operation"); op.dump(); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index bad6b55..42e2fb1 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -106,6 +106,7 @@ public: void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const; void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const; + void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const; }; diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index d8222c8..351d96d 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -5,6 +5,7 @@ add_public_tablegen_target(SpatialToPimIncGen) add_pim_library(OMSpatialToPim SpatialToPimPass.cpp Common.cpp + Patterns.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp new file mode 100644 index 0000000..5dbecd2 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -0,0 +1,88 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +#include "llvm/ADT/STLExtras.h" + +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override { + + if (funcOp.getArguments().empty()) + return failure(); + + if (llvm::all_of(funcOp.getArguments(), + [](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); })) + return failure(); + + Location loc = funcOp.getLoc(); + + for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) { + if (arg.getUses().empty()) + continue; + + rewriter.setInsertionPoint(funcOp.getOperation()); + + assert(isa(arg.getType())); + + auto argRankedTensorType = llvm::dyn_cast(arg.getType()); + mlir::MemRefType memRefType = + mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType()); + + std::string argName = "arg_" + std::to_string(index); + + memref::GlobalOp::create(rewriter, + loc, + rewriter.getStringAttr(argName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + {}, + {}, + {}); + + for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) { + auto argUser = argUses.getOwner(); + if (auto spatCompute = dyn_cast(argUser)) { + auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); + rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + auto toTensor = bufferization::ToTensorOp::create(rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + + + rewriter.startOpModification(spatCompute.getOperation()); + BBArgValue.replaceAllUsesWith(toTensor); + spatCompute.getInputsMutable().erase(BBArgIndex); + spatCompute.getBody().front().eraseArgument(BBArgIndex); + rewriter.finalizeOpModification(spatCompute.getOperation()); + } + else { + rewriter.setInsertionPoint(argUser); + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); + rewriter.startOpModification(argUser); + argUses.set(getGlobalOp); + rewriter.finalizeOpModification(argUser); + } + } + } + + return success(); + } +}; + +} // namespace +void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.hpp b/src/PIM/Conversion/SpatialToPim/Patterns.hpp new file mode 100644 index 0000000..e34f6ab --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Patterns.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + + +namespace onnx_mlir { + +void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns); + +} diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index ebdb429..ba7beac 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -1,20 +1,24 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_os_ostream.h" #include @@ -23,6 +27,7 @@ #include #include "Conversion/ONNXToSpatial/Common.hpp" +#include "Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -146,12 +151,24 @@ void SpatialToPimPass::runOnOperation() { scf::SCFDialect, BuiltinDialect>(); - RewritePatternSet patterns(ctx); - populateWithGenerated(patterns); + { + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { - signalPassFailure(); - return; + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } + + { + RewritePatternSet patterns(ctx); + populateGlobalTensorToMemrefPatterns(patterns); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { + signalPassFailure(); + return; + } } auto entryFunc = getPimEntryFunc(moduleOp); @@ -466,11 +483,11 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); - auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { - auto tensorType = cast(valueToReplace.getType()); + auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { + auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; - rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); + rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); @@ -479,85 +496,28 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu loc, tensorType, deviceTensor, - hostTensor, + inputTensor, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); - rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult()); + rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); }; - // Replace input tensors with memRefs - SmallVector inputTensors; - for (size_t i = 0; i < funcOp.getNumArguments(); i++) { - BlockArgument tensorArg = funcOp.getArgument(i); - DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i); - ShapedType tensorArgType = cast(tensorArg.getType()); - MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType()); - - if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc))) - return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering"); - BlockArgument memRefArg = funcOp.getArgument(i + 1); - - Block& block = funcOp.getBody().front(); - rewriter.setInsertionPoint(&block.front()); - auto toTensorOp = - bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); - inputTensors.push_back(toTensorOp); - - tensorArg.replaceAllUsesWith(toTensorOp); - if (failed(funcOp.eraseArgument(i))) - return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering"); - } - - llvm::SmallSet sliceOpsToRemove; for (auto& op : funcOp.getBody().getOps()) if (auto computeOp = dyn_cast(op)) { - unsigned numComputeWeights = computeOp.getWeights().size(); - for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { - TypedValue tensorSource; - int64_t elementsOffset = 0; - - if (auto sliceOp = dyn_cast(computeOpInput.getDefiningOp())) { - tensorSource = cast>(sliceOp.getSource()); - - if (isa(tensorSource.getDefiningOp())) - continue; - - ArrayRef sourceShape = tensorSource.getType().getShape(); - ArrayRef sliceOffsets = sliceOp.getStaticOffsets(); - ArrayRef sliceSizes = sliceOp.getStaticSizes(); - ArrayRef sliceStrides = sliceOp.getStaticStrides(); - assert("Extracting slice non-contiguous in memory" - && isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides)); - - for (size_t i = 0; i < sliceOffsets.size(); i++) { - int64_t partialOffset = sliceOffsets[i]; - if (partialOffset != 0) - for (size_t j = i + 1; j < sourceShape.size(); j++) - partialOffset *= sourceShape[j]; - elementsOffset += partialOffset; - } - - computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource); - sliceOpsToRemove.insert(sliceOp); + assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle"); + assert(computeOp.getBody().front().getNumArguments() == 0 + && "Already removed from mergeNode and global input handle"); + for (auto getGlobal : computeOp.getOps()) { + if (getGlobal.getName().starts_with("arg")) { + assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute"); + auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin(); + insertMemCopyHostToDev(toTensorOpValue, 0); } - else - tensorSource = cast>(computeOpInput); - - // Compute results must be transferred through channels via send/receive - if (isa(tensorSource.getDefiningOp())) - continue; - - BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); - insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset); } } - for (auto sliceOp : sliceOpsToRemove) - if (sliceOp->getUses().empty()) - rewriter.eraseOp(sliceOp); - return success(); } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 0cd8482..cf42e6e 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -45,6 +45,16 @@ void PimBufferizationPass::runOnOperation() { bufferization::OneShotBufferizationOptions options; options.allowUnknownOps = true; bufferization::BufferizationState state; + + /*for (auto funcOp : moduleOp.getOps()) {*/ + /* for (auto pimCoreOp : funcOp.getOps()) {*/ + /* if (failed(bufferization::runOneShotBufferize(pimCoreOp, options, state))) {*/ + /* moduleOp.emitError("Failed to bufferize PIM and Spatial ops");*/ + /* signalPassFailure();*/ + /* }*/ + /* }*/ + /*}*/ + if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); signalPassFailure();