better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -10,8 +10,6 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -68,14 +66,6 @@ FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<Arr
|
||||
return permutation;
|
||||
}
|
||||
|
||||
Value transposeMaybeInCompute(
|
||||
Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||
};
|
||||
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
@@ -78,12 +78,6 @@ llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||
int64_t rank);
|
||||
|
||||
mlir::Value transposeMaybeInCompute(mlir::Value value,
|
||||
mlir::RankedTensorType resultType,
|
||||
mlir::ArrayRef<int64_t> permutation,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
@@ -114,21 +113,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet matmulPatterns(ctx);
|
||||
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||
|
||||
bool hasUnloweredMatMul = false;
|
||||
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||
hasUnloweredMatMul = true;
|
||||
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||
});
|
||||
if (hasUnloweredMatMul) {
|
||||
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
@@ -165,10 +149,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet transposePatterns(ctx);
|
||||
populateTransposePatterns(transposePatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
|
||||
|
||||
ConversionTarget earlyPostTarget(*ctx);
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
@@ -206,15 +186,14 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
populateEmptyFunction(*entryFunc);
|
||||
|
||||
dumpModule(moduleOp, "spatial0");
|
||||
|
||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||
|
||||
@@ -10,6 +10,7 @@ void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { popula
|
||||
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
populateGeneratedConversionPatterns(patterns, ctx);
|
||||
populateElementwisePatterns(patterns, ctx);
|
||||
populateMatMulRewritePatterns(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
populatePoolPatterns(patterns, ctx);
|
||||
|
||||
@@ -331,14 +331,6 @@ static Value extractDynamicGemmBColumn(
|
||||
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
||||
}
|
||||
|
||||
static Value extractTransposedBRow(
|
||||
Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
static Value extractDynamicGemmRowVector(
|
||||
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
||||
@@ -424,7 +416,6 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numOutRows = outType.getDimSize(0);
|
||||
@@ -446,8 +437,7 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
||||
Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
||||
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||
Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
@@ -739,6 +729,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
const int64_t numOutCols = outType.getDimSize(1);
|
||||
const int64_t reductionSize = aType.getDimSize(1);
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
||||
bType = transposedType;
|
||||
}
|
||||
|
||||
if (!isCompileTimeComputable(b)) {
|
||||
bool hasC = hasGemmBias(c);
|
||||
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||
@@ -758,10 +755,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
biasType = *verifiedBiasType;
|
||||
}
|
||||
|
||||
const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize;
|
||||
const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols;
|
||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows
|
||||
|| bType.getDimSize(1) != expectedBCols) {
|
||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize
|
||||
|| bType.getDimSize(1) != numOutCols) {
|
||||
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
|
||||
return failure();
|
||||
}
|
||||
@@ -773,8 +768,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
||||
auto batchOp =
|
||||
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
||||
auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc);
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
auto outputCompute = createDynamicGemmOutputCompute(
|
||||
@@ -793,13 +787,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
|
||||
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
|
||||
return failure();
|
||||
|
||||
@@ -123,39 +123,16 @@ static Value extractBatchMatrix(Value value,
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
auto createONNXTranspose = [&](RankedTensorType resultType, ArrayRef<int64_t> permutation) {
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||
};
|
||||
if (type.getRank() == 2) {
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
auto resultType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType(), type.getEncoding());
|
||||
return createONNXTranspose(resultType, {1, 0});
|
||||
}
|
||||
|
||||
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
auto resultType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType(), type.getEncoding());
|
||||
return createONNXTranspose(resultType, {0, 2, 1});
|
||||
}
|
||||
|
||||
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -372,32 +349,6 @@ static Value extractDynamicBatchedBColumn(Value matrix,
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedBRow(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value row,
|
||||
RankedTensorType vectorType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(batch),
|
||||
row,
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
vectorType,
|
||||
rowSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedRowVector(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
@@ -432,7 +383,6 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numBatches = outType.getDimSize(0);
|
||||
@@ -459,9 +409,7 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
||||
Value aVector =
|
||||
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
|
||||
Value bVector =
|
||||
bAlreadyTransposed
|
||||
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
|
||||
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
||||
extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
@@ -537,15 +485,6 @@ static FailureOr<Value> createBatchedDynamicOutputCompute(Value scalarPieces,
|
||||
return computeOp->getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchedReductionPiece(Value partialPiecesArg,
|
||||
Value batch,
|
||||
Value hSlice,
|
||||
@@ -764,7 +703,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
int64_t gemmK = shapeInfo->k;
|
||||
int64_t gemmN = shapeInfo->n;
|
||||
if (useTransposedForm) {
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
@@ -787,15 +726,10 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
gemmResult = transposeCompute.getResult(0);
|
||||
}
|
||||
if (useTransposedForm)
|
||||
gemmResult =
|
||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
||||
.getResult();
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
@@ -822,7 +756,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
int64_t gemmK = shapeInfo->k;
|
||||
int64_t gemmN = shapeInfo->n;
|
||||
if (useTransposedForm) {
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
@@ -880,12 +814,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (failed(result))
|
||||
return failure();
|
||||
Value finalResult = *result;
|
||||
if (useTransposedForm)
|
||||
finalResult = transposeBatchedOutput(
|
||||
finalResult,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
if (useTransposedForm) {
|
||||
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
|
||||
shapeInfo->outType.getElementType(),
|
||||
shapeInfo->outType.getEncoding());
|
||||
finalResult =
|
||||
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
||||
.getResult();
|
||||
}
|
||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, finalResult);
|
||||
return success();
|
||||
@@ -901,7 +837,6 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
rhsBatchedType,
|
||||
scalarPiecesType,
|
||||
directOutType,
|
||||
false,
|
||||
rewriter,
|
||||
loc);
|
||||
if (failed(batchOp))
|
||||
@@ -911,12 +846,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (failed(result))
|
||||
return failure();
|
||||
Value finalResult = *result;
|
||||
if (useTransposedForm)
|
||||
finalResult = transposeBatchedOutput(
|
||||
finalResult,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
if (useTransposedForm) {
|
||||
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
|
||||
shapeInfo->outType.getElementType(),
|
||||
shapeInfo->outType.getEncoding());
|
||||
finalResult =
|
||||
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
||||
.getResult();
|
||||
}
|
||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, finalResult);
|
||||
return success();
|
||||
|
||||
@@ -137,11 +137,17 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
|
||||
auto transposedType = RankedTensorType::get(
|
||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||
Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
|
||||
Value transposedInput =
|
||||
ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), transposedType, input, rewriter.getI64ArrayAttr(permutation))
|
||||
.getResult();
|
||||
auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
if (failed(transposedResult))
|
||||
return failure();
|
||||
result = transposeMaybeInCompute(*transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
||||
result =
|
||||
ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, *transposedResult, rewriter.getI64ArrayAttr(inversePermutation))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(softmaxOp, result);
|
||||
|
||||
@@ -15,6 +15,10 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool isInsideSpatialComputeRegion(Operation* op) {
|
||||
return op->getParentOfType<spatial::SpatCompute>() || op->getParentOfType<spatial::SpatComputeBatch>();
|
||||
}
|
||||
|
||||
static Value createTransposeInit(Value input,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
@@ -102,10 +106,22 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
||||
return success();
|
||||
}
|
||||
}
|
||||
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||
Value transposed =
|
||||
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation).getResult()[0];
|
||||
rewriter.replaceOp(transposeOp, transposed);
|
||||
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
Value init = createTransposeInit(input, resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||
return linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), input, init, *permutation).getResult()[0];
|
||||
};
|
||||
|
||||
if (isInsideSpatialComputeRegion(transposeOp.getOperation())) {
|
||||
rewriter.replaceOp(transposeOp, buildTranspose(adaptor.getData()));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {adaptor.getData()}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), buildTranspose(input));
|
||||
});
|
||||
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user