add support for operations: reduceMean, add, mul, div, sigmoid
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s
This commit is contained in:
@@ -239,6 +239,22 @@ def SpatSumOp : SpatOp<"sum", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVAvgOp : SpatOp<"vavg", []> {
|
||||
let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
||||
let summary = "Element-wise sigmoid activation";
|
||||
|
||||
|
||||
@@ -361,7 +361,7 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel receive to pim.load using by creating a new global buffer
|
||||
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
@@ -370,8 +370,21 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId) {
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
cast<IntegerAttr>(precomputedOtherCoreId))
|
||||
.getOutput();
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
@@ -379,31 +392,30 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
return failure();
|
||||
}
|
||||
|
||||
// The first 'broadcast' operation creates the buffer just after the
|
||||
// channelNewOp, while the other 'broadcast' operation need to find this
|
||||
// buffer allocation just after the channelNewOp
|
||||
Value bufferAllocation;
|
||||
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
|
||||
// Buffer already allocated, load from this buffer
|
||||
bufferAllocation = allocOpAfterChannel;
|
||||
}
|
||||
else {
|
||||
// Buffer was not allocated previously, allocate it after channelNewOp
|
||||
rewriter.setInsertionPointAfter(channelNewOp);
|
||||
bufferAllocation = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
}
|
||||
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
|
||||
if (!sendOp)
|
||||
continue;
|
||||
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
|
||||
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
|
||||
}
|
||||
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
|
||||
return failure();
|
||||
}();
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
bufferAllocation,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(outputSize));
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -428,8 +440,7 @@ struct ChannelBroadcastSendOpInterface
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send into a device-to-host copy into the shared
|
||||
* broadcast buffer that receive ops load from later.
|
||||
* Turn the broadcast send into one pim.send per broadcast receiver.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
@@ -448,32 +459,32 @@ struct ChannelBroadcastSendOpInterface
|
||||
return failure();
|
||||
}
|
||||
|
||||
// The first 'broadcast' operation creates the buffer just after the
|
||||
// channelNewOp, while the other 'broadcast' operation need to find this
|
||||
// buffer allocation just after the channelNewOp
|
||||
Value bufferAllocation;
|
||||
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
|
||||
// Buffer already allocated, load from this buffer
|
||||
bufferAllocation = allocOpAfterChannel;
|
||||
}
|
||||
else {
|
||||
// Buffer was not allocated previously, allocate it after channelNewOp
|
||||
rewriter.setInsertionPointAfter(channelNewOp);
|
||||
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
|
||||
}
|
||||
|
||||
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
bufferAllocation.getType(),
|
||||
bufferAllocation,
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||
bool foundReceiver = false;
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
foundReceiver = true;
|
||||
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
|
||||
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(sizeInBytes),
|
||||
rewriter.getI32IntegerAttr(dstCoreId));
|
||||
}
|
||||
|
||||
if (!foundReceiver) {
|
||||
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user