add automatic patches to onnx-mlir with CMakeLists.txt

This commit is contained in:
NiccoloN
2026-02-25 13:00:55 +01:00
parent 77f815a7a2
commit 5ca8916f4f
10 changed files with 169 additions and 63 deletions

View File

@@ -30,4 +30,59 @@ raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM" "${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
) )
# Patch onnx-mlir sources for PIM accelerator support.
# Each patch searches for a context-aware anchor string rather than relying on
# line numbers, so that moderate upstream changes are tolerated.
function(raptor_apply_patch file_path anchor replacement description)
file(READ "${file_path}" contents)
# Already applied replacement text is present
string(FIND "${contents}" "${replacement}" already_applied_pos)
if(NOT already_applied_pos EQUAL -1)
message(STATUS "Patch already applied: ${description}")
return()
endif()
# Anchor must exist for the patch to be applicable
string(FIND "${contents}" "${anchor}" anchor_pos)
if(anchor_pos EQUAL -1)
message(FATAL_ERROR
"Patch anchor not found onnx-mlir may have changed.\n"
" Patch : ${description}\n"
" File : ${file_path}\n"
" Anchor: ${anchor}"
)
endif()
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
file(WRITE "${file_path}" "${patched}")
message(STATUS "Patch applied: ${description}")
endfunction()
set(ONNX_MLIR_DIR "${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir")
# 1) Disable the absl dependency
raptor_apply_patch(
"${ONNX_MLIR_DIR}/CMakeLists.txt"
"find_package(absl REQUIRED)"
"#find_package(absl REQUIRED)"
"Disable find_package(absl)"
)
# 2) Register PIM compiler options alongside NNPA
raptor_apply_patch(
"${ONNX_MLIR_DIR}/src/Accelerators/Accelerator.hpp"
"#include \"src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp\""
"#include \"src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp\"\n#include \"src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp\""
"Add PIM compiler options include"
)
# 3) Short-circuit output emission for the PIM accelerator
raptor_apply_patch(
"${ONNX_MLIR_DIR}/src/Compiler/CompilerUtils.cpp"
"switch (emissionTarget) {\n case EmitObj: {"
"if (llvm::is_contained(maccel, accel::Accelerator::Kind::PIM))\n return CompilerSuccess;\n switch (emissionTarget) {\n case EmitObj: {"
"Skip output emission for PIM accelerator"
)
add_subdirectory(onnx-mlir) add_subdirectory(onnx-mlir)

View File

