Files
Raptor/src/PIM/Pass/CountInstructionPass.cpp
NiccoloN 6e1de865bb add constant folding and verification pass for pim host operations
better validation scripts output
big refactors
2026-03-20 12:08:12 +01:00

61 lines
2.0 KiB
C++

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct CountInstructionPass : public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass)
StringRef getArgument() const override { return "count-instruction-pass"; }
StringRef getDescription() const override { return "Count instructions for each core/compute in the module"; }
// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
CountInstructionPass() {}
CountInstructionPass(const CountInstructionPass& pass)
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
void runOnOperation() final {
ModuleOp module = getOperation();
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
unsigned totalInstructionCount = 0;
unsigned computeId = 0;
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) {
unsigned instructionCount = 0;
instructionCount += computeOp.getBody().front().getOperations().size();
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
totalInstructionCount += instructionCount;
computeId++;
}
unsigned coreId = 0;
for (auto coreOp : func.getOps<pim::PimCoreOp>()) {
unsigned instructionCount = 0;
instructionCount += coreOp.getBody().front().getOperations().size();
llvm::outs() << "Core " << coreId << ": " << instructionCount << " instructions\n";
totalInstructionCount += instructionCount;
coreId++;
}
llvm::outs() << "Total instruction count: " << totalInstructionCount << "\n";
}
};
} // namespace
std::unique_ptr<Pass> createCountInstructionPass() { return std::make_unique<CountInstructionPass>(); }
} // namespace onnx_mlir