This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
@@ -9,19 +10,62 @@ namespace onnx_mlir::spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
|
||||
static FailureOr<int64_t> getConstantI64(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
|
||||
static FailureOr<int32_t> getConstantI32(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
|
||||
return getConstantI64(sendOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI64(receiveOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getSourceCoreId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getTargetCoreId());
|
||||
}
|
||||
|
||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
||||
if (!endpoints.send || !endpoints.receive)
|
||||
return failure();
|
||||
|
||||
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
|
||||
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
|
||||
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendSourceCoreId != *receiveSourceCoreId) {
|
||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
|
||||
|
||||
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
|
||||
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendTargetCoreId != *receiveTargetCoreId) {
|
||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
@@ -46,20 +90,26 @@ Channels::Channels(func::FuncOp funcOp) {
|
||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
||||
|
||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].send = sendOp;
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].send = sendOp;
|
||||
}
|
||||
|
||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].receive = receiveOp;
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].receive = receiveOp;
|
||||
}
|
||||
|
||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.send = {};
|
||||
@@ -68,8 +118,10 @@ void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
}
|
||||
|
||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.receive = {};
|
||||
@@ -85,14 +137,20 @@ FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(sendOp));
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
||||
return failure();
|
||||
return endpointsOr->receive;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(receiveOp));
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->send)
|
||||
return failure();
|
||||
return endpointsOr->send;
|
||||
|
||||
Reference in New Issue
Block a user