refactor Pim passes layout

This commit is contained in:
NiccoloN
2026-03-23 20:14:59 +01:00
parent f869925b64
commit da01e6d697
18 changed files with 34 additions and 35 deletions

View File

@@ -1,12 +1,13 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
PimConstantFolding/Common.cpp
PimConstantFolding/Patterns/Constant.cpp
PimConstantFolding/PimConstantFoldingPass.cpp
PimConstantFolding/Patterns/Subview.cpp
PimMaterializeConstantsPass.cpp
PimVerificationPass.cpp
Pim/ConstantFolding/Common.cpp
Pim/ConstantFolding/Patterns/Constant.cpp
Pim/ConstantFolding/ConstantFoldingPass.cpp
Pim/ConstantFolding/Patterns/Subview.cpp
Pim/MaterializeConstantsPass.cpp
Pim/VerificationPass.cpp
Pim/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();

View File

@@ -12,8 +12,8 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
@@ -48,6 +48,6 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
} // namespace onnx_mlir

View File

@@ -31,9 +31,8 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct PimMaterializeConstantsPass
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getDescription() const override {
@@ -128,8 +127,8 @@ struct PimMaterializeConstantsPass
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
return std::make_unique<PimMaterializeConstantsPass>();
std::unique_ptr<Pass> createMaterializeConstantsPass() {
return std::make_unique<MaterializeConstantsPass>();
}
} // namespace onnx_mlir

View File

@@ -43,16 +43,16 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
return false;
}
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override {
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
}
PimVerificationPass() {}
PimVerificationPass(const PimVerificationPass& pass) {}
VerificationPass() {}
VerificationPass(const VerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
@@ -189,6 +189,6 @@ private:
} // namespace
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
} // namespace onnx_mlir