Files
Raptor/src/PIM/Dialect/Spatial/Channels.cpp
T
NiccoloN a50e77ff38
Validate Operations / validate-operations (push) Has been cancelled
refactorone
2026-05-20 19:06:41 +02:00

179 lines
5.6 KiB
C++

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
using namespace mlir;
namespace onnx_mlir::spatial {
namespace {
static FailureOr<int64_t> getConstantI64(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return constantValue.getSExtValue();
}
static FailureOr<int32_t> getConstantI32(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
return getConstantI64(sendOp.getChannelId());
}
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
return getConstantI64(receiveOp.getChannelId());
}
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
return getConstantI32(receiveOp.getSourceCoreId());
}
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
return getConstantI32(receiveOp.getTargetCoreId());
}
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
if (!endpoints.send || !endpoints.receive)
return failure();
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
return failure();
}
if (*sendSourceCoreId != *receiveSourceCoreId) {
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
return failure();
}
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
return failure();
}
if (*sendTargetCoreId != *receiveTargetCoreId) {
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
return failure();
}
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
return failure();
}
return success();
}
} // namespace
Channels::Channels(func::FuncOp funcOp) {
if (!funcOp)
return;
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
}
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
void Channels::insertSend(SpatChannelSendOp sendOp) {
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return;
nextChannelId = std::max(nextChannelId, *channelId + 1);
endpoints[*channelId].send = sendOp;
}
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return;
nextChannelId = std::max(nextChannelId, *channelId + 1);
endpoints[*channelId].receive = receiveOp;
}
void Channels::eraseSend(SpatChannelSendOp sendOp) {
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return;
auto it = endpoints.find(*channelId);
if (it == endpoints.end())
return;
it->second.send = {};
if (!it->second.receive)
endpoints.erase(it);
}
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return;
auto it = endpoints.find(*channelId);
if (it == endpoints.end())
return;
it->second.receive = {};
if (!it->second.send)
endpoints.erase(it);
}
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
auto it = endpoints.find(id);
if (it == endpoints.end())
return failure();
return it->second;
}
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return failure();
auto endpointsOr = lookup(*channelId);
if (failed(endpointsOr) || !endpointsOr->receive)
return failure();
return endpointsOr->receive;
}
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return failure();
auto endpointsOr = lookup(*channelId);
if (failed(endpointsOr) || !endpointsOr->send)
return failure();
return endpointsOr->send;
}
LogicalResult Channels::verify() const {
for (const auto& [channelId, pair] : endpoints) {
if (!pair.send || !pair.receive) {
if (pair.send) {
auto sendOp = pair.send;
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
}
else if (pair.receive) {
auto receiveOp = pair.receive;
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
}
return failure();
}
if (failed(verifyEndpointPair(pair)))
return failure();
}
return success();
}
} // namespace onnx_mlir::spatial