replace helper-op cleanup with canonicalization
clean up PIM pattern naming remove unused ValueMap.hpp
This commit is contained in:
@@ -1,44 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class AutoCleaningValueMap : public mlir::RewriterBase::ForwardingListener {
|
|
||||||
public:
|
|
||||||
llvm::DenseMap<mlir::Value, T> map;
|
|
||||||
|
|
||||||
AutoCleaningValueMap(mlir::OpBuilder::Listener listener)
|
|
||||||
: ForwardingListener(&listener) {}
|
|
||||||
|
|
||||||
void notifyOperationErased(mlir::Operation* op) override {
|
|
||||||
for (mlir::Value result : op->getResults())
|
|
||||||
map.erase(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
void notifyBlockErased(mlir::Block* block) override {
|
|
||||||
for (mlir::BlockArgument arg : block->getArguments())
|
|
||||||
map.erase(arg);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class NotErasableValueMap : public mlir::RewriterBase::ForwardingListener {
|
|
||||||
public:
|
|
||||||
llvm::DenseMap<mlir::Value, T> map;
|
|
||||||
|
|
||||||
NotErasableValueMap(mlir::OpBuilder::Listener listener)
|
|
||||||
: ForwardingListener(&listener) {}
|
|
||||||
|
|
||||||
void notifyOperationErased(mlir::Operation* op) override {
|
|
||||||
for (mlir::Value result : op->getResults())
|
|
||||||
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(result));
|
|
||||||
}
|
|
||||||
|
|
||||||
void notifyBlockErased(mlir::Block* block) override {
|
|
||||||
for (mlir::BlockArgument arg : block->getArguments())
|
|
||||||
assert("Value contained in NotErasableValueMap can't be erased" && !map.contains(arg));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
@@ -4,8 +4,6 @@
|
|||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
|
|
||||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||||
|
|
||||||
#include "Common/ValueMap.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|||||||
@@ -8,9 +8,8 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Math/MatMul.cpp
|
Patterns/Math/MatMul.cpp
|
||||||
Patterns/NN/Pooling.cpp
|
Patterns/NN/Pooling.cpp
|
||||||
Patterns/NN/ReduceMean.cpp
|
Patterns/NN/ReduceMean.cpp
|
||||||
Patterns/Tensor/ONNXConcatToTensorConcat.cpp
|
Patterns/Tensor/Concat.cpp
|
||||||
Patterns/Tensor/ONNXReshapeToTensorReshape.cpp
|
Patterns/Tensor/Reshape.cpp
|
||||||
Patterns/Tensor/RemoveUnusedHelperOps.cpp
|
|
||||||
Utils/SpatialReducer.cpp
|
Utils/SpatialReducer.cpp
|
||||||
Utils/WeightSubdivider.cpp
|
Utils/WeightSubdivider.cpp
|
||||||
Utils/AnnotateReplication.cpp
|
Utils/AnnotateReplication.cpp
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
@@ -117,12 +119,10 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove trailing "helper ops" i.e. concat,img_concat,reshape.
|
PassManager cleanupPM(ctx);
|
||||||
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
|
cleanupPM.addPass(createCanonicalizerPass());
|
||||||
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
|
if (failed(cleanupPM.run(moduleOp)))
|
||||||
|
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
|
|
||||||
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
|
|
||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,6 @@ void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns,
|
|||||||
|
|
||||||
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||||
ONNXConcatToTensorConcat(MLIRContext* ctx)
|
Concat(MLIRContext* ctx)
|
||||||
: OpConversionPattern(ctx) {}
|
: OpConversionPattern(ctx) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
||||||
@@ -25,7 +25,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<ONNXConcatToTensorConcat>(ctx);
|
patterns.insert<Concat>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
template <typename OpTy, typename OpAdaptorTy>
|
|
||||||
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
|
|
||||||
RemoveUnusedHelperOps(MLIRContext* ctx)
|
|
||||||
: OpRewritePattern<OpTy>(ctx) {}
|
|
||||||
|
|
||||||
void initialize() { this->setHasBoundedRewriteRecursion(); }
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
|
|
||||||
if (op.getResult().use_empty()) {
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
||||||
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
|
|
||||||
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
|
|
||||||
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -77,7 +77,7 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
|||||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
|
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
|
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
|
||||||
@@ -115,7 +115,7 @@ struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<ONNXReshapeToTensorReshape>(ctx);
|
patterns.insert<Reshape>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -57,10 +57,8 @@ private:
|
|||||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||||
void addReceiveOps(Value channelSourceOp,
|
void
|
||||||
spatial::SpatChannelNewOp& channel,
|
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
|
||||||
bool useBroadcastOp,
|
|
||||||
IRRewriter& rewriter);
|
|
||||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
Value channelSourceOp,
|
Value channelSourceOp,
|
||||||
@@ -199,29 +197,11 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
|
||||||
|
|
||||||
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
|
||||||
// If this result has no uses, then just skip it
|
|
||||||
if (result.use_empty())
|
if (result.use_empty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||||
|
|
||||||
/*
|
|
||||||
* Here we assume that ReturnOp are only reachable by the following patterns:
|
|
||||||
*
|
|
||||||
* 1)
|
|
||||||
* %0 = spat.compute([...])
|
|
||||||
* [%0 has one user, which is a ConcatOp]
|
|
||||||
* %1 = tensor.concat(%0)
|
|
||||||
* [%1 has one user, which is a ReturnOp]
|
|
||||||
* return %1
|
|
||||||
*
|
|
||||||
* 2)
|
|
||||||
* %0 = spat.compute([...])
|
|
||||||
* [%0 has one user, which is a ReturnOp]
|
|
||||||
* return %0
|
|
||||||
*
|
|
||||||
* If the IR is like 2), then we can store the tensor to the output global memory location
|
|
||||||
*/
|
|
||||||
auto resultUses = result.getUses();
|
auto resultUses = result.getUses();
|
||||||
auto numResultUses = rangeLength(resultUses);
|
auto numResultUses = rangeLength(resultUses);
|
||||||
if (numResultUses == 1) {
|
if (numResultUses == 1) {
|
||||||
@@ -549,7 +529,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
|||||||
receivedValue =
|
receivedValue =
|
||||||
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||||
else
|
else
|
||||||
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
receivedValue =
|
||||||
|
spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel);
|
||||||
|
|
||||||
Value replacementValue = receivedValue;
|
Value replacementValue = receivedValue;
|
||||||
if (consumerValue != channelSourceOp) {
|
if (consumerValue != channelSourceOp) {
|
||||||
@@ -577,7 +558,8 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
|||||||
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
replacementValue = cast<Value>(mapping.lookup(consumerValue));
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(replacementValue.getType() == blockArg.getType() && "Replayed channel use chain must match block argument type");
|
assert(replacementValue.getType() == blockArg.getType()
|
||||||
|
&& "Replayed channel use chain must match block argument type");
|
||||||
blockArg.replaceAllUsesWith(replacementValue);
|
blockArg.replaceAllUsesWith(replacementValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ add_pim_library(OMPimPasses
|
|||||||
CountInstructionPass.cpp
|
CountInstructionPass.cpp
|
||||||
MessagePass.cpp
|
MessagePass.cpp
|
||||||
PimConstantFolding/Common.cpp
|
PimConstantFolding/Common.cpp
|
||||||
PimConstantFolding/Patterns/ConstantPatterns.cpp
|
PimConstantFolding/Patterns/Constant.cpp
|
||||||
PimConstantFolding/PimConstantFoldingPass.cpp
|
PimConstantFolding/PimConstantFoldingPass.cpp
|
||||||
PimConstantFolding/Patterns/SubviewPatterns.cpp
|
PimConstantFolding/Patterns/Subview.cpp
|
||||||
PimHostVerificationPass.cpp
|
PimHostVerificationPass.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|||||||
Reference in New Issue
Block a user