Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 13:23:31 +02:00
parent 82b44a6387
commit 832bd7f1f7
37 changed files with 285 additions and 153 deletions
@@ -3,11 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
Patterns.cpp
CompileTime.cpp
ONNXToSpatialVerifier.cpp
PrePatterns.cpp
PostPatterns.cpp
Patterns/Pre.cpp
Patterns/Post.cpp
Patterns/GeneratedConversion.cpp
Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp
@@ -22,6 +23,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
Common/ComputeRegionBuilder.cpp
Common/ShapeTilingUtils.cpp
@@ -33,6 +35,7 @@ add_pim_library(OMONNXToSpatial
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
MLIRTosaDialect
OMCompilerOptions
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
@@ -43,8 +44,8 @@ bool isWeightLikeComputeOperand(Value value) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
value = transposeOp.getInput();
continue;
}
@@ -80,7 +81,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
return failure();
IRMapping localMapper;
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -171,6 +172,16 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
@@ -226,6 +237,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
@@ -14,10 +15,8 @@
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -86,30 +85,6 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
returnOp.setOperand(index, computeResult);
}
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isCompileTimeOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
@@ -117,6 +92,7 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -156,11 +132,13 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXTransposeOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
@@ -187,9 +165,14 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
RewritePatternSet transposePatterns(ctx);
populateTransposePatterns(transposePatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -205,6 +188,7 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -222,8 +206,6 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
@@ -1,19 +1,16 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedPrePatterns(patterns, ctx);
}
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedConversionPatterns(patterns, ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
@@ -27,6 +24,11 @@ void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRCon
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateWeightPromotionPatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -1,38 +1,39 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateWeightPromotionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
}
} // namespace onnx_mlir
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -10,7 +10,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -4,7 +4,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
@@ -9,7 +10,7 @@
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -20,7 +21,7 @@ namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
@@ -276,7 +277,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
} // namespace
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
@@ -1,6 +1,5 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
@@ -12,7 +11,7 @@ namespace {
} // namespace
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<onnxToArithConstant>(ctx);
patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx);
@@ -6,7 +6,7 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -5,7 +5,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -6,7 +6,7 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -0,0 +1,75 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static Value createTransposeInit(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(resultType.getRank());
for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) {
if (!ShapedType::isDynamic(resultDim)) {
sizes.push_back(rewriter.getIndexAttr(resultDim));
continue;
}
sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult());
}
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
}
static SmallVector<int64_t> getTransposePermutation(ONNXTransposeOp transposeOp) {
auto inputType = cast<RankedTensorType>(transposeOp.getData().getType());
SmallVector<int64_t> permutation;
if (auto permAttr = transposeOp.getPermAttr()) {
permutation.reserve(permAttr.size());
for (IntegerAttr attr : permAttr.getAsRange<IntegerAttr>())
permutation.push_back(attr.getInt());
return permutation;
}
permutation.reserve(inputType.getRank());
for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp,
ONNXTransposeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(transposeOp.getResult().getType());
if (!inputType || !resultType)
return failure();
SmallVector<int64_t> permutation = getTransposePermutation(transposeOp);
Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc());
Value transposed =
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation)
.getResult()[0];
rewriter.replaceOp(transposeOp, transposed);
return success();
}
};
} // namespace
void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<TransposeToLinalgTranspose>(ctx);
}
} // namespace onnx_mlir
@@ -1,18 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir