fix much stuff

This commit is contained in:
NiccoloN
2026-05-22 18:53:38 +02:00
parent 8337a11ce9
commit 2c1da813b5
18 changed files with 502 additions and 191 deletions
@@ -422,12 +422,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlices[coreId].size());
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) {
auto weightArg = computeOp.getWeightArgument(aHSliceId);
auto inputArg = computeOp.getInputArgument(aHSliceId);
if (!weightArg || !inputArg)
return failure();
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
gemmLoc,
currOutHSliceType,
computeOp.getWeightArgument(aHSliceId),
computeOp.getInputArgument(aHSliceId)));
*weightArg,
*inputArg));
}
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
@@ -561,29 +566,31 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
Value lane = batchOp.getLaneArgument();
Value weight = batchOp.getWeightArgument(0);
Value packedInput = batchOp.getInputArgument(0);
Value packedOutput = batchOp.getOutputArgument(0);
auto lane = batchOp.getLaneArgument();
auto weight = batchOp.getWeightArgument(0);
auto packedInput = batchOp.getInputArgument(0);
auto packedOutput = batchOp.getOutputArgument(0);
if (!lane || !weight || !packedInput || !packedOutput)
return failure();
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> inputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value row =
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides)
.getResult();
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult();
Value laneResult = vmmResult;
if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
tensor::ParallelInsertSliceOp::create(
rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
rewriter.replaceOp(gemmOp, batchOp.getResults());
@@ -27,13 +27,16 @@ static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
return arg && canPromoteInputBlockArgument(*arg);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
Block& block = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
@@ -104,20 +107,30 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue);
mapper.map(*oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
@@ -184,12 +197,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(compute.getLaneArgument().getType());
newBlockArgLocs.push_back(compute.getLaneArgument().getLoc());
newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(laneArg->getLoc());
for (Value weight : newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
@@ -197,8 +213,11 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
newBlockArgTypes.push_back(resultType);
newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc());
newBlockArgLocs.push_back(outputArg->getLoc());
}
auto* newBlock = rewriter.createBlock(
@@ -211,25 +230,41 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
mapper.map(compute.getLaneArgument(), newCompute.getLaneArgument());
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
auto newLaneArg = newCompute.getLaneArgument();
if (!newLaneArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
mapper.map(*laneArg, *newLaneArg);
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue);
mapper.map(*oldArg, *clonedValue);
}
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
}
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
mapper.map(compute.getOutputArgument(resultIndex),
newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
for (Operation& op : oldBlock)
rewriter.clone(op, mapper);