#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "llvm/ADT/STLExtras.h" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override { if (funcOp.getArguments().empty()) return failure(); if (llvm::all_of(funcOp.getArguments(), [](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); })) return failure(); Location loc = funcOp.getLoc(); for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) { if (arg.getUses().empty()) continue; rewriter.setInsertionPoint(funcOp.getOperation()); assert(isa(arg.getType())); auto argRankedTensorType = llvm::dyn_cast(arg.getType()); mlir::MemRefType memRefType = mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType()); std::string argName = "arg_" + std::to_string(index); memref::GlobalOp::create(rewriter, loc, rewriter.getStringAttr(argName), rewriter.getStringAttr("private"), TypeAttr::get(memRefType), {}, {}, {}); for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) { auto argUser = argUses.getOwner(); if (auto spatCompute = dyn_cast(argUser)) { auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create(rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); rewriter.startOpModification(spatCompute.getOperation()); BBArgValue.replaceAllUsesWith(toTensor); spatCompute.getInputsMutable().erase(BBArgIndex); spatCompute.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatCompute.getOperation()); } else { rewriter.setInsertionPoint(argUser); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); rewriter.startOpModification(argUser); argUses.set(getGlobalOp); rewriter.finalizeOpModification(argUser); } } } return success(); } }; } // namespace void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) { patterns.add(patterns.getContext()); } } // namespace onnx_mlir