conv now lowers correctly down to bufferized pim

This commit is contained in:
NiccoloN
2026-03-20 12:55:09 +01:00
parent 6e1de865bb
commit db3f52a647
4 changed files with 114 additions and 100 deletions

View File

@@ -300,7 +300,8 @@ struct ChannelBroadcastReceiveOpInterface
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto outputType = cast<ShapedType>(outputTensor.getType());
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
@@ -356,7 +357,8 @@ struct ChannelBroadcastSendOpInterface
}
/*
* Turn the channel send to pim.send
* Turn the channel send into a device-to-host copy into the shared
* broadcast buffer that receive ops load from later.
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
@@ -389,8 +391,18 @@ struct ChannelBroadcastSendOpInterface
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
}
auto srcType = cast<ShapedType>(srcTensor.getType());
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
rewriter.setInsertionPoint(op);
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
rewriter.create<pim::PimMemCopyDevToHostOp>(op->getLoc(),
bufferAllocation.getType(),
bufferAllocation,
srcMemRef,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes));
rewriter.eraseOp(op);
return success();
}
};