conv now lowers correctly down to bufferized pim
This commit is contained in:
@@ -300,7 +300,8 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
@@ -356,7 +357,8 @@ struct ChannelBroadcastSendOpInterface
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
* Turn the channel send into a device-to-host copy into the shared
|
||||
* broadcast buffer that receive ops load from later.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
@@ -389,8 +391,18 @@ struct ChannelBroadcastSendOpInterface
|
||||
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
||||
}
|
||||
|
||||
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef});
|
||||
rewriter.create<pim::PimMemCopyDevToHostOp>(op->getLoc(),
|
||||
bufferAllocation.getType(),
|
||||
bufferAllocation,
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user