fix much stuff

This commit is contained in:
NiccoloN
2026-05-22 18:53:38 +02:00
parent 8337a11ce9
commit 2c1da813b5
18 changed files with 502 additions and 191 deletions
@@ -422,12 +422,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlices[coreId].size());
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) {
auto weightArg = computeOp.getWeightArgument(aHSliceId);
auto inputArg = computeOp.getInputArgument(aHSliceId);
if (!weightArg || !inputArg)
return failure();
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
gemmLoc,
currOutHSliceType,
computeOp.getWeightArgument(aHSliceId),
computeOp.getInputArgument(aHSliceId)));
*weightArg,
*inputArg));
}
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
@@ -561,29 +566,31 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
Value lane = batchOp.getLaneArgument();
Value weight = batchOp.getWeightArgument(0);
Value packedInput = batchOp.getInputArgument(0);
Value packedOutput = batchOp.getOutputArgument(0);
auto lane = batchOp.getLaneArgument();
auto weight = batchOp.getWeightArgument(0);
auto packedInput = batchOp.getInputArgument(0);
auto packedOutput = batchOp.getOutputArgument(0);
if (!lane || !weight || !packedInput || !packedOutput)
return failure();
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> inputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value row =
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides)
.getResult();
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult();
Value laneResult = vmmResult;
if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
tensor::ParallelInsertSliceOp::create(
rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
rewriter.replaceOp(gemmOp, batchOp.getResults());