fix missed failing tests for channels

moderate refactor
This commit is contained in:
NiccoloN
2026-04-14 12:26:41 +02:00
parent 30ee9640d4
commit eade488d13
30 changed files with 115 additions and 50 deletions

View File

@@ -1,13 +1,13 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
Pim/ConstantFolding/Common.cpp
Pim/ConstantFolding/Patterns/Constant.cpp
Pim/ConstantFolding/ConstantFoldingPass.cpp
Pim/ConstantFolding/Patterns/Subview.cpp
Pim/MaterializeConstantsPass.cpp
Pim/VerificationPass.cpp
Pim/EmitPimJsonPass.cpp
PimCodegen/HostConstantFolding/Common.cpp
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
PimCodegen/MaterializeHostConstantsPass.cpp
PimCodegen/VerificationPass.cpp
PimCodegen/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
std::unique_ptr<mlir::Pass> createMergeComputeNodePass();
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();

View File

@@ -11,10 +11,10 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
} // namespace onnx_mlir

View File

@@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
@@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
return std::make_unique<MaterializeHostConstantsPass>();
}
} // namespace onnx_mlir