44 lines
1.1 KiB
C++
44 lines
1.1 KiB
C++
#pragma once
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
namespace onnx_mlir::spatial {
|
|
|
|
struct ChannelEndpoints {
|
|
SpatChannelSendOp send;
|
|
SpatChannelReceiveOp receive;
|
|
};
|
|
|
|
class Channels {
|
|
public:
|
|
using ChannelId = int64_t;
|
|
|
|
explicit Channels(mlir::func::FuncOp funcOp);
|
|
|
|
ChannelId allocate();
|
|
|
|
void insertSend(SpatChannelSendOp sendOp);
|
|
void insertReceive(SpatChannelReceiveOp receiveOp);
|
|
void eraseSend(SpatChannelSendOp sendOp);
|
|
void eraseReceive(SpatChannelReceiveOp receiveOp);
|
|
|
|
llvm::FailureOr<ChannelEndpoints> lookup(ChannelId id) const;
|
|
llvm::FailureOr<SpatChannelReceiveOp> getReceiveFor(SpatChannelSendOp sendOp) const;
|
|
llvm::FailureOr<SpatChannelSendOp> getSendFor(SpatChannelReceiveOp receiveOp) const;
|
|
|
|
mlir::LogicalResult verify() const;
|
|
|
|
private:
|
|
ChannelId nextChannelId = 0;
|
|
llvm::DenseMap<ChannelId, ChannelEndpoints> endpoints;
|
|
};
|
|
|
|
} // namespace onnx_mlir::spatial
|