Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
@@ -36,6 +36,14 @@ static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
struct PromotedOperands {
SmallVector<bool> promoteInput;
SmallVector<Value> newWeights;
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
};
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
@@ -48,60 +56,91 @@ static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
return false;
}
template <typename ComputeOpTy>
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
PromotedOperands promoted;
promoted.promoteInput.assign(compute.getInputs().size(), false);
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
promoted.newInputs.reserve(compute.getInputs().size());
promoted.newInputTypes.reserve(compute.getInputs().size());
promoted.newInputLocs.reserve(compute.getInputs().size());
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
goto keep_input;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
goto keep_input;
promoted.promoteInput[inputIdx] = true;
promoted.newWeights.push_back(input);
needsRewrite = true;
continue;
keep_input:
promoted.newInputs.push_back(input);
promoted.newInputTypes.push_back(input.getType());
promoted.newInputLocs.push_back(input.getLoc());
}
if (!needsRewrite)
return failure();
return promoted;
}
template <typename ComputeOpTy>
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
const PromotedOperands& promoted,
IRRewriter& bodyRewriter,
IRMapping& mapper,
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
PatternRewriter& rewriter) {
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
if (!promoted.promoteInput[oldInputIdx]) {
auto newInputArg = getNewInputArg(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
return success();
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
for (Value weight : newWeights) {
for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
@@ -115,24 +154,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
if (!promoteInput[oldInputIdx]) {
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
if (failed(mapPromotedInputArguments(
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
return failure();
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
@@ -156,63 +180,35 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(),
compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
promoted->newWeights,
promoted->newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults());
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(laneArg->getLoc());
for (Value weight : newWeights) {
for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
@@ -224,7 +220,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
@@ -242,29 +238,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
if (!promoteInput[oldInputIdx]) {
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
if (failed(mapPromotedInputArguments(
compute, *promoted, bodyRewriter, mapper, [&](size_t index) { return newCompute.getInputArgument(index); }, rewriter)))
return failure();
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
mapper.map(*outputArg,
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
}
for (Operation& op : oldBlock)