179 lines
5.6 KiB
C++
179 lines
5.6 KiB
C++
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir::spatial {
|
|
|
|
namespace {
|
|
|
|
static FailureOr<int64_t> getConstantI64(Value value) {
|
|
APInt constantValue;
|
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
|
return failure();
|
|
return constantValue.getSExtValue();
|
|
}
|
|
|
|
static FailureOr<int32_t> getConstantI32(Value value) {
|
|
APInt constantValue;
|
|
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
|
return failure();
|
|
return static_cast<int32_t>(constantValue.getSExtValue());
|
|
}
|
|
|
|
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
|
|
return getConstantI64(sendOp.getChannelId());
|
|
}
|
|
|
|
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
|
|
return getConstantI64(receiveOp.getChannelId());
|
|
}
|
|
|
|
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
|
|
|
|
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
|
|
return getConstantI32(receiveOp.getSourceCoreId());
|
|
}
|
|
|
|
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
|
|
|
|
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
|
|
return getConstantI32(receiveOp.getTargetCoreId());
|
|
}
|
|
|
|
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
|
if (!endpoints.send || !endpoints.receive)
|
|
return failure();
|
|
|
|
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
|
|
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
|
|
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
|
|
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
|
|
return failure();
|
|
}
|
|
if (*sendSourceCoreId != *receiveSourceCoreId) {
|
|
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
|
return failure();
|
|
}
|
|
|
|
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
|
|
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
|
|
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
|
|
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
|
|
return failure();
|
|
}
|
|
if (*sendTargetCoreId != *receiveTargetCoreId) {
|
|
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
|
return failure();
|
|
}
|
|
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
|
|
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
Channels::Channels(func::FuncOp funcOp) {
|
|
if (!funcOp)
|
|
return;
|
|
|
|
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
|
|
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
|
|
}
|
|
|
|
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
|
|
|
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
|
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
|
if (failed(channelId))
|
|
return;
|
|
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
|
endpoints[*channelId].send = sendOp;
|
|
}
|
|
|
|
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
|
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
|
if (failed(channelId))
|
|
return;
|
|
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
|
endpoints[*channelId].receive = receiveOp;
|
|
}
|
|
|
|
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
|
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
|
if (failed(channelId))
|
|
return;
|
|
auto it = endpoints.find(*channelId);
|
|
if (it == endpoints.end())
|
|
return;
|
|
it->second.send = {};
|
|
if (!it->second.receive)
|
|
endpoints.erase(it);
|
|
}
|
|
|
|
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
|
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
|
if (failed(channelId))
|
|
return;
|
|
auto it = endpoints.find(*channelId);
|
|
if (it == endpoints.end())
|
|
return;
|
|
it->second.receive = {};
|
|
if (!it->second.send)
|
|
endpoints.erase(it);
|
|
}
|
|
|
|
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
|
auto it = endpoints.find(id);
|
|
if (it == endpoints.end())
|
|
return failure();
|
|
return it->second;
|
|
}
|
|
|
|
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
|
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
|
if (failed(channelId))
|
|
return failure();
|
|
auto endpointsOr = lookup(*channelId);
|
|
if (failed(endpointsOr) || !endpointsOr->receive)
|
|
return failure();
|
|
return endpointsOr->receive;
|
|
}
|
|
|
|
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
|
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
|
if (failed(channelId))
|
|
return failure();
|
|
auto endpointsOr = lookup(*channelId);
|
|
if (failed(endpointsOr) || !endpointsOr->send)
|
|
return failure();
|
|
return endpointsOr->send;
|
|
}
|
|
|
|
LogicalResult Channels::verify() const {
|
|
for (const auto& [channelId, pair] : endpoints) {
|
|
if (!pair.send || !pair.receive) {
|
|
if (pair.send) {
|
|
auto sendOp = pair.send;
|
|
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
|
|
}
|
|
else if (pair.receive) {
|
|
auto receiveOp = pair.receive;
|
|
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
|
|
}
|
|
return failure();
|
|
}
|
|
if (failed(verifyEndpointPair(pair)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
} // namespace onnx_mlir::spatial
|