From 0a5e73c3ea6a3767157b7edfb7874947e664fad7 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 3 Jun 2026 12:26:31 +0200 Subject: [PATCH] better transpose pattern and cleanup --- .../ONNXToSpatial/Common/ShapeTilingUtils.cpp | 10 -- .../ONNXToSpatial/Common/ShapeTilingUtils.hpp | 6 - .../ONNXToSpatial/ONNXToSpatialPass.cpp | 31 +---- src/PIM/Conversion/ONNXToSpatial/Patterns.cpp | 1 + .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 35 ++--- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 123 +++++------------- .../ONNXToSpatial/Patterns/NN/Softmax.cpp | 10 +- .../Patterns/Tensor/Transpose.cpp | 24 +++- 8 files changed, 75 insertions(+), 165 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 40eeec2..771ca03 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -10,8 +10,6 @@ #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; @@ -68,14 +66,6 @@ FailureOr> getTransposePermutationChecked(std::optional permutation, PatternRewriter& rewriter, Location loc) { - auto buildTranspose = [&](Value input) -> Value { - return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult(); - }; - return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose); -} - SmallVector getUnitStrides(PatternRewriter& rewriter, int64_t rank) { return SmallVector(rank, rewriter.getIndexAttr(1)); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index 84fd5a3..4c265c1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -78,12 +78,6 @@ llvm::SmallVector invertPermutation(mlir::ArrayRef permutation mlir::FailureOr> getTransposePermutationChecked(std::optional permAttr, int64_t rank); -mlir::Value transposeMaybeInCompute(mlir::Value value, - mlir::RankedTensorType resultType, - mlir::ArrayRef permutation, - mlir::PatternRewriter& rewriter, - mlir::Location loc); - llvm::SmallVector getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank); llvm::SmallVector getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 2ea2215..cbcc67a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -8,7 +8,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" @@ -114,21 +113,6 @@ void ONNXToSpatialPass::runOnOperation() { return; } - RewritePatternSet matmulPatterns(ctx); - populateMatMulRewritePatterns(matmulPatterns, ctx); - walkAndApplyPatterns(moduleOp, std::move(matmulPatterns)); - - bool hasUnloweredMatMul = false; - moduleOp.walk([&](ONNXMatMulOp matmulOp) { - hasUnloweredMatMul = true; - matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion"); - }); - if (hasUnloweredMatMul) { - moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion"); - signalPassFailure(); - return; - } - ConversionTarget target(*ctx); target.addLegalDialect createONNXToSpatialPass() { return std::make_unique(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp index bfe02c6..ffa0b1f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp @@ -10,6 +10,7 @@ void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { popula void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateGeneratedConversionPatterns(patterns, ctx); populateElementwisePatterns(patterns, ctx); + populateMatMulRewritePatterns(patterns, ctx); populateGemmPatterns(patterns, ctx); populateConvPatterns(patterns, ctx); populatePoolPatterns(patterns, ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index e073507..e068259 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -331,14 +331,6 @@ static Value extractDynamicGemmBColumn( return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult(); } -static Value extractTransposedBRow( - Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { - SmallVector offsets {row, rewriter.getIndexAttr(0)}; - SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; - SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult(); -} - static Value extractDynamicGemmRowVector( Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { SmallVector offsets {row, rewriter.getIndexAttr(0)}; @@ -424,7 +416,6 @@ static FailureOr createVvdmulBatch(Value a, RankedTensorType bType, RankedTensorType scalarPiecesType, RankedTensorType outType, - bool bAlreadyTransposed, ConversionPatternRewriter& rewriter, Location loc) { const int64_t numOutRows = outType.getDimSize(0); @@ -446,8 +437,7 @@ static FailureOr createVvdmulBatch(Value a, auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc); - Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc) - : extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc); + Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc); Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; @@ -739,6 +729,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, const int64_t numOutCols = outType.getDimSize(1); const int64_t reductionSize = aType.getDimSize(1); + if (gemmOpAdaptor.getTransB()) { + auto bShape = bType.getShape(); + auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding()); + b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult(); + bType = transposedType; + } + if (!isCompileTimeComputable(b)) { bool hasC = hasGemmBias(c); float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); @@ -758,10 +755,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, biasType = *verifiedBiasType; } - const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize; - const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols; - if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows - || bType.getDimSize(1) != expectedBCols) { + if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize + || bType.getDimSize(1) != numOutCols) { gemmOp.emitOpError("has inconsistent A, B, and output shapes"); return failure(); } @@ -773,8 +768,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, } auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType()); - auto batchOp = - createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc); + auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc); if (failed(batchOp)) return failure(); auto outputCompute = createDynamicGemmOutputCompute( @@ -793,13 +787,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, b = *scaledB; bType = cast(b.getType()); - if (gemmOpAdaptor.getTransB()) { - auto bShape = bType.getShape(); - auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); - b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc); - bType = cast(b.getType()); - } - if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) { gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling"); return failure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index c1dba86..922c05f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -123,39 +123,16 @@ static Value extractBatchMatrix(Value value, static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); auto shape = type.getShape(); - RankedTensorType transposedType; - SmallVector perm; + auto createONNXTranspose = [&](RankedTensorType resultType, ArrayRef permutation) { + return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)).getResult(); + }; if (type.getRank() == 2) { - transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); - perm = {1, 0}; - } - else { - transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); - perm = {0, 2, 1}; + auto resultType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType(), type.getEncoding()); + return createONNXTranspose(resultType, {1, 0}); } - return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc); -} - -static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) { - auto type = cast(value.getType()); - auto shape = type.getShape(); - RankedTensorType transposedType; - SmallVector perm; - if (type.getRank() == 2) { - transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); - perm = {1, 0}; - } - else { - transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); - perm = {0, 2, 1}; - } - - auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { - Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); - spatial::SpatYieldOp::create(rewriter, loc, transposed); - }); - return transposeCompute.getResult(0); + auto resultType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType(), type.getEncoding()); + return createONNXTranspose(resultType, {0, 2, 1}); } static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) { @@ -372,32 +349,6 @@ static Value extractDynamicBatchedBColumn(Value matrix, .getResult(); } -static Value extractDynamicBatchedBRow(Value matrix, - int64_t sourceBatchCount, - Value batch, - Value row, - RankedTensorType vectorType, - PatternRewriter& rewriter, - Location loc) { - auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType()); - SmallVector offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) - : OpFoldResult(batch), - row, - rewriter.getIndexAttr(0)}; - SmallVector sizes { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; - auto rowSlice = - tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3)); - return tensor::CollapseShapeOp::create(rewriter, - loc, - vectorType, - rowSlice, - SmallVector { - {0, 1}, - {2} - }); -} - static Value extractDynamicBatchedRowVector(Value matrix, int64_t sourceBatchCount, Value batch, @@ -432,7 +383,6 @@ static FailureOr createBatchedVvdmulBatch(Value a, RankedTensorType bType, RankedTensorType scalarPiecesType, RankedTensorType outType, - bool bAlreadyTransposed, PatternRewriter& rewriter, Location loc) { const int64_t numBatches = outType.getDimSize(0); @@ -459,9 +409,7 @@ static FailureOr createBatchedVvdmulBatch(Value a, Value aVector = extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc); Value bVector = - bAlreadyTransposed - ? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc) - : extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc); + extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc); Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; @@ -537,15 +485,6 @@ static FailureOr createBatchedDynamicOutputCompute(Value scalarPieces, return computeOp->getResult(0); } -static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) { - auto transposeCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) { - Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1})); - spatial::SpatYieldOp::create(rewriter, loc, transposed); - }); - return transposeCompute.getResult(0); -} - static Value extractBatchedReductionPiece(Value partialPiecesArg, Value batch, Value hSlice, @@ -764,7 +703,7 @@ struct MatMulToGemm : OpRewritePattern { int64_t gemmK = shapeInfo->k; int64_t gemmN = shapeInfo->n; if (useTransposedForm) { - lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc); + lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc); lhsBatchForGemm = shapeInfo->rhsBatch; rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); rhsBatchForGemm = shapeInfo->lhsBatch; @@ -787,15 +726,10 @@ struct MatMulToGemm : OpRewritePattern { rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); - if (useTransposedForm) { - auto transposeCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) { - Value transposed = - ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0})); - spatial::SpatYieldOp::create(rewriter, loc, transposed); - }); - gemmResult = transposeCompute.getResult(0); - } + if (useTransposedForm) + gemmResult = + ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0})) + .getResult(); rewriter.replaceOp(matmulOp, gemmResult); return success(); } @@ -822,7 +756,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { int64_t gemmK = shapeInfo->k; int64_t gemmN = shapeInfo->n; if (useTransposedForm) { - lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc); + lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc); lhsBatchForGemm = shapeInfo->rhsBatch; rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); rhsBatchForGemm = shapeInfo->lhsBatch; @@ -880,12 +814,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { if (failed(result)) return failure(); Value finalResult = *result; - if (useTransposedForm) - finalResult = transposeBatchedOutput( - finalResult, - RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()), - rewriter, - loc); + if (useTransposedForm) { + auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, + shapeInfo->outType.getElementType(), + shapeInfo->outType.getEncoding()); + finalResult = + ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1})) + .getResult(); + } finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); rewriter.replaceOp(matmulOp, finalResult); return success(); @@ -901,7 +837,6 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { rhsBatchedType, scalarPiecesType, directOutType, - false, rewriter, loc); if (failed(batchOp)) @@ -911,12 +846,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { if (failed(result)) return failure(); Value finalResult = *result; - if (useTransposedForm) - finalResult = transposeBatchedOutput( - finalResult, - RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()), - rewriter, - loc); + if (useTransposedForm) { + auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, + shapeInfo->outType.getElementType(), + shapeInfo->outType.getEncoding()); + finalResult = + ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1})) + .getResult(); + } finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); rewriter.replaceOp(matmulOp, finalResult); return success(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index 374bd8e..8b0503e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -137,11 +137,17 @@ struct SoftmaxToSpatialCompute : OpConversionPattern { auto transposedType = RankedTensorType::get( permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); - Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc()); + Value transposedInput = + ONNXTransposeOp::create( + rewriter, softmaxOp.getLoc(), transposedType, input, rewriter.getI64ArrayAttr(permutation)) + .getResult(); auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); if (failed(transposedResult)) return failure(); - result = transposeMaybeInCompute(*transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc()); + result = + ONNXTransposeOp::create( + rewriter, softmaxOp.getLoc(), inputType, *transposedResult, rewriter.getI64ArrayAttr(inversePermutation)) + .getResult(); } rewriter.replaceOp(softmaxOp, result); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp index fe733f7..2ed6957 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp @@ -15,6 +15,10 @@ using namespace mlir; namespace onnx_mlir { namespace { +static bool isInsideSpatialComputeRegion(Operation* op) { + return op->getParentOfType() || op->getParentOfType(); +} + static Value createTransposeInit(Value input, RankedTensorType resultType, ArrayRef permutation, @@ -102,10 +106,22 @@ struct TransposeToLinalgTranspose : OpConversionPattern { return success(); } } - Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc()); - Value transposed = - linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation).getResult()[0]; - rewriter.replaceOp(transposeOp, transposed); + + auto buildTranspose = [&](Value input) -> Value { + Value init = createTransposeInit(input, resultType, *permutation, rewriter, transposeOp.getLoc()); + return linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), input, init, *permutation).getResult()[0]; + }; + + if (isInsideSpatialComputeRegion(transposeOp.getOperation())) { + rewriter.replaceOp(transposeOp, buildTranspose(adaptor.getData())); + return success(); + } + + auto computeOp = createSpatCompute<1>( + rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {adaptor.getData()}, [&](Value input) { + spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), buildTranspose(input)); + }); + rewriter.replaceOp(transposeOp, computeOp.getResult(0)); return success(); } };