compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
This commit is contained in:
@@ -0,0 +1,120 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
|
||||
|
||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
||||
if (!endpoints.send || !endpoints.receive)
|
||||
return failure();
|
||||
|
||||
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
|
||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
|
||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
|
||||
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Channels::Channels(func::FuncOp funcOp) {
|
||||
if (!funcOp)
|
||||
return;
|
||||
|
||||
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
|
||||
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
|
||||
}
|
||||
|
||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
||||
|
||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].send = sendOp;
|
||||
}
|
||||
|
||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].receive = receiveOp;
|
||||
}
|
||||
|
||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.send = {};
|
||||
if (!it->second.receive)
|
||||
endpoints.erase(it);
|
||||
}
|
||||
|
||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.receive = {};
|
||||
if (!it->second.send)
|
||||
endpoints.erase(it);
|
||||
}
|
||||
|
||||
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
||||
auto it = endpoints.find(id);
|
||||
if (it == endpoints.end())
|
||||
return failure();
|
||||
return it->second;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(sendOp));
|
||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
||||
return failure();
|
||||
return endpointsOr->receive;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(receiveOp));
|
||||
if (failed(endpointsOr) || !endpointsOr->send)
|
||||
return failure();
|
||||
return endpointsOr->send;
|
||||
}
|
||||
|
||||
LogicalResult Channels::verify() const {
|
||||
for (const auto& [channelId, pair] : endpoints) {
|
||||
if (!pair.send || !pair.receive) {
|
||||
if (pair.send) {
|
||||
auto sendOp = pair.send;
|
||||
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
|
||||
}
|
||||
else if (pair.receive) {
|
||||
auto receiveOp = pair.receive;
|
||||
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (failed(verifyEndpointPair(pair)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
Reference in New Issue
Block a user