add better createSpatCompute helper

This commit is contained in:
NiccoloN
2026-03-30 16:14:26 +02:00
parent 39830be888
commit 3625edc80a
5 changed files with 259 additions and 239 deletions

View File

@@ -155,18 +155,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
gemvOps.push_back(gemvOp.getY());
}
auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
auto* concatBlock = new Block();
for (auto gemvOp : gemvOps)
concatBlock->addArgument(gemvOp.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
});
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
@@ -289,25 +281,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
auto* computeBlock = new Block();
for (auto aHSlice : aHSlices[coreId])
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
auto computeArgs = computeBlock->getArguments();
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(computeArgs.size());
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
rewriter.setInsertionPointAfter(computeOp);
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
});
partialResults.push_back(computeOp.getResult(0));
}
@@ -318,34 +302,20 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
}
auto reduceComputeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
auto* reduceBlock = new Block();
for (auto partialResult : partialResults)
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
reduceComputeOp.getBody().push_back(reduceBlock);
rewriter.setInsertionPointToStart(reduceBlock);
auto blockArgs = reduceBlock->getArguments();
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
rewriter.setInsertionPointAfter(reduceComputeOp);
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) {
SmallVector<Value> values(blockArgs.begin(), blockArgs.end());
Value outHSlice = sumTensors(values, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
});
outHSlices.push_back(reduceComputeOp.getResult(0));
}
auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
auto* concatBlock = new Block();
for (auto outHSlice : outHSlices)
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
});
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();