refactor Pim passes layout

This commit is contained in:
NiccoloN
2026-03-23 20:14:59 +01:00
parent f869925b64
commit da01e6d697
18 changed files with 34 additions and 35 deletions

View File

@@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimCodeGen.cpp
../Pass/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerPasses.hpp"
#define DEBUG_TYPE "PimCompilerUtils"
@@ -45,10 +45,10 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass());
pm.addPass(createConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimMaterializeConstantsPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMaterializeConstantsPass());
pm.addPass(createVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass());

View File

@@ -15,7 +15,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -12,7 +12,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)

View File

@@ -24,7 +24,7 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;

View File

@@ -9,7 +9,7 @@
#include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;

View File

@@ -1,12 +1,13 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
PimConstantFolding/Common.cpp
PimConstantFolding/Patterns/Constant.cpp
PimConstantFolding/PimConstantFoldingPass.cpp
PimConstantFolding/Patterns/Subview.cpp
PimMaterializeConstantsPass.cpp
PimVerificationPass.cpp
Pim/ConstantFolding/Common.cpp
Pim/ConstantFolding/Patterns/Constant.cpp
Pim/ConstantFolding/ConstantFoldingPass.cpp
Pim/ConstantFolding/Patterns/Subview.cpp
Pim/MaterializeConstantsPass.cpp
Pim/VerificationPass.cpp
Pim/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
std::unique_ptr<mlir::Pass> createVerificationPass();
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();

View File

@@ -12,8 +12,8 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
@@ -48,6 +48,6 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<PimConstantFoldingPass>(); }
std::unique_ptr<Pass> createConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
} // namespace onnx_mlir

View File

@@ -31,9 +31,8 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct PimMaterializeConstantsPass
: PassWrapper<PimMaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimMaterializeConstantsPass)
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getDescription() const override {
@@ -128,8 +127,8 @@ struct PimMaterializeConstantsPass
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() {
return std::make_unique<PimMaterializeConstantsPass>();
std::unique_ptr<Pass> createMaterializeConstantsPass() {
return std::make_unique<MaterializeConstantsPass>();
}
} // namespace onnx_mlir

View File

@@ -43,16 +43,16 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
return false;
}
struct PimVerificationPass : PassWrapper<PimVerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimVerificationPass)
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
StringRef getArgument() const override { return "verify-pim-pass"; }
StringRef getDescription() const override {
return "Verify that bufferized PIM IR contains only explicit host/device transfers";
}
PimVerificationPass() {}
PimVerificationPass(const PimVerificationPass& pass) {}
VerificationPass() {}
VerificationPass(const VerificationPass& pass) {}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
@@ -189,6 +189,6 @@ private:
} // namespace
std::unique_ptr<Pass> createPimVerificationPass() { return std::make_unique<PimVerificationPass>(); }
std::unique_ptr<Pass> createVerificationPass() { return std::make_unique<VerificationPass>(); }
} // namespace onnx_mlir

View File

@@ -17,7 +17,7 @@
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/PimAccelerator.hpp"
#define DEBUG_TYPE "PimAccelerator"
@@ -73,9 +73,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass);
registerPass(createPimConstantFoldingPass);
registerPass(createPimMaterializeConstantsPass);
registerPass(createPimVerificationPass);
registerPass(createConstantFoldingPass);
registerPass(createMaterializeConstantsPass);
registerPass(createVerificationPass);
registerPass(createEmitPimJsonPass);
}