#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LogicalResult.h" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static std::optional getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) { if (auto compute = dyn_cast(owner)) { unsigned inputCount = compute.getInputs().size(); if (inputCount == 0) return std::nullopt; unsigned inputBegin = compute->getNumOperands() - inputCount; if (operandNumber < inputBegin) return std::nullopt; return operandNumber - inputBegin; } if (auto computeBatch = dyn_cast(owner)) { unsigned inputCount = computeBatch.getInputs().size(); if (inputCount == 0) return std::nullopt; unsigned inputBegin = computeBatch->getNumOperands() - inputCount; if (operandNumber < inputBegin) return std::nullopt; return operandNumber - inputBegin; } return std::nullopt; } struct MoveExtractSliceIntoCompute final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override { if (!isa(extractSliceOp->getParentOp())) return failure(); for (auto& uses : extractSliceOp->getUses()) { if (isa(uses.getOwner())) { if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber())) return failure(); } else if (isa_and_present(uses.getOwner()->getParentOp())) { return failure(); } } llvm::DenseMap mapSpatToExtract; for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) { if (auto spatCompute = dyn_cast(uses.getOwner())) { auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); if (BBArgValue.use_empty()) continue; rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); if (!mapSpatToExtract.contains(spatCompute.getOperation())) { auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)}); } rewriter.startOpModification(spatCompute.getOperation()); BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]); spatCompute.getInputsMutable().erase(BBArgIndex); spatCompute.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatCompute.getOperation()); } else if (auto spatComputeBatch = dyn_cast(uses.getOwner())) { auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); if (BBArgValue.use_empty()) continue; rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) { auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)}); } rewriter.startOpModification(spatComputeBatch.getOperation()); BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]); spatComputeBatch.getInputsMutable().erase(BBArgIndex); spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatComputeBatch.getOperation()); } else { { if (auto spatCompute = uses.getOwner()->getParentOfType()) { rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); if (!mapSpatToExtract.contains(spatCompute.getOperation())) { auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)}); } rewriter.startOpModification(spatCompute.getOperation()); uses.set(mapSpatToExtract[spatCompute.getOperation()]); rewriter.finalizeOpModification(spatCompute.getOperation()); } else if (auto spatComputeBatch = uses.getOwner()->getParentOfType()) { rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) { auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation()); mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)}); } rewriter.startOpModification(spatComputeBatch.getOperation()); uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]); rewriter.finalizeOpModification(spatComputeBatch.getOperation()); } } } } rewriter.eraseOp(extractSliceOp); return success(); } }; struct ArithConstToGlobalMemoryPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override { static int i = 0; Location loc = constantOp.getLoc(); if (hasWeightAlways(constantOp)) return failure(); if (!isa(constantOp->getParentOp())) return failure(); if (llvm::all_of(constantOp->getUsers(), [](Operation* op) { if (isa(op)) return false; if (isa(op->getParentOp())) return true; return false; })) return failure(); rewriter.setInsertionPoint(constantOp->getParentOfType()); auto constRankedTensorType = llvm::dyn_cast(constantOp.getType()); if (constRankedTensorType) { mlir::MemRefType memRefType = mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType()); std::string argName = "const_" + std::to_string(i++); memref::GlobalOp::create(rewriter, loc, rewriter.getStringAttr(argName), rewriter.getStringAttr("private"), TypeAttr::get(memRefType), constantOp.getValueAttr(), rewriter.getUnitAttr(), {}); llvm::DenseMap mapSpatComputeToConst; for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { auto constUsers = constUses.getOwner(); if (auto spatCompute = llvm::dyn_cast(constUsers)) { auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()}); } rewriter.startOpModification(spatCompute.getOperation()); BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]); spatCompute.getInputsMutable().erase(BBArgIndex); spatCompute.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatCompute.getOperation()); } else if (auto spatComputeBatch = llvm::dyn_cast(constUsers)) { auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()}); } rewriter.startOpModification(spatComputeBatch.getOperation()); BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]); spatComputeBatch.getInputsMutable().erase(BBArgIndex); spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatComputeBatch.getOperation()); } else { { if (auto spatCompute = constUses.getOwner()->getParentOfType()) { rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()}); } rewriter.startOpModification(spatCompute.getOperation()); constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]); rewriter.finalizeOpModification(spatCompute.getOperation()); } else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType()) { rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()}); } rewriter.startOpModification(spatComputeBatch.getOperation()); constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]); rewriter.finalizeOpModification(spatComputeBatch.getOperation()); } } } } } else if (constantOp.getType().isIntOrIndexOrFloat()) { llvm::DenseMap mapSpatComputeToConst; for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { auto constUsers = constUses.getOwner(); if (auto spatCompute = llvm::dyn_cast(constUsers)) { auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); rewriter.startOpModification(spatCompute.getOperation()); BBArgValue.replaceAllUsesWith(newConst->getResult(0)); spatCompute.getInputsMutable().erase(BBArgIndex); spatCompute.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatCompute.getOperation()); } else if (auto spatComputeBatch = llvm::dyn_cast(constUsers)) { auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); rewriter.startOpModification(spatComputeBatch.getOperation()); BBArgValue.replaceAllUsesWith(newConst->getResult(0)); spatComputeBatch.getInputsMutable().erase(BBArgIndex); spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatComputeBatch.getOperation()); } else if (auto parent = constUsers->getParentOfType()) { if (!mapSpatComputeToConst.contains(parent)) { rewriter.setInsertionPoint(&parent.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)}); } constUses.set(mapSpatComputeToConst[parent.getOperation()]); } else { auto batchParent = constUsers->getParentOfType(); assert(batchParent && "Global Constant used direcly not within a compute"); if (!mapSpatComputeToConst.contains(batchParent.getOperation())) { rewriter.setInsertionPoint(&batchParent.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)}); } constUses.set(mapSpatComputeToConst[batchParent.getOperation()]); } } } rewriter.eraseOp(constantOp); return success(); } }; struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override { if (funcOp.getArguments().empty()) return failure(); if (llvm::all_of(funcOp.getArguments(), [](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); })) return failure(); Location loc = funcOp.getLoc(); for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) { if (arg.getUses().empty()) continue; rewriter.setInsertionPoint(funcOp.getOperation()); assert(isa(arg.getType())); auto argRankedTensorType = llvm::dyn_cast(arg.getType()); mlir::MemRefType memRefType = mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType()); std::string argName = "arg_" + std::to_string(index); memref::GlobalOp::create(rewriter, loc, rewriter.getStringAttr(argName), rewriter.getStringAttr("private"), TypeAttr::get(memRefType), {}, {}, {}); for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) { auto argUser = argUses.getOwner(); if (auto spatCompute = dyn_cast(argUser)) { auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); rewriter.startOpModification(spatCompute.getOperation()); BBArgValue.replaceAllUsesWith(toTensor); spatCompute.getInputsMutable().erase(BBArgIndex); spatCompute.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatCompute.getOperation()); } else if (auto spatComputeBatch = dyn_cast(argUser)) { auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); rewriter.startOpModification(spatComputeBatch.getOperation()); BBArgValue.replaceAllUsesWith(toTensor); spatComputeBatch.getInputsMutable().erase(BBArgIndex); spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); rewriter.finalizeOpModification(spatComputeBatch.getOperation()); } else { rewriter.setInsertionPoint(argUser); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); rewriter.startOpModification(argUser); argUses.set(toTensor); rewriter.finalizeOpModification(argUser); } } } return success(); } }; } // namespace void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) { patterns.add( patterns.getContext()); } } // namespace onnx_mlir