#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static bool isWeightMaterializationHelperUser(Operation* op) { return isa(op); } static bool canPromoteInputBlockArgument(BlockArgument arg) { return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser); } static bool canPromoteInputBlockArgument(std::optional arg) { return arg && canPromoteInputBlockArgument(*arg); } static bool isDirectConstantValue(Value value) { return isa_and_nonnull(value.getDefiningOp()); } struct PromotedOperands { SmallVector promoteInput; SmallVector newWeights; SmallVector newInputs; SmallVector newInputTypes; SmallVector newInputLocs; }; template static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { if (!isWeightLikeComputeOperand(input)) continue; if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx))) continue; return true; } return false; } template static FailureOr computePromotedOperands(ComputeOpTy compute) { PromotedOperands promoted; promoted.promoteInput.assign(compute.getInputs().size(), false); promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end()); promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); promoted.newInputs.reserve(compute.getInputs().size()); promoted.newInputTypes.reserve(compute.getInputs().size()); promoted.newInputLocs.reserve(compute.getInputs().size()); bool needsRewrite = false; for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { if (!isWeightLikeComputeOperand(input)) goto keep_input; if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx))) goto keep_input; promoted.promoteInput[inputIdx] = true; promoted.newWeights.push_back(input); needsRewrite = true; continue; keep_input: promoted.newInputs.push_back(input); promoted.newInputTypes.push_back(input.getType()); promoted.newInputLocs.push_back(input.getLoc()); } if (!needsRewrite) return failure(); return promoted; } template static LogicalResult mapPromotedInputArguments(ComputeOpTy compute, const PromotedOperands& promoted, IRRewriter& bodyRewriter, IRMapping& mapper, std::function(size_t)> getNewInputArg, PatternRewriter& rewriter) { size_t newInputIdx = 0; for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) { auto oldArg = compute.getInputArgument(oldInputIdx); if (!oldArg) return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite"); if (!promoted.promoteInput[oldInputIdx]) { auto newInputArg = getNewInputArg(newInputIdx++); if (!newInputArg) return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument"); mapper.map(*oldArg, *newInputArg); continue; } auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper); if (failed(clonedValue)) return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand"); mapper.map(*oldArg, *clonedValue); } return success(); } // Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs. struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override { auto promoted = computePromotedOperands(compute); if (failed(promoted)) return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote"); Block& oldBlock = compute.getBody().front(); rewriter.setInsertionPointAfter(compute); auto newCompute = spatial::SpatCompute::create( rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs); SmallVector newBlockArgTypes; SmallVector newBlockArgLocs; for (Value weight : promoted->newWeights) { newBlockArgTypes.push_back(weight.getType()); newBlockArgLocs.push_back(weight.getLoc()); } llvm::append_range(newBlockArgTypes, promoted->newInputTypes); llvm::append_range(newBlockArgLocs, promoted->newInputLocs); auto* newBlock = rewriter.createBlock( &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); newCompute.getProperties().setOperandSegmentSizes( {static_cast(promoted->newWeights.size()), static_cast(promoted->newInputs.size())}); rewriter.setInsertionPointToStart(newBlock); IRRewriter bodyRewriter(rewriter.getContext()); bodyRewriter.setInsertionPointToStart(newBlock); IRMapping mapper; for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) { auto oldWeightArg = compute.getWeightArgument(weightIndex); auto newWeightArg = newCompute.getWeightArgument(weightIndex); if (!oldWeightArg || !newWeightArg) return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite"); mapper.map(*oldWeightArg, *newWeightArg); } if (failed(mapPromotedInputArguments( compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter))) return failure(); for (Operation& op : oldBlock.without_terminator()) rewriter.clone(op, mapper); auto oldYield = cast(oldBlock.getTerminator()); SmallVector newYieldOperands; newYieldOperands.reserve(oldYield.getOutputs().size()); for (Value operand : oldYield.getOutputs()) { auto mapped = mapper.lookupOrNull(operand); newYieldOperands.push_back(mapped ? cast(mapped) : operand); } spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands); rewriter.replaceOp(compute, newCompute.getResults()); return success(); } }; // Promotes foldable batch helper chains to weights while preserving compact compute_batch IR. struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override { auto promoted = computePromotedOperands(compute); if (failed(promoted)) return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote"); Block& oldBlock = compute.getBody().front(); rewriter.setInsertionPointAfter(compute); auto newCompute = spatial::SpatComputeBatch::create(rewriter, compute.getLoc(), compute.getResultTypes(), rewriter.getI32IntegerAttr(static_cast(compute.getLaneCount())), promoted->newWeights, promoted->newInputs); auto laneArg = compute.getLaneArgument(); if (!laneArg) return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument"); SmallVector newBlockArgTypes; SmallVector newBlockArgLocs; newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults()); newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults()); newBlockArgTypes.push_back(laneArg->getType()); newBlockArgLocs.push_back(laneArg->getLoc()); for (Value weight : promoted->newWeights) { newBlockArgTypes.push_back(weight.getType()); newBlockArgLocs.push_back(weight.getLoc()); } llvm::append_range(newBlockArgTypes, promoted->newInputTypes); llvm::append_range(newBlockArgLocs, promoted->newInputLocs); for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) { auto outputArg = compute.getOutputArgument(resultIndex); if (!outputArg) return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument"); newBlockArgTypes.push_back(resultType); newBlockArgLocs.push_back(outputArg->getLoc()); } auto* newBlock = rewriter.createBlock( &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); newCompute.getProperties().setOperandSegmentSizes( {static_cast(promoted->newWeights.size()), static_cast(promoted->newInputs.size())}); rewriter.setInsertionPointToStart(newBlock); IRRewriter bodyRewriter(rewriter.getContext()); bodyRewriter.setInsertionPointToStart(newBlock); IRMapping mapper; auto newLaneArg = newCompute.getLaneArgument(); if (!newLaneArg) return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument"); mapper.map(*laneArg, *newLaneArg); for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) { auto oldWeightArg = compute.getWeightArgument(weightIndex); auto newWeightArg = newCompute.getWeightArgument(weightIndex); if (!oldWeightArg || !newWeightArg) return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite"); mapper.map(*oldWeightArg, *newWeightArg); } if (failed(mapPromotedInputArguments( compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter))) return failure(); for (auto resultIndex : llvm::seq(0, compute.getNumResults())) { auto outputArg = compute.getOutputArgument(resultIndex); if (!outputArg) return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite"); mapper.map(*outputArg, newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex)); } for (Operation& op : oldBlock) rewriter.clone(op, mapper); rewriter.replaceOp(compute, newCompute.getResults()); return success(); } }; } // namespace void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } void annotateWeightsConstants(func::FuncOp funcOp) { funcOp.walk([&](arith::ConstantOp constantOp) { if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult())) markWeightAlways(constantOp); }); } bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); } bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); } } // namespace onnx_mlir