refactor Pim passes layout

This commit is contained in:
NiccoloN
2026-03-23 20:14:59 +01:00
parent f869925b64
commit da01e6d697
18 changed files with 34 additions and 35 deletions

View File

@@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp PimCompilerUtils.cpp
PimCodeGen.cpp PimCodeGen.cpp
../Pass/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#define DEBUG_TYPE "PimCompilerUtils" #define DEBUG_TYPE "PimCompilerUtils"
@@ -45,10 +45,10 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass()); pm.addPass(createConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded")); pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimMaterializeConstantsPass()); pm.addPass(createMaterializeConstantsPass());
pm.addPass(createPimVerificationPass()); pm.addPass(createVerificationPass());
pm.addPass(createMessagePass("Pim verified")); pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass()); pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());

View File

@@ -15,7 +15,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -12,7 +12,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0) #define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)

View File

@@ -24,7 +24,7 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -9,7 +9,7 @@
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -1,12 +1,13 @@
add_pim_library(OMPimPasses add_pim_library(OMPimPasses
CountInstructionPass.cpp CountInstructionPass.cpp
MessagePass.cpp MessagePass.cpp
PimConstantFolding/Common.cpp Pim/ConstantFolding/Common.cpp
PimConstantFolding/Patterns/Constant.cpp Pim/ConstantFolding/Patterns/Constant.cpp
PimConstantFolding/PimConstantFoldingPass.cpp Pim/ConstantFolding/ConstantFoldingPass.cpp
PimConstantFolding/Patterns/Subview.cpp Pim/ConstantFolding/Patterns/Subview.cpp
PimMaterializeConstantsPass.cpp Pim/MaterializeConstantsPass.cpp
PimVerificationPass.cpp Pim/VerificationPass.cpp
Pim/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass(); std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass(); std::unique_ptr<mlir::Pass> createConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass(); std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass(); std::unique_ptr<mlir::Pass> createVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass(); std::unique_ptr<mlir::Pass> createEmitPimJsonPass();

View File

@@ -12,8 +12,8 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> { struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; } StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
@@ -48,6 +48,6 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
} // namespace } // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); } std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -31,9 +31,8 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8; return type.getNumElements() * type.getElementTypeBitWidth() / 8;
} }
struct PimMaterializeConstantsPass struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; } StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getDescription() const override { StringRef getDescription() const override {
@@ -128,8 +127,8 @@ struct PimMaterializeConstantsPass
} // namespace } // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { std::unique_ptr<Pass> createMaterializeConstantsPass() {
return std::make_unique<PimMaterializeConstantsPass>(); return std::make_unique<MaterializeConstantsPass>();
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -43,16 +43,16 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
return false; return false;
} }
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> { struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; } StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override { StringRef getDescription() const override {
return "Verify that bufferized PIM IR contains only explicit host/device transfers"; return "Verify that bufferized PIM IR contains only explicit host/device transfers";
} }
PimVerificationPass() {} VerificationPass() {}
PimVerificationPass(const PimVerificationPass& pass) {} VerificationPass(const VerificationPass& pass) {}
void runOnOperation() override { void runOnOperation() override {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
@@ -189,6 +189,6 @@ private:
} // namespace } // namespace
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); } std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -17,7 +17,7 @@
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/PimAccelerator.hpp" #include "src/Accelerators/PIM/PimAccelerator.hpp"
#define DEBUG_TYPE "PimAccelerator" #define DEBUG_TYPE "PimAccelerator"
@@ -73,9 +73,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass); registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass); registerPass(createBufferizePimPass);
registerPass(createPimConstantFoldingPass); registerPass(createConstantFoldingPass);
registerPass(createPimMaterializeConstantsPass); registerPass(createMaterializeConstantsPass);
registerPass(createPimVerificationPass); registerPass(createVerificationPass);
registerPass(createEmitPimJsonPass); registerPass(createEmitPimJsonPass);
} }