This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
File diff suppressed because it is too large Load Diff
@@ -16,12 +16,9 @@ struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
Location loc = reluOp.getLoc();
Type resultType = reluOp.getResult().getType();
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
});
rewriter.replaceOp(reluOp, computeOp);
auto reluPlan = spatial::SpatReluPlanOp::create(
rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
rewriter.replaceOp(reluOp, reluPlan.getResult());
return success();
}
};
@@ -118,17 +118,17 @@ static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatGraphCompute> {
using OpRewritePattern<spatial::SpatGraphCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(spatial::SpatGraphCompute compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto newCompute = spatial::SpatCompute::create(
auto newCompute = spatial::SpatGraphCompute::create(
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
@@ -182,10 +182,10 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatGraphComputeBatch> {
using OpRewritePattern<spatial::SpatGraphComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
LogicalResult matchAndRewrite(spatial::SpatGraphComputeBatch compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
@@ -197,7 +197,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
if (failed(laneCountAttr))
return failure();
auto newCompute = spatial::SpatComputeBatch::create(
auto newCompute = spatial::SpatGraphComputeBatch::create(
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
@@ -281,8 +281,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
});
}
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatGraphCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatGraphComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir