remove other dead logic
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-25 21:22:08 +02:00
parent 0f240af271
commit addfc8a86e
14 changed files with 10 additions and 462 deletions
+1 -1
View File
@@ -98,7 +98,7 @@ Supporting pieces:
core count, DCP window, experimental conv impl, concat error handling, …)
and `PimCodeGen` entry points.
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
- `src/PIM/Pass` — auxiliary passes (`MessagePass`)
and the `PIMPasses.h` registry used by `PimAccelerator`.
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
+6 -9
View File
@@ -19,23 +19,21 @@ void markWeightAlways(mlir::Operation* op) {
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeight() == *weightArg;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeight() == *weightArg;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
template <typename VMMOpTy, typename ParentOpTy>
void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeight = [&](mlir::Value weight) {
@@ -49,7 +47,6 @@ void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpO
}
};
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
}
@@ -63,7 +60,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
-4
View File
@@ -30,20 +30,17 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass());
pm.addPass(createMergeComputeNodesPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
}
if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim"));
}
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass());
pm.addPass(createPimStaticMemoryCoalescingPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized"));
}
@@ -54,7 +51,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimCodePass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim code emitted"));
}
}
-10
View File
@@ -74,16 +74,6 @@ def PimHaltOp : PimOp<"halt", [Terminator]> {
}];
}
def PimYieldOp : PimOp<"yield", [Terminator]> {
let summary = "Yield results from a Pim region";
let arguments = (ins
Variadic<PimTensor>:$outputs
);
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Communication
//===----------------------------------------------------------------------===//
-34
View File
@@ -227,40 +227,6 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parser.parseRegion(*body, regionArgs);
}
void PimYieldOp::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueSequence(printer, getOutputs());
printer.printOptionalAttrDict((*this)->getAttrs());
printer << " : ";
printCompressedTypeSequence(printer, getOutputs().getTypes());
}
ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
SmallVector<Type> outputTypes;
OpAsmParser::UnresolvedOperand firstOutput;
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
if (firstOutputResult.has_value()) {
if (failed(*firstOutputResult))
return failure();
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, outputs))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
return failure();
if (outputs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
}
void PimConcatOp::print(OpAsmPrinter& printer) {
printer << " axis " << getAxis() << " ";
printCompressedValueSequence(printer, getInputs());
-1
View File
@@ -2,7 +2,6 @@ add_onnx_mlir_dialect(Spatial spat)
add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps
Channels.cpp
SpatialOps.cpp
SpatialOpsAsm.cpp
SpatialOpsVerify.cpp
-178
View File
@@ -1,178 +0,0 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
using namespace mlir;
namespace onnx_mlir::spatial {
namespace {
static FailureOr<int64_t> getConstantI64(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return constantValue.getSExtValue();
}
static FailureOr<int32_t> getConstantI32(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
return getConstantI64(sendOp.getChannelId());
}
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
return getConstantI64(receiveOp.getChannelId());
}
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
return getConstantI32(receiveOp.getSourceCoreId());
}
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
return getConstantI32(receiveOp.getTargetCoreId());
}
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
if (!endpoints.send || !endpoints.receive)
return failure();
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
return failure();
}
if (*sendSourceCoreId != *receiveSourceCoreId) {
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
return failure();
}
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
return failure();
}
if (*sendTargetCoreId != *receiveTargetCoreId) {
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
return failure();
}
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
return failure();
}
return success();
}
} // namespace
Channels::Channels(func::FuncOp funcOp) {
if (!funcOp)
return;
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
}
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
void Channels::insertSend(SpatChannelSendOp sendOp) {
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return;
nextChannelId = std::max(nextChannelId, *channelId + 1);
endpoints[*channelId].send = sendOp;
}
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return;
nextChannelId = std::max(nextChannelId, *channelId + 1);
endpoints[*channelId].receive = receiveOp;
}
void Channels::eraseSend(SpatChannelSendOp sendOp) {
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return;
auto it = endpoints.find(*channelId);
if (it == endpoints.end())
return;
it->second.send = {};
if (!it->second.receive)
endpoints.erase(it);
}
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return;
auto it = endpoints.find(*channelId);
if (it == endpoints.end())
return;
it->second.receive = {};
if (!it->second.send)
endpoints.erase(it);
}
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
auto it = endpoints.find(id);
if (it == endpoints.end())
return failure();
return it->second;
}
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return failure();
auto endpointsOr = lookup(*channelId);
if (failed(endpointsOr) || !endpointsOr->receive)
return failure();
return endpointsOr->receive;
}
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return failure();
auto endpointsOr = lookup(*channelId);
if (failed(endpointsOr) || !endpointsOr->send)
return failure();
return endpointsOr->send;
}
LogicalResult Channels::verify() const {
for (const auto& [channelId, pair] : endpoints) {
if (!pair.send || !pair.receive) {
if (pair.send) {
auto sendOp = pair.send;
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
}
else if (pair.receive) {
auto receiveOp = pair.receive;
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
}
return failure();
}
if (failed(verifyEndpointPair(pair)))
return failure();
}
return success();
}
} // namespace onnx_mlir::spatial
-43
View File
@@ -1,43 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir::spatial {
struct ChannelEndpoints {
SpatChannelSendOp send;
SpatChannelReceiveOp receive;
};
class Channels {
public:
using ChannelId = int64_t;
explicit Channels(mlir::func::FuncOp funcOp);
ChannelId allocate();
void insertSend(SpatChannelSendOp sendOp);
void insertReceive(SpatChannelReceiveOp receiveOp);
void eraseSend(SpatChannelSendOp sendOp);
void eraseReceive(SpatChannelReceiveOp receiveOp);
llvm::FailureOr<ChannelEndpoints> lookup(ChannelId id) const;
llvm::FailureOr<SpatChannelReceiveOp> getReceiveFor(SpatChannelSendOp sendOp) const;
llvm::FailureOr<SpatChannelSendOp> getSendFor(SpatChannelReceiveOp receiveOp) const;
mlir::LogicalResult verify() const;
private:
ChannelId nextChannelId = 0;
llvm::DenseMap<ChannelId, ChannelEndpoints> endpoints;
};
} // namespace onnx_mlir::spatial
-35
View File
@@ -217,25 +217,6 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
}];
}
def SpatMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins
SpatTensor:$weight,
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
}];
}
def SpatVAddOp : SpatOp<"vadd", []> {
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
@@ -272,22 +253,6 @@ def SpatVMulOp : SpatOp<"vmul", []> {
}];
}
def SpatSumOp : SpatOp<"sum", []> {
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatVAvgOp : SpatOp<"vavg", []> {
let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor";
@@ -26,68 +26,6 @@ namespace spatial {
namespace {
inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitter->emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vectorM = vectorShape[0];
int64_t vector1 = vectorShape[1];
if (vectorM != M || vector1 != 1)
return emitter->emitError("vector shape must be (M, 1)");
int64_t outputN = outputShape[0];
int64_t output1 = outputShape[1];
if (outputN != N || output1 != 1)
return emitter->emitError("output shape must be (N, 1)");
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
return emitter->emitError("matrix, vector and output must have rank 4");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
int64_t matrix1First = matrixShape[2];
int64_t matrix1Second = matrixShape[3];
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
int64_t vector1First = vectorShape[0];
int64_t vectorM = vectorShape[1];
int64_t vector1Second = vectorShape[2];
int64_t vector1Third = vectorShape[3];
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
// This is ok, it was caused by the simplification of the concat error.
}
else {
return emitter->emitError("vector shape must be (1, M, 1, 1)");
}
}
int64_t output1First = outputShape[0];
int64_t outputN = outputShape[1];
int64_t output1Second = outputShape[2];
int64_t output1Third = outputShape[3];
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
return emitter->emitError("output shape must be (1, N, 1, 1)");
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
@@ -287,21 +225,6 @@ static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
} // namespace
LogicalResult SpatMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt))
return emitError("weight must be a shaped value");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
if (matrixShape.size() == 2)
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
if (matrixShape.size() == 4)
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt))
@@ -136,7 +136,6 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
struct ComputeMotifInfo {
uint64_t instructionCount = 0;
uint64_t weightedMvmCount = 0;
uint64_t weightedVmmCount = 0;
};
@@ -285,8 +284,6 @@ void emitMotifProfile(func::FuncOp funcOp) {
ComputeMotifInfo& info = computeInfos[index];
for (Operation& op : compute.getBody().front()) {
info.instructionCount++;
if (isa<spatial::SpatMVMOp>(&op))
info.weightedMvmCount++;
if (isa<spatial::SpatVMMOp>(&op))
info.weightedVmmCount++;
}
@@ -400,7 +397,7 @@ void emitMotifProfile(func::FuncOp funcOp) {
wideWeightedVmmLevels256 += count >= 256;
}
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
SmallVector<ShapeKey> weightedVmmShapeKeys;
for (auto [index, compute] : llvm::enumerate(computes)) {
const ComputeMotifInfo& info = computeInfos[index];
@@ -408,7 +405,6 @@ void emitMotifProfile(func::FuncOp funcOp) {
continue;
weightedVmmShapeKeys.push_back({info.instructionCount,
info.weightedVmmCount,
info.weightedMvmCount,
static_cast<uint64_t>(compute.getWeights().size()),
static_cast<uint64_t>(compute.getInputs().size()),
static_cast<uint64_t>(parents[index].size()),
@@ -461,14 +457,13 @@ void emitMotifProfile(func::FuncOp funcOp) {
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
auto [count, shape] = weightedVmmShapeCounts[rank];
auto [insts, vmmOps, mvmOps, weights, inputs, fanIn, fanOut] = shape;
auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape;
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
"mvmOps={4} weights={5} inputs={6} fanIn={7} fanOut={8}\n",
"weights={4} inputs={5} fanIn={6} fanOut={7}\n",
rank,
count,
insts,
vmmOps,
mvmOps,
weights,
inputs,
fanIn,
-1
View File
@@ -1,5 +1,4 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
PimCodegen/HostConstantFolding/Common.cpp
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
-59
View File
@@ -1,59 +0,0 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct CountInstructionPass : public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass)
StringRef getArgument() const override { return "count-instruction-pass"; }
StringRef getDescription() const override { return "Count instructions for each core/compute in the module"; }
// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
CountInstructionPass() {}
CountInstructionPass(const CountInstructionPass& pass)
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
void runOnOperation() final {
ModuleOp module = getOperation();
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
unsigned totalInstructionCount = 0;
unsigned computeId = 0;
for (auto computeOp : func.getOps<spatial::SpatCompute>()) {
unsigned instructionCount = 0;
instructionCount += computeOp.getBody().front().getOperations().size();
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
totalInstructionCount += instructionCount;
computeId++;
}
unsigned coreId = 0;
for (auto coreOp : func.getOps<pim::PimCoreOp>()) {
unsigned instructionCount = 0;
instructionCount += coreOp.getBody().front().getOperations().size();
llvm::outs() << "Core " << coreId << ": " << instructionCount << " instructions\n";
totalInstructionCount += instructionCount;
coreId++;
}
llvm::outs() << "Total instruction count: " << totalInstructionCount << "\n";
}
};
} // namespace
std::unique_ptr<Pass> createCountInstructionPass() { return std::make_unique<CountInstructionPass>(); }
} // namespace onnx_mlir
-2
View File
@@ -29,6 +29,4 @@ std::unique_ptr<mlir::Pass> createEmitPimCodePass();
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
std::unique_ptr<mlir::Pass> createCountInstructionPass();
} // namespace onnx_mlir