better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -123,39 +123,16 @@ static Value extractBatchMatrix(Value value,
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
auto createONNXTranspose = [&](RankedTensorType resultType, ArrayRef<int64_t> permutation) {
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||
};
|
||||
if (type.getRank() == 2) {
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
auto resultType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType(), type.getEncoding());
|
||||
return createONNXTranspose(resultType, {1, 0});
|
||||
}
|
||||
|
||||
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
auto resultType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType(), type.getEncoding());
|
||||
return createONNXTranspose(resultType, {0, 2, 1});
|
||||
}
|
||||
|
||||
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -372,32 +349,6 @@ static Value extractDynamicBatchedBColumn(Value matrix,
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedBRow(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value row,
|
||||
RankedTensorType vectorType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(batch),
|
||||
row,
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
vectorType,
|
||||
rowSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedRowVector(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
@@ -432,7 +383,6 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numBatches = outType.getDimSize(0);
|
||||
@@ -459,9 +409,7 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
||||
Value aVector =
|
||||
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
|
||||
Value bVector =
|
||||
bAlreadyTransposed
|
||||
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
|
||||
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
||||
extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
@@ -537,15 +485,6 @@ static FailureOr<Value> createBatchedDynamicOutputCompute(Value scalarPieces,
|
||||
return computeOp->getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchedReductionPiece(Value partialPiecesArg,
|
||||
Value batch,
|
||||
Value hSlice,
|
||||
@@ -764,7 +703,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
int64_t gemmK = shapeInfo->k;
|
||||
int64_t gemmN = shapeInfo->n;
|
||||
if (useTransposedForm) {
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
@@ -787,15 +726,10 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
gemmResult = transposeCompute.getResult(0);
|
||||
}
|
||||
if (useTransposedForm)
|
||||
gemmResult =
|
||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
||||
.getResult();
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
@@ -822,7 +756,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
int64_t gemmK = shapeInfo->k;
|
||||
int64_t gemmN = shapeInfo->n;
|
||||
if (useTransposedForm) {
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
@@ -880,12 +814,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (failed(result))
|
||||
return failure();
|
||||
Value finalResult = *result;
|
||||
if (useTransposedForm)
|
||||
finalResult = transposeBatchedOutput(
|
||||
finalResult,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
if (useTransposedForm) {
|
||||
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
|
||||
shapeInfo->outType.getElementType(),
|
||||
shapeInfo->outType.getEncoding());
|
||||
finalResult =
|
||||
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
||||
.getResult();
|
||||
}
|
||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, finalResult);
|
||||
return success();
|
||||
@@ -901,7 +837,6 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
rhsBatchedType,
|
||||
scalarPiecesType,
|
||||
directOutType,
|
||||
false,
|
||||
rewriter,
|
||||
loc);
|
||||
if (failed(batchOp))
|
||||
@@ -911,12 +846,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (failed(result))
|
||||
return failure();
|
||||
Value finalResult = *result;
|
||||
if (useTransposedForm)
|
||||
finalResult = transposeBatchedOutput(
|
||||
finalResult,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
if (useTransposedForm) {
|
||||
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
|
||||
shapeInfo->outType.getElementType(),
|
||||
shapeInfo->outType.getEncoding());
|
||||
finalResult =
|
||||
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
||||
.getResult();
|
||||
}
|
||||
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, finalResult);
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user