replace helper-op cleanup with canonicalization

clean up PIM pattern naming
remove unused ValueMap.hpp
This commit is contained in:
NiccoloN
2026-03-23 17:13:54 +01:00
parent 50c545539b
commit 461bdd808d
12 changed files with 21 additions and 123 deletions

View File

@@ -1,44 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
template <typename T>
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
map.erase(result);
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
map.erase(arg);
}
};
template <typename T>
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
public:
llvm::DenseMap<mlir::Value, T> map;
NotErasableValueMap(mlir::OpBuilder::Listener listener)
: ForwardingListener(&listener) {}
void notifyOperationErased(mlir::Operation* op) override {
for (mlir::Value result : op->getResults())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
}
void notifyBlockErased(mlir::Block* block) override {
for (mlir::BlockArgument arg : block->getArguments())
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
}
};

View File

@@ -4,8 +4,6 @@
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h" #include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "Common/ValueMap.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {

View File

@@ -8,9 +8,8 @@ add_pim_library(OMONNXToSpatial
Patterns/Math/MatMul.cpp Patterns/Math/MatMul.cpp
Patterns/NN/Pooling.cpp Patterns/NN/Pooling.cpp
Patterns/NN/ReduceMean.cpp Patterns/NN/ReduceMean.cpp
Patterns/Tensor/ONNXConcatToTensorConcat.cpp Patterns/Tensor/Concat.cpp
Patterns/Tensor/ONNXReshapeToTensorReshape.cpp Patterns/Tensor/Reshape.cpp
Patterns/Tensor/RemoveUnusedHelperOps.cpp
Utils/SpatialReducer.cpp Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp Utils/WeightSubdivider.cpp
Utils/AnnotateReplication.cpp Utils/AnnotateReplication.cpp

View File

@@ -1,7 +1,9 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
@@ -117,12 +119,10 @@ void ONNXToSpatialPass::runOnOperation() {
} }
} }
// Remove trailing "helper ops" i.e. concat,img_concat,reshape. PassManager cleanupPM(ctx);
RewritePatternSet removeUnusedHelperOpsPatterns(ctx); cleanupPM.addPass(createCanonicalizerPass());
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx); if (failed(cleanupPM.run(moduleOp)))
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);

View File

@@ -17,8 +17,6 @@ void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns,
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -8,8 +8,8 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> { struct Concat : public OpConversionPattern<ONNXConcatOp> {
ONNXConcatToTensorConcat(MLIRContext* ctx) Concat(MLIRContext* ctx)
: OpConversionPattern(ctx) {} : OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp, LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
@@ -25,7 +25,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
}; };
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) { void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXConcatToTensorConcat>(ctx); patterns.insert<Concat>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,35 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
template <typename OpTy, typename OpAdaptorTy>
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
RemoveUnusedHelperOps(MLIRContext* ctx)
: OpRewritePattern<OpTy>(ctx) {}
void initialize() { this->setHasBoundedRewriteRecursion(); }
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
}
} // namespace onnx_mlir

View File

@@ -77,7 +77,7 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size(); return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
} }
struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> { struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp, LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
@@ -115,7 +115,7 @@ struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
} // namespace } // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXReshapeToTensorReshape>(ctx); patterns.insert<Reshape>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -57,10 +57,8 @@ private:
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void addReceiveOps(Value channelSourceOp, void
spatial::SpatChannelNewOp& channel, addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
bool useBroadcastOp,
IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex, unsigned int argIndex,
Value channelSourceOp, Value channelSourceOp,
@@ -199,29 +197,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
// If this result has no uses, then just skip it
if (result.use_empty()) if (result.use_empty())
continue; continue;
auto yieldType = cast<TensorType>(yieldValue.getType()); auto yieldType = cast<TensorType>(yieldValue.getType());
/*
* Here we assume that ReturnOp are only reachable by the following patterns:
*
* 1)
* %0 = spat.compute([...])
* [%0 has one user, which is a ConcatOp]
* %1 = tensor.concat(%0)
* [%1 has one user, which is a ReturnOp]
* return %1
*
* 2)
* %0 = spat.compute([...])
* [%0 has one user, which is a ReturnOp]
* return %0
*
* If the IR is like 2), then we can store the tensor to the output global memory location
*/
auto resultUses = result.getUses(); auto resultUses = result.getUses();
auto numResultUses = rangeLength(resultUses); auto numResultUses = rangeLength(resultUses);
if (numResultUses == 1) { if (numResultUses == 1) {
@@ -549,7 +529,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
receivedValue = receivedValue =
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
else else
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); receivedValue =
spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
Value replacementValue = receivedValue; Value replacementValue = receivedValue;
if (consumerValue != channelSourceOp) { if (consumerValue != channelSourceOp) {
@@ -577,7 +558,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
replacementValue = cast<Value>(mapping.lookup(consumerValue)); replacementValue = cast<Value>(mapping.lookup(consumerValue));
} }
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type"); assert(replacementValue.getType() == blockArg.getType()
&& "Replayed channel use chain must match block argument type");
blockArg.replaceAllUsesWith(replacementValue); blockArg.replaceAllUsesWith(replacementValue);
} }

View File

@@ -2,9 +2,9 @@ add_pim_library(OMPimPasses
CountInstructionPass.cpp CountInstructionPass.cpp
MessagePass.cpp MessagePass.cpp
PimConstantFolding/Common.cpp PimConstantFolding/Common.cpp
PimConstantFolding/Patterns/ConstantPatterns.cpp PimConstantFolding/Patterns/Constant.cpp
PimConstantFolding/PimConstantFoldingPass.cpp PimConstantFolding/PimConstantFoldingPass.cpp
PimConstantFolding/Patterns/SubviewPatterns.cpp PimConstantFolding/Patterns/Subview.cpp
PimHostVerificationPass.cpp PimHostVerificationPass.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS