better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -331,14 +331,6 @@ static Value extractDynamicGemmBColumn(
|
||||
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
||||
}
|
||||
|
||||
static Value extractTransposedBRow(
|
||||
Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
static Value extractDynamicGemmRowVector(
|
||||
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
||||
@@ -424,7 +416,6 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numOutRows = outType.getDimSize(0);
|
||||
@@ -446,8 +437,7 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
||||
Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
||||
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||
Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
@@ -739,6 +729,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
const int64_t numOutCols = outType.getDimSize(1);
|
||||
const int64_t reductionSize = aType.getDimSize(1);
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
||||
bType = transposedType;
|
||||
}
|
||||
|
||||
if (!isCompileTimeComputable(b)) {
|
||||
bool hasC = hasGemmBias(c);
|
||||
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||
@@ -758,10 +755,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
biasType = *verifiedBiasType;
|
||||
}
|
||||
|
||||
const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize;
|
||||
const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols;
|
||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows
|
||||
|| bType.getDimSize(1) != expectedBCols) {
|
||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize
|
||||
|| bType.getDimSize(1) != numOutCols) {
|
||||
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
|
||||
return failure();
|
||||
}
|
||||
@@ -773,8 +768,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
||||
auto batchOp =
|
||||
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
||||
auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc);
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
auto outputCompute = createDynamicGemmOutputCompute(
|
||||
@@ -793,13 +787,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
|
||||
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
|
||||
return failure();
|
||||
|
||||
Reference in New Issue
Block a user