sightly better bufferization

minor fixes
This commit is contained in:
NiccoloN
2026-05-07 17:53:47 +02:00
parent f2fe147961
commit f6c8cc4aa5
19 changed files with 150 additions and 141 deletions
@@ -252,25 +252,6 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.replaceOp(concatOp, concatenated);
}
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
wvmmOps.push_back(wvmmOp);
});
for (auto wvmmOp : wvmmOps) {
rewriter.setInsertionPoint(wvmmOp);
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
wvmmOp.getOutput().getType(),
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
wvmmOp.getInput(),
outputBuffer);
}
}
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatMapOp> mapOps;
funcOp.walk([&](spatial::SpatMapOp mapOp) {
@@ -736,7 +717,7 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
@@ -745,15 +726,13 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
}
lowerRemainingSpatialMathOps(funcOp, rewriter);
RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {