sightly better bufferization

minor fixes
This commit is contained in:
NiccoloN
2026-05-07 17:53:47 +02:00
parent f2fe147961
commit f6c8cc4aa5
19 changed files with 150 additions and 141 deletions
@@ -381,7 +381,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
@@ -527,7 +527,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
rewriter.setInsertionPointToEnd(body);
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
Value laneResult = vmmResult;
if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
+1 -1
View File
@@ -95,7 +95,7 @@ bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
return false;
}
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1);
mlir::Value result = operation->getResult(0);
auto resultType = result.getType();
+1 -1
View File
@@ -41,7 +41,7 @@ mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
@@ -16,13 +16,13 @@ def onnxToPimTranspose : Pat<
>;
def spatToPimVMM : Pat<
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimMVM : Pat<
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
(SpatMVMOp:$srcOpRes $weightIndex, $vector),
(PimMVMOp $weightIndex, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
@@ -252,25 +252,6 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.replaceOp(concatOp, concatenated);
}
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
wvmmOps.push_back(wvmmOp);
});
for (auto wvmmOp : wvmmOps) {
rewriter.setInsertionPoint(wvmmOp);
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
wvmmOp.getOutput().getType(),
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
wvmmOp.getInput(),
outputBuffer);
}
}
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatMapOp> mapOps;
funcOp.walk([&](spatial::SpatMapOp mapOp) {
@@ -736,7 +717,7 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
@@ -745,15 +726,13 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
}
lowerRemainingSpatialMathOps(funcOp, rewriter);
RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {