#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" #include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" using namespace mlir; namespace onnx_mlir::spatial { namespace { static FailureOr getConstantI64(Value value) { APInt constantValue; if (!matchPattern(value, m_ConstantInt(&constantValue))) return failure(); return constantValue.getSExtValue(); } static FailureOr getConstantI32(Value value) { APInt constantValue; if (!matchPattern(value, m_ConstantInt(&constantValue))) return failure(); return static_cast(constantValue.getSExtValue()); } static FailureOr getChannelId(SpatChannelSendOp sendOp) { return getConstantI64(sendOp.getChannelId()); } static FailureOr getChannelId(SpatChannelReceiveOp receiveOp) { return getConstantI64(receiveOp.getChannelId()); } static FailureOr getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); } static FailureOr getSourceCoreId(SpatChannelReceiveOp receiveOp) { return getConstantI32(receiveOp.getSourceCoreId()); } static FailureOr getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); } static FailureOr getTargetCoreId(SpatChannelReceiveOp receiveOp) { return getConstantI32(receiveOp.getTargetCoreId()); } static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) { if (!endpoints.send || !endpoints.receive) return failure(); FailureOr sendSourceCoreId = getSourceCoreId(endpoints.send); FailureOr receiveSourceCoreId = getSourceCoreId(endpoints.receive); if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) { endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands"); return failure(); } if (*sendSourceCoreId != *receiveSourceCoreId) { endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive"); return failure(); } FailureOr sendTargetCoreId = getTargetCoreId(endpoints.send); FailureOr receiveTargetCoreId = getTargetCoreId(endpoints.receive); if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) { endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands"); return failure(); } if (*sendTargetCoreId != *receiveTargetCoreId) { endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive"); return failure(); } if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) { endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type"); return failure(); } return success(); } } // namespace Channels::Channels(func::FuncOp funcOp) { if (!funcOp) return; funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); }); funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); }); } Channels::ChannelId Channels::allocate() { return nextChannelId++; } void Channels::insertSend(SpatChannelSendOp sendOp) { FailureOr channelId = getChannelId(sendOp); if (failed(channelId)) return; nextChannelId = std::max(nextChannelId, *channelId + 1); endpoints[*channelId].send = sendOp; } void Channels::insertReceive(SpatChannelReceiveOp receiveOp) { FailureOr channelId = getChannelId(receiveOp); if (failed(channelId)) return; nextChannelId = std::max(nextChannelId, *channelId + 1); endpoints[*channelId].receive = receiveOp; } void Channels::eraseSend(SpatChannelSendOp sendOp) { FailureOr channelId = getChannelId(sendOp); if (failed(channelId)) return; auto it = endpoints.find(*channelId); if (it == endpoints.end()) return; it->second.send = {}; if (!it->second.receive) endpoints.erase(it); } void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) { FailureOr channelId = getChannelId(receiveOp); if (failed(channelId)) return; auto it = endpoints.find(*channelId); if (it == endpoints.end()) return; it->second.receive = {}; if (!it->second.send) endpoints.erase(it); } FailureOr Channels::lookup(ChannelId id) const { auto it = endpoints.find(id); if (it == endpoints.end()) return failure(); return it->second; } FailureOr Channels::getReceiveFor(SpatChannelSendOp sendOp) const { FailureOr channelId = getChannelId(sendOp); if (failed(channelId)) return failure(); auto endpointsOr = lookup(*channelId); if (failed(endpointsOr) || !endpointsOr->receive) return failure(); return endpointsOr->receive; } FailureOr Channels::getSendFor(SpatChannelReceiveOp receiveOp) const { FailureOr channelId = getChannelId(receiveOp); if (failed(channelId)) return failure(); auto endpointsOr = lookup(*channelId); if (failed(endpointsOr) || !endpointsOr->send) return failure(); return endpointsOr->send; } LogicalResult Channels::verify() const { for (const auto& [channelId, pair] : endpoints) { if (!pair.send || !pair.receive) { if (pair.send) { auto sendOp = pair.send; sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive"; } else if (pair.receive) { auto receiveOp = pair.receive; receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send"; } return failure(); } if (failed(verifyEndpointPair(pair))) return failure(); } return success(); } } // namespace onnx_mlir::spatial