From 461bdd808d52e41b060826d78a0a133959c38e3e Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 23 Mar 2026 17:13:54 +0100 Subject: [PATCH] replace helper-op cleanup with canonicalization clean up PIM pattern naming remove unused ValueMap.hpp --- src/PIM/Common/ValueMap.hpp | 44 ------------------- src/PIM/Compiler/PimCodeGen.hpp | 2 - .../Conversion/ONNXToSpatial/CMakeLists.txt | 5 +-- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 12 ++--- src/PIM/Conversion/ONNXToSpatial/Patterns.hpp | 2 - ...NNXConcatToTensorConcat.cpp => Concat.cpp} | 6 +-- .../Patterns/Tensor/RemoveUnusedHelperOps.cpp | 35 --------------- ...ReshapeToTensorReshape.cpp => Reshape.cpp} | 4 +- .../SpatialToPim/SpatialToPimPass.cpp | 30 +++---------- src/PIM/Pass/CMakeLists.txt | 4 +- .../{ConstantPatterns.cpp => Constant.cpp} | 0 .../{SubviewPatterns.cpp => Subview.cpp} | 0 12 files changed, 21 insertions(+), 123 deletions(-) delete mode 100644 src/PIM/Common/ValueMap.hpp rename src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/{ONNXConcatToTensorConcat.cpp => Concat.cpp} (81%) delete mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/RemoveUnusedHelperOps.cpp rename src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/{ONNXReshapeToTensorReshape.cpp => Reshape.cpp} (96%) rename src/PIM/Pass/PimConstantFolding/Patterns/{ConstantPatterns.cpp => Constant.cpp} (100%) rename src/PIM/Pass/PimConstantFolding/Patterns/{SubviewPatterns.cpp => Subview.cpp} (100%) diff --git a/src/PIM/Common/ValueMap.hpp b/src/PIM/Common/ValueMap.hpp deleted file mode 100644 index c8e249a..0000000 --- a/src/PIM/Common/ValueMap.hpp +++ /dev/null @@ -1,44 +0,0 @@ -#pragma once - -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" - -#include "llvm/ADT/DenseMap.h" - -template -class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener { -public: - llvm::DenseMap map; - - AutoCleaningValueMap(mlir::OpBuilder::Listener listener) - : ForwardingListener(&listener) {} - - void notifyOperationErased(mlir::Operation* op) override { - for (mlir::Value result : op->getResults()) - map.erase(result); - } - - void notifyBlockErased(mlir::Block* block) override { - for (mlir::BlockArgument arg : block->getArguments()) - map.erase(arg); - } -}; - -template -class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener { -public: - llvm::DenseMap map; - - NotErasableValueMap(mlir::OpBuilder::Listener listener) - : ForwardingListener(&listener) {} - - void notifyOperationErased(mlir::Operation* op) override { - for (mlir::Value result : op->getResults()) - assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result)); - } - - void notifyBlockErased(mlir::Block* block) override { - for (mlir::BlockArgument arg : block->getArguments()) - assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg)); - } -}; diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index df5b170..38c4bc0 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -4,8 +4,6 @@ #include "llvm/Support/JSON.h" #include "onnx-mlir/Compiler/OMCompilerTypes.h" - -#include "Common/ValueMap.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" namespace onnx_mlir { diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index ce76d78..2ecfe24 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -8,9 +8,8 @@ add_pim_library(OMONNXToSpatial Patterns/Math/MatMul.cpp Patterns/NN/Pooling.cpp Patterns/NN/ReduceMean.cpp - Patterns/Tensor/ONNXConcatToTensorConcat.cpp - Patterns/Tensor/ONNXReshapeToTensorReshape.cpp - Patterns/Tensor/RemoveUnusedHelperOps.cpp + Patterns/Tensor/Concat.cpp + Patterns/Tensor/Reshape.cpp Utils/SpatialReducer.cpp Utils/WeightSubdivider.cpp Utils/AnnotateReplication.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index b6faf7d..bb929e1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -1,7 +1,9 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_os_ostream.h" @@ -117,12 +119,10 @@ void ONNXToSpatialPass::runOnOperation() { } } - // Remove trailing "helper ops" i.e. concat,img_concat,reshape. - RewritePatternSet removeUnusedHelperOpsPatterns(ctx); - populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx); - - if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns)))) - llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n"; + PassManager cleanupPM(ctx); + cleanupPM.addPass(createCanonicalizerPass()); + if (failed(cleanupPM.run(moduleOp))) + llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; annotateWeightsConstants(*entryFunc); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index 1b7ffaa..4311851 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -17,8 +17,6 @@ void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXConcatToTensorConcat.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp similarity index 81% rename from src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXConcatToTensorConcat.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp index a73c547..36ffc95 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXConcatToTensorConcat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp @@ -8,8 +8,8 @@ using namespace mlir; namespace onnx_mlir { -struct ONNXConcatToTensorConcat : public OpConversionPattern { - ONNXConcatToTensorConcat(MLIRContext* ctx) +struct Concat : public OpConversionPattern { + Concat(MLIRContext* ctx) : OpConversionPattern(ctx) {} LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp, @@ -25,7 +25,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern { }; void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/RemoveUnusedHelperOps.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/RemoveUnusedHelperOps.cpp deleted file mode 100644 index d609c14..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/RemoveUnusedHelperOps.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/PatternMatch.h" - -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -template -struct RemoveUnusedHelperOps : OpRewritePattern { - RemoveUnusedHelperOps(MLIRContext* ctx) - : OpRewritePattern(ctx) {} - - void initialize() { this->setHasBoundedRewriteRecursion(); } - - LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final { - if (op.getResult().use_empty()) { - rewriter.eraseOp(op); - return success(); - } - - return failure(); - } -}; - -void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert>(ctx); - patterns.insert>(ctx); - patterns.insert>(ctx); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXReshapeToTensorReshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp similarity index 96% rename from src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXReshapeToTensorReshape.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp index ca023aa..dcf200d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/ONNXReshapeToTensorReshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp @@ -77,7 +77,7 @@ static bool inferExpandReassociation(ArrayRef sourceShape, return sourceIdx == sourceShape.size() && resultIdx == resultShape.size(); } -struct ONNXReshapeToTensorReshape : OpConversionPattern { +struct Reshape : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp, @@ -115,7 +115,7 @@ struct ONNXReshapeToTensorReshape : OpConversionPattern { } // namespace void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index f614179..5bfd8f8 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -57,10 +57,8 @@ private: LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); - void addReceiveOps(Value channelSourceOp, - spatial::SpatChannelNewOp& channel, - bool useBroadcastOp, - IRRewriter& rewriter); + void + addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter); void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, unsigned int argIndex, Value channelSourceOp, @@ -199,29 +197,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { - // If this result has no uses, then just skip it if (result.use_empty()) continue; auto yieldType = cast(yieldValue.getType()); - /* - * Here we assume that ReturnOp are only reachable by the following patterns: - * - * 1) - * %0 = spat.compute([...]) - * [%0 has one user, which is a ConcatOp] - * %1 = tensor.concat(%0) - * [%1 has one user, which is a ReturnOp] - * return %1 - * - * 2) - * %0 = spat.compute([...]) - * [%0 has one user, which is a ReturnOp] - * return %0 - * - * If the IR is like 2), then we can store the tensor to the output global memory location - */ auto resultUses = result.getUses(); auto numResultUses = rangeLength(resultUses); if (numResultUses == 1) { @@ -549,7 +529,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); else - receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); + receivedValue = + spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); Value replacementValue = receivedValue; if (consumerValue != channelSourceOp) { @@ -577,7 +558,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu replacementValue = cast(mapping.lookup(consumerValue)); } - assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type"); + assert(replacementValue.getType() == blockArg.getType() + && "Replayed channel use chain must match block argument type"); blockArg.replaceAllUsesWith(replacementValue); } diff --git a/src/PIM/Pass/CMakeLists.txt b/src/PIM/Pass/CMakeLists.txt index 48c66ec..8c4dfd6 100644 --- a/src/PIM/Pass/CMakeLists.txt +++ b/src/PIM/Pass/CMakeLists.txt @@ -2,9 +2,9 @@ add_pim_library(OMPimPasses CountInstructionPass.cpp MessagePass.cpp PimConstantFolding/Common.cpp - PimConstantFolding/Patterns/ConstantPatterns.cpp + PimConstantFolding/Patterns/Constant.cpp PimConstantFolding/PimConstantFoldingPass.cpp - PimConstantFolding/Patterns/SubviewPatterns.cpp + PimConstantFolding/Patterns/Subview.cpp PimHostVerificationPass.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Pass/PimConstantFolding/Patterns/ConstantPatterns.cpp b/src/PIM/Pass/PimConstantFolding/Patterns/Constant.cpp similarity index 100% rename from src/PIM/Pass/PimConstantFolding/Patterns/ConstantPatterns.cpp rename to src/PIM/Pass/PimConstantFolding/Patterns/Constant.cpp diff --git a/src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp b/src/PIM/Pass/PimConstantFolding/Patterns/Subview.cpp similarity index 100% rename from src/PIM/Pass/PimConstantFolding/Patterns/SubviewPatterns.cpp rename to src/PIM/Pass/PimConstantFolding/Patterns/Subview.cpp