Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp
T
NiccoloN 2d5b03c08f
Validate Operations / validate-operations (push) Has been cancelled
automatic code reformat
2026-05-29 19:21:37 +02:00

289 lines
12 KiB
C++

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
return arg && canPromoteInputBlockArgument(*arg);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
struct PromotedOperands {
SmallVector<bool> promoteInput;
SmallVector<Value> newWeights;
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
};
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
return true;
}
return false;
}
template <typename ComputeOpTy>
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
PromotedOperands promoted;
promoted.promoteInput.assign(compute.getInputs().size(), false);
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
promoted.newInputs.reserve(compute.getInputs().size());
promoted.newInputTypes.reserve(compute.getInputs().size());
promoted.newInputLocs.reserve(compute.getInputs().size());
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
goto keep_input;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
goto keep_input;
promoted.promoteInput[inputIdx] = true;
promoted.newWeights.push_back(input);
needsRewrite = true;
continue;
keep_input:
promoted.newInputs.push_back(input);
promoted.newInputTypes.push_back(input.getType());
promoted.newInputLocs.push_back(input.getLoc());
}
if (!needsRewrite)
return failure();
return promoted;
}
template <typename ComputeOpTy>
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
const PromotedOperands& promoted,
IRRewriter& bodyRewriter,
IRMapping& mapper,
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
PatternRewriter& rewriter) {
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
if (!promoted.promoteInput[oldInputIdx]) {
auto newInputArg = getNewInputArg(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
return success();
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto newCompute = spatial::SpatCompute::create(
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
if (failed(mapPromotedInputArguments(
compute,
*promoted,
bodyRewriter,
mapper,
[&](size_t index) { return newCompute.getInputArgument(index); },
rewriter)))
return failure();
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto newCompute =
spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(),
compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
promoted->newWeights,
promoted->newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size()
+ compute.getNumResults());
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(laneArg->getLoc());
for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
newBlockArgTypes.push_back(resultType);
newBlockArgLocs.push_back(outputArg->getLoc());
}
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
auto newLaneArg = newCompute.getLaneArgument();
if (!newLaneArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
mapper.map(*laneArg, *newLaneArg);
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
if (failed(mapPromotedInputArguments(
compute,
*promoted,
bodyRewriter,
mapper,
[&](size_t index) { return newCompute.getInputArgument(index); },
rewriter)))
return failure();
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
mapper.map(*outputArg,
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
}
for (Operation& op : oldBlock)
rewriter.clone(op, mapper);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
} // namespace
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
void annotateWeightsConstants(func::FuncOp funcOp) {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir