add support for operations: reduceMean, add, mul, div, sigmoid
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s

This commit is contained in:
NiccoloN
2026-03-30 15:41:12 +02:00
parent 5e7114f517
commit 39830be888
32 changed files with 1057 additions and 224 deletions

View File

@@ -239,6 +239,22 @@ def SpatSumOp : SpatOp<"sum", []> {
}];
}
def SpatVAvgOp : SpatOp<"vavg", []> {
let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatSigmoidOp : SpatOp<"sigmoid", []> {
let summary = "Element-wise sigmoid activation";

View File

@@ -361,7 +361,7 @@ struct ChannelBroadcastReceiveOpInterface
}
/*
* Turn the channel receive to pim.load using by creating a new global buffer
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
@@ -370,8 +370,21 @@ struct ChannelBroadcastReceiveOpInterface
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto outputType = cast<ShapedType>(outputTensor.getType());
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
if (precomputedOtherCoreId) {
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
cast<IntegerAttr>(precomputedOtherCoreId))
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
@@ -379,31 +392,30 @@ struct ChannelBroadcastReceiveOpInterface
return failure();
}
// The first 'broadcast' operation creates the buffer just after the
// channelNewOp, while the other 'broadcast' operation need to find this
// buffer allocation just after the channelNewOp
Value bufferAllocation;
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
// Buffer already allocated, load from this buffer
bufferAllocation = allocOpAfterChannel;
}
else {
// Buffer was not allocated previously, allocate it after channelNewOp
rewriter.setInsertionPointAfter(channelNewOp);
bufferAllocation = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
}
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
for (Operation* user : channelNewOp->getUsers()) {
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
if (!sendOp)
continue;
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
}
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
return failure();
}();
if (failed(srcCoreId))
return failure();
rewriter.setInsertionPoint(op);
auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
bufferAllocation,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(outputSize));
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, memCopyHostToDevOp.getOutput());
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
@@ -428,8 +440,7 @@ struct ChannelBroadcastSendOpInterface
}
/*
* Turn the channel send into a device-to-host copy into the shared
* broadcast buffer that receive ops load from later.
* Turn the broadcast send into one pim.send per broadcast receiver.
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
@@ -448,32 +459,32 @@ struct ChannelBroadcastSendOpInterface
return failure();
}
// The first 'broadcast' operation creates the buffer just after the
// channelNewOp, while the other 'broadcast' operation need to find this
// buffer allocation just after the channelNewOp
Value bufferAllocation;
if (auto allocOpAfterChannel = dyn_cast<memref::AllocOp>(channelNewOp->getNextNode())) {
// Buffer already allocated, load from this buffer
bufferAllocation = allocOpAfterChannel;
}
else {
// Buffer was not allocated previously, allocate it after channelNewOp
rewriter.setInsertionPointAfter(channelNewOp);
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
}
auto srcType = cast<ShapedType>(srcTensor.getType());
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
rewriter.setInsertionPoint(op);
pim::PimMemCopyDevToHostOp::create(rewriter,
op->getLoc(),
bufferAllocation.getType(),
bufferAllocation,
srcMemRef,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes));
bool foundReceiver = false;
for (Operation* user : channelNewOp->getUsers()) {
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
if (!receiveOp)
continue;
foundReceiver = true;
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
pim::PimSendOp::create(rewriter,
op->getLoc(),
srcMemRef,
rewriter.getI32IntegerAttr(sizeInBytes),
rewriter.getI32IntegerAttr(dstCoreId));
}
if (!foundReceiver) {
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
return failure();
}
rewriter.eraseOp(op);
return success();
}