@@ -18,8 +18,9 @@ add_subdirectory(Common)
add_onnx_mlir_library(OMPIMAccel add_onnx_mlir_library(OMPIMAccel
PimAccelerator.cpp PimAccelerator.cpp
Transforms/PimBufferizationPass.cpp Transforms/PimBufferizationPass.cpp
Pass/MessagePass.cpp
Pass/CountInstructionPass.cpp Pass/CountInstructionPass.cpp
Pass/EmitPimJsonPass.cpp
Pass/MessagePass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -433,7 +433,7 @@ std::string getMemorySizeAsString(size_t size) {
return std::to_string(size) + " Bytes"; return std::to_string(size) + " Bytes";
} }
int compileModuleToPIMJSON(const OwningOpRef<ModuleOp>& moduleOpRef, std::string& outputDirPath) { OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const OwningOpRef<ModuleOp>& moduleOpRef, std::string& outputDirPath) {
ModuleOp moduleOp = moduleOpRef.get(); ModuleOp moduleOp = moduleOpRef.get();
if (pimEmissionTarget != EmitPimCodegen) { if (pimEmissionTarget != EmitPimCodegen) {

View File

@@ -17,40 +17,44 @@
namespace onnx_mlir { namespace onnx_mlir {
llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget( llvm::cl::opt<std::string> pimOutputDir("pim-output-dir",
llvm::cl::desc("[Optional] Choose PIM-related target to emit " llvm::cl::desc("Directory where pim json code will be emitted"),
"(once selected it will cancel the other targets):"), llvm::cl::init("pim"),
llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::values(clEnumVal(EmitPim, "Lower model to PIM IR")),
llvm::cl::values(
clEnumVal(EmitPimBufferized, "Lower model to PIM IR and bufferize it")),
llvm::cl::values(clEnumVal(EmitPimCodegen, "Lower model to PIM IR and "
"generate code for PIM")),
llvm::cl::init(EmitPimCodegen), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimOnlyCodegen("pim-only-codegen", llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::desc("Only generate code for PIM (assume input is already in " llvm::cl::desc("[Optional] Choose PIM-related target to emit (once selected it will cancel the other targets):"),
"bufferized PIM IR)"), llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::values(clEnumVal(EmitPim, "Lower model to PIM IR")),
llvm::cl::values(clEnumVal(EmitPimBufferized, "Lower model to PIM IR and bufferize it")),
llvm::cl::values(clEnumVal(EmitPimCodegen, "Lower model to PIM IR and generate code for PIM")),
llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl", llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
llvm::cl::desc("Use experimental implementation for convolution"), llvm::cl::desc("Use experimental implementation for convolution"),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t> crossbarSize("crossbar-size", llvm::cl::opt<size_t>
llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2)); crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t> crossbarCountInCore("crossbar-count", llvm::cl::opt<size_t>
llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2)); crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
llvm::cl::opt<long> coresCount("core-count", llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum " llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
"amount of cores."), llvm::cl::init(-1));
llvm::cl::init(-1));
llvm::cl::opt<bool> ignoreConcatError("ignore-concat-error", llvm::cl::opt<bool>
llvm::cl::desc( ignoreConcatError("ignore-concat-error",
"Ignore ConcatOp corner case: do not assert and do a simplification"), llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
llvm::cl::init(false)); llvm::cl::init(false));
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -21,7 +21,8 @@ typedef enum {
} PimEmissionTargetType; } PimEmissionTargetType;
extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<onnx_mlir::PimEmissionTargetType> pimEmissionTarget; extern llvm::cl::opt<std::string> pimOutputDir;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<bool> pimOnlyCodegen; extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl; extern llvm::cl::opt<bool> useExperimentalConvImpl;

View File

@@ -13,7 +13,7 @@ void addPassesPim(mlir::OwningOpRef<mlir::ModuleOp>& module,
EmissionTargetType& emissionTarget, EmissionTargetType& emissionTarget,
std::string outputNameNoExt); std::string outputNameNoExt);
int compileModuleToPIMJSON(const mlir::OwningOpRef<mlir::ModuleOp>& moduleOpRef, OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const mlir::OwningOpRef<mlir::ModuleOp>& moduleOpRef,
std::string& outputDirName); std::string& outputDirName);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -0,0 +1,45 @@
#include "mlir/Pass/Pass.h"
#include <filesystem>
#include "Compiler/PimCompilerOptions.hpp"
#include "Compiler/PimCompilerUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct EmitPimJsonPass : PassWrapper<EmitPimJsonPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPimJsonPass);
StringRef getArgument() const override { return "emit-pim-json-pass"; }
StringRef getDescription() const override { return "Emit json code for the pim simulators"; }
EmitPimJsonPass() {}
EmitPimJsonPass(const EmitPimJsonPass& pass) {}
void runOnOperation() final {
ModuleOp moduleOp = getOperation();
std::filesystem::path pimDir(pimOutputDir.data());
std::error_code error_code;
std::filesystem::create_directories(pimDir, error_code);
if (error_code) {
moduleOp.emitError("Failed to create PIM output directory: " + error_code.message());
signalPassFailure();
return;
}
int compiler_error_code = compileModuleToPIMJSON(moduleOp, pimOutputDir);
if (compiler_error_code != CompilerSuccess)
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> createEmitPimJsonPass() { return std::make_unique<EmitPimJsonPass>(); }
} // namespace onnx_mlir

View File

@@ -15,6 +15,8 @@ std::unique_ptr<Pass> createSpatialToPIMPass();
std::unique_ptr<Pass> createBufferizePimPass(); std::unique_ptr<Pass> createBufferizePimPass();
std::unique_ptr<Pass> createEmitPimJsonPass();
std::unique_ptr<Pass> createMessagePass(std::string message); std::unique_ptr<Pass> createMessagePass(std::string message);
std::unique_ptr<Pass> createCountInstructionPass(); std::unique_ptr<Pass> createCountInstructionPass();

View File

@@ -8,7 +8,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
@@ -19,40 +19,40 @@
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/PimAccelerator.hpp" #include "src/Accelerators/PIM/PimAccelerator.hpp"
#include <memory>
#define DEBUG_TYPE "PimAccelerator" #define DEBUG_TYPE "PimAccelerator"
namespace onnx_mlir { namespace onnx_mlir {
namespace accel { namespace accel {
Accelerator *createPIM() { return PimAccelerator::getInstance(); } Accelerator* createPIM() { return PimAccelerator::getInstance(); }
PimAccelerator *PimAccelerator::instance = nullptr; PimAccelerator* PimAccelerator::instance = nullptr;
PimAccelerator *PimAccelerator::getInstance() { PimAccelerator* PimAccelerator::getInstance() {
if (instance == nullptr) if (instance == nullptr)
instance = new PimAccelerator(); instance = new PimAccelerator();
return instance; return instance;
} }
PimAccelerator::PimAccelerator() : Accelerator(Accelerator::Kind::PIM) { PimAccelerator::PimAccelerator()
: Accelerator(Kind::PIM) {
LLVM_DEBUG(llvm::dbgs() << "Creating a PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Creating a PIM accelerator\n");
acceleratorTargets.push_back(this); acceleratorTargets.push_back(this);
}; }
PimAccelerator::~PimAccelerator() { delete instance; } PimAccelerator::~PimAccelerator() { delete instance; }
uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; } uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; }
void PimAccelerator::addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module, void PimAccelerator::addPasses(OwningOpRef<ModuleOp>& module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget, PassManager& pm,
std::string outputNameNoExt) const { EmissionTargetType& emissionTarget,
std::string outputNameNoExt) const {
LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n");
addPassesPim(module, pm, emissionTarget, outputNameNoExt); addPassesPim(module, pm, emissionTarget, outputNameNoExt);
} }
void PimAccelerator::registerDialects(mlir::DialectRegistry &registry) const { void PimAccelerator::registerDialects(DialectRegistry& registry) const {
LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n");
registry.insert<tensor::TensorDialect>(); registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>(); registry.insert<tosa::TosaDialect>();
@@ -61,8 +61,7 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry &registry) const {
registry.insert<spatial::SpatialDialect>(); registry.insert<spatial::SpatialDialect>();
tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry);
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerBufferizableOpInterfaceExternalModels(registry); pim::registerBufferizableOpInterfaceExternalModels(registry);
@@ -70,11 +69,11 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry &registry) const {
void PimAccelerator::registerPasses(int optLevel) const { void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
// Register here all the passes that could be used registerPass(createONNXToSpatialPass);
mlir::registerPass(createONNXToSpatialPass); registerPass(createSpatialToGraphvizPass);
mlir::registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToPIMPass);
mlir::registerPass(createSpatialToPIMPass); registerPass(createBufferizePimPass);
mlir::registerPass(createBufferizePimPass); registerPass(createEmitPimJsonPass);
} }
void PimAccelerator::configurePasses() const { void PimAccelerator::configurePasses() const {
@@ -82,27 +81,26 @@ void PimAccelerator::configurePasses() const {
// TODO: This does nothing for now. // TODO: This does nothing for now.
} }
mlir::MemRefType PimAccelerator::convertTensorTypeToMemRefType( MemRefType PimAccelerator::convertTensorTypeToMemRefType(const TensorType tensorType) const {
const mlir::TensorType tensorType) const {
// Do not convert tensor types to memref types. // Do not convert tensor types to memref types.
return nullptr; return nullptr;
} }
void PimAccelerator::conversionTargetONNXToKrnl( void PimAccelerator::conversionTargetONNXToKrnl(ConversionTarget& target) const {
mlir::ConversionTarget &target) const {
target.addLegalDialect<pim::PimDialect>(); target.addLegalDialect<pim::PimDialect>();
} }
void PimAccelerator::rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns, void PimAccelerator::rewritePatternONNXToKrnl(RewritePatternSet& patterns,
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const { TypeConverter& typeConverter,
MLIRContext* ctx) const {
// TODO: Add patterns for conversion // TODO: Add patterns for conversion
} }
void PimAccelerator::conversionTargetKrnlToLLVM( void PimAccelerator::conversionTargetKrnlToLLVM(ConversionTarget& target) const {}
mlir::ConversionTarget &target) const {}
void PimAccelerator::rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns, void PimAccelerator::rewritePatternKrnlToLLVM(RewritePatternSet& patterns,
mlir::LLVMTypeConverter &typeConverter, mlir::MLIRContext *ctx) const { LLVMTypeConverter& typeConverter,
MLIRContext* ctx) const {
// We should not need this, since we offload it all to PIM. // We should not need this, since we offload it all to PIM.
} }

View File

@@ -43,7 +43,7 @@ def build_dump_ranges(config_path, outputs_descriptor):
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges): def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges):
subprocess.run( subprocess.run(
["cargo", "run", "--package", "pim-simulator", "--bin", "pim-simulator", "--", ["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir, check=True cwd=simulator_dir, check=True
) )