36 lines
1.1 KiB
C++
36 lines
1.1 KiB
C++
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
template <typename OpTy, typename OpAdaptorTy>
|
|
struct RemoveUnusedHelperOps : OpRewritePattern<OpTy> {
|
|
RemoveUnusedHelperOps(MLIRContext* ctx)
|
|
: OpRewritePattern<OpTy>(ctx) {}
|
|
|
|
void initialize() { this->setHasBoundedRewriteRecursion(); }
|
|
|
|
LogicalResult matchAndRewrite(OpTy op, PatternRewriter& rewriter) const final {
|
|
if (op.getResult().use_empty()) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
void populateRemoveUnusedHelperOpsPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<RemoveUnusedHelperOps<tensor::ConcatOp, tensor::ConcatOpAdaptor>>(ctx);
|
|
patterns.insert<RemoveUnusedHelperOps<spatial::SpatImgConcatOp, spatial::SpatImgConcatOpAdaptor>>(ctx);
|
|
patterns.insert<RemoveUnusedHelperOps<ONNXReshapeOp, ONNXReshapeOpAdaptor>>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|