add PIM accelerator

This commit is contained in:
NiccoloN
2026-02-24 15:09:18 +01:00
parent b24a0df8d7
commit a6e928bdd7
67 changed files with 9109 additions and 1 deletions

View File

@@ -0,0 +1,19 @@
add_onnx_mlir_library(OMPIMCommon
PIMCommon.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
${PIM_ONNX_MLIR_SRC_ROOT}
${PIM_SRC_ROOT}
${PIM_BIN_ROOT}
${PIM_INCLUDE_PATH}
LINK_LIBS PUBLIC
onnx
OMPimCompilerUtils
SpatialOps
PimOps
)

View File

@@ -0,0 +1,67 @@
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
llvm::FailureOr<Operation *> getOtherEndOfChannel(
Operation *op, bool opIsReceive, RewriterBase &rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError(
"User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation *notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
} else if (secondUser == op) {
notOpUser = firstUser;
} else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
} else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,16 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/StringRef.h"
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME =
"pim.constant.should_allocate";
namespace onnx_mlir {
llvm::FailureOr<mlir::Operation *> getOtherEndOfChannel(
mlir::Operation *op, bool opIsReceive, mlir::RewriterBase &rewriter);
} // namespace onnx_mlir

View File

@@ -0,0 +1,44 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
template <typename T>
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
map.erase(result);
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
map.erase(arg);
}
};
template <typename T>
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
NotErasableValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
}
};