replace helper-op cleanup with canonicalization

clean up PIM pattern naming
remove unused ValueMap.hpp
This commit is contained in:
NiccoloN
2026-03-23 17:13:54 +01:00
parent 50c545539b
commit 461bdd808d
12 changed files with 21 additions and 123 deletions

View File

@@ -8,9 +8,8 @@ add_pim_library(OMONNXToSpatial
Patterns/Math/MatMul.cpp
Patterns/NN/Pooling.cpp
Patterns/NN/ReduceMean.cpp
Patterns/Tensor/ONNXConcatToTensorConcat.cpp
Patterns/Tensor/ONNXReshapeToTensorReshape.cpp
Patterns/Tensor/RemoveUnusedHelperOps.cpp
Patterns/Tensor/Concat.cpp
Patterns/Tensor/Reshape.cpp
Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp

View File

@@ -1,7 +1,9 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
@@ -117,12 +119,10 @@ void ONNXToSpatialPass::runOnOperation() {
}
}
// Remove trailing "helper ops" i.e. concat,img_concat,reshape.
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp)))
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc);

View File

@@ -17,8 +17,6 @@ void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns,
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -8,8 +8,8 @@ using namespace mlir;
namespace onnx_mlir {
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
ONNXConcatToTensorConcat(MLIRContext* ctx)
struct Concat : public OpConversionPattern<ONNXConcatOp> {
Concat(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
@@ -25,7 +25,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
};
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXConcatToTensorConcat>(ctx);
patterns.insert<Concat>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,35 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {}
void initialize() { this->setHasBoundedRewriteRecursion(); }
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -77,7 +77,7 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
@@ -115,7 +115,7 @@ struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
} // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXReshapeToTensorReshape>(ctx);
patterns.insert<Reshape>(ctx);
}
} // namespace onnx_mlir