sightly better bufferization
minor fixes
This commit is contained in:
@@ -252,25 +252,6 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
||||
rewriter.replaceOp(concatOp, concatenated);
|
||||
}
|
||||
|
||||
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
|
||||
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
|
||||
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
wvmmOps.push_back(wvmmOp);
|
||||
});
|
||||
|
||||
for (auto wvmmOp : wvmmOps) {
|
||||
rewriter.setInsertionPoint(wvmmOp);
|
||||
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
|
||||
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
|
||||
wvmmOp.getOutput().getType(),
|
||||
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
|
||||
wvmmOp.getInput(),
|
||||
outputBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatMapOp> mapOps;
|
||||
funcOp.walk([&](spatial::SpatMapOp mapOp) {
|
||||
@@ -736,7 +717,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
|
||||
if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -745,15 +726,13 @@ void SpatialToPimPass::runOnOperation() {
|
||||
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
|
||||
if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lowerRemainingSpatialMathOps(funcOp, rewriter);
|
||||
|
||||
RewritePatternSet channelPatterns(ctx);
|
||||
populateWithGenerated(channelPatterns);
|
||||
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||
|
||||
Reference in New Issue
Block a user