This commit is contained in:
@@ -252,7 +252,13 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
||||
Location loc) {
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
rewriter,
|
||||
loc,
|
||||
TypeRange {partialPiecesType},
|
||||
laneCount,
|
||||
ValueRange {b},
|
||||
ValueRange {a},
|
||||
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
||||
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
||||
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
@@ -284,8 +290,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static Value createDynamicGemmBatchRow(
|
||||
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
static Value
|
||||
createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (numOutCols == 1)
|
||||
return lane;
|
||||
|
||||
@@ -294,17 +300,21 @@ static Value createDynamicGemmBatchRow(
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
}
|
||||
|
||||
static Value
|
||||
extractDynamicGemmBColumn(Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
static Value extractDynamicGemmBColumn(
|
||||
Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
|
||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
||||
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc);
|
||||
SmallVector<ReassociationIndices> collapseReassociation {ReassociationIndices {0, 1}};
|
||||
SmallVector<ReassociationIndices> collapseReassociation {
|
||||
ReassociationIndices {0, 1}
|
||||
};
|
||||
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
Value collapsed =
|
||||
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
|
||||
SmallVector<ReassociationIndices> expandReassociation {ReassociationIndices {0, 1}};
|
||||
SmallVector<ReassociationIndices> expandReassociation {
|
||||
ReassociationIndices {0, 1}
|
||||
};
|
||||
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
||||
}
|
||||
|
||||
@@ -371,13 +381,15 @@ static Value createBroadcastedBiasScalar(Value bias,
|
||||
Location loc) {
|
||||
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
|
||||
if (biasType.getRank() == 1) {
|
||||
SmallVector<OpFoldResult> offsets {
|
||||
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
|
||||
SmallVector<OpFoldResult> offsets {biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(column)};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
|
||||
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
|
||||
Value vector = tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides)
|
||||
.getResult();
|
||||
SmallVector<ReassociationIndices> reassociation {ReassociationIndices {0, 1}};
|
||||
Value vector =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides).getResult();
|
||||
SmallVector<ReassociationIndices> reassociation {
|
||||
ReassociationIndices {0, 1}
|
||||
};
|
||||
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
|
||||
}
|
||||
|
||||
@@ -407,16 +419,21 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
||||
const int64_t reductionSize = aType.getDimSize(1);
|
||||
const int64_t laneCount = numOutRows * numOutCols;
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
rewriter,
|
||||
loc,
|
||||
TypeRange {scalarPiecesType},
|
||||
laneCount,
|
||||
ValueRange {},
|
||||
ValueRange {a, b},
|
||||
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
||||
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
|
||||
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
||||
Value bVector = bAlreadyTransposed
|
||||
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
||||
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||
Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
||||
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
@@ -578,9 +595,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
||||
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
|
||||
Value reduced =
|
||||
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
|
||||
Value hOffset =
|
||||
onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice,
|
||||
crossbarSize.getValue());
|
||||
Value hOffset = onnx_mlir::multiplyIndexByConstant(
|
||||
rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
|
||||
if (biasArg) {
|
||||
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
|
||||
Value biasSlice =
|
||||
@@ -721,8 +737,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
||||
auto batchOp = createVvdmulBatch(
|
||||
a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
||||
auto batchOp =
|
||||
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
||||
auto outputCompute = createDynamicGemmOutputCompute(
|
||||
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
|
||||
rewriter.replaceOp(gemmOp, outputCompute.getResults());
|
||||
|
||||
Reference in New Issue
Block a user