Dynamic gemm/conv

This commit is contained in:
ilgeco
2026-05-28 18:00:14 +02:00
parent cbf7b235f1
commit 1ab489fe0a
17 changed files with 704 additions and 69 deletions
@@ -1,8 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -12,6 +16,72 @@ using namespace mlir;
namespace onnx_mlir {
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
if (auto attr = dyn_cast<Attribute>(result))
return arith::ConstantIndexOp::create(rewriter, loc, cast<IntegerAttr>(attr).getInt()).getResult();
return cast<Value>(result);
}
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
APInt lhsConst;
if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero())
return rhs;
APInt rhsConst;
if (matchPattern(rhs, m_ConstantInt(&rhsConst)) && rhsConst.isZero())
return lhs;
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
}
static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatternRewriter& rewriter, Location loc) {
APInt factorConst;
if (auto attr = dyn_cast<Attribute>(factor))
factorConst = cast<IntegerAttr>(attr).getValue();
else if (!matchPattern(cast<Value>(factor), m_ConstantInt(&factorConst)))
return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult();
if (factorConst.isZero())
return arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
if (factorConst.isOne())
return value;
auto factorValue = arith::ConstantIndexOp::create(rewriter, loc, factorConst.getSExtValue()).getResult();
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
}
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
return false;
for (OpFoldResult stride : strides) {
APInt strideValue;
if (auto attr = dyn_cast<Attribute>(stride)) {
if (cast<IntegerAttr>(attr).getInt() != 1)
return false;
continue;
}
if (!matchPattern(cast<Value>(stride), m_ConstantInt(&strideValue)) || !strideValue.isOne())
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(resultType.getShape().rbegin(), resultType.getShape().rend()),
llvm::make_range(sourceType.getShape().rbegin(), sourceType.getShape().rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize == sizesAndShape.end())
return true;
++firstDifferentSize;
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
auto [size, _dimension] = sizeAndShape;
return size == 1;
});
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
@@ -123,4 +193,87 @@ Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatte
return broadcastCompute.getResult(0);
}
Value materializeContiguousTensorSlice(Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> strides,
ConversionPatternRewriter& rewriter,
Location loc) {
assert(resultType.hasStaticShape() && "expected static result type");
size_t rank = static_cast<size_t>(resultType.getRank());
assert(offsets.size() == rank && "expected rank-matching offsets");
assert(strides.size() == rank && "expected rank-matching strides");
SmallVector<OpFoldResult> sizes;
sizes.reserve(resultType.getRank());
for (int64_t size : resultType.getShape())
sizes.push_back(rewriter.getIndexAttr(size));
if (isContiguousTensorSlice(source, resultType, strides))
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
if (resultType.getRank() == 0)
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
SmallVector<Value> zeroIndices(resultType.getRank());
for (Value& zeroIndex : zeroIndices)
zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> resultIndices;
resultIndices.reserve(resultType.getRank());
auto buildLoopNest = [&](auto&& self, unsigned dim, Value accumulator) -> Value {
if (dim == resultType.getRank()) {
SmallVector<Value> sourceIndices;
sourceIndices.reserve(resultType.getRank());
for (unsigned idx = 0; idx < resultType.getRank(); ++idx) {
Value offsetValue = getIndexValue(offsets[idx], rewriter, loc);
Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc);
sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc));
}
SmallVector<OpFoldResult> sourceOffsets;
SmallVector<OpFoldResult> destinationOffsets;
SmallVector<OpFoldResult> unitSizes;
SmallVector<OpFoldResult> unitStrides;
sourceOffsets.reserve(resultType.getRank());
destinationOffsets.reserve(resultType.getRank());
unitSizes.reserve(resultType.getRank());
unitStrides.reserve(resultType.getRank());
for (Value index : sourceIndices)
sourceOffsets.push_back(index);
for (Value index : resultIndices)
destinationOffsets.push_back(index);
for (int64_t idx = 0; idx < resultType.getRank(); ++idx) {
unitSizes.push_back(rewriter.getIndexAttr(1));
unitStrides.push_back(rewriter.getIndexAttr(1));
}
auto elementTensorType =
RankedTensorType::get(SmallVector<int64_t>(resultType.getRank(), 1), resultType.getElementType());
Value elementSlice =
tensor::ExtractSliceOp::create(rewriter, loc, elementTensorType, source, sourceOffsets, unitSizes, unitStrides)
.getResult();
return tensor::InsertSliceOp::create(
rewriter, loc, elementSlice, accumulator, destinationOffsets, unitSizes, unitStrides)
.getResult();
}
Value lower = zeroIndices[dim];
Value upper = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(dim)).getResult();
Value step = arith::ConstantIndexOp::create(rewriter, loc, 1).getResult();
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody());
resultIndices.push_back(loop.getInductionVar());
Value updated = self(self, dim + 1, loop.getRegionIterArgs().front());
resultIndices.pop_back();
scf::YieldOp::create(rewriter, loc, updated);
rewriter.setInsertionPointAfter(loop);
return loop.getResult(0);
};
return buildLoopNest(buildLoopNest, 0, init);
}
} // namespace onnx_mlir
@@ -141,4 +141,11 @@ mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value materializeContiguousTensorSlice(mlir::Value source,
mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets,
llvm::ArrayRef<mlir::OpFoldResult> strides,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -111,6 +111,32 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
}
static Value createConvWeightMatrix(Value w,
RankedTensorType wFlatType,
RankedTensorType wTransType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto buildWeightMatrix = [&](Value weight) -> Value {
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
weight,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult();
};
if (isCompileTimeComputable(w))
return buildWeightMatrix(w);
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) {
spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight));
});
return computeOp.getResult(0);
}
static Value buildPackedBias(bool hasBias,
Value gemmBias,
Value biasMatrix,
@@ -395,15 +421,7 @@ static Value lowerSingleConvGroup(Value x,
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc);
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
@@ -73,38 +73,11 @@ static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t va
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
}
static std::optional<int64_t> getConstantIndexValue(Value value) {
if (auto constantIndex = value.getDefiningOp<arith::ConstantIndexOp>())
return constantIndex.value();
APInt constantValue;
if (matchPattern(value, m_ConstantInt(&constantValue)))
return constantValue.getSExtValue();
return std::nullopt;
}
static Value
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
std::optional<int64_t> constantValue = getConstantIndexValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
if (constantResult)
return createIndexConstant(rewriter, constantResult.getInt());
}
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
}
static Value
@@ -379,6 +352,233 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
return batchOp;
}
static Value createDynamicGemmBatchRow(
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
if (numOutCols == 1)
return lane;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
}
static Value createDynamicGemmBatchColumn(
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
return modIndexByConstant(lane, numOutCols, rewriter, loc);
}
static Value
extractDynamicGemmBColumn(Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc);
SmallVector<ReassociationIndices> collapseReassociation {ReassociationIndices {0, 1}};
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
Value collapsed =
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
SmallVector<ReassociationIndices> expandReassociation {ReassociationIndices {0, 1}};
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
}
static Value extractTransposedBRow(
Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult();
}
static Value extractDynamicGemmRowVector(
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, matrix, offsets, sizes, strides).getResult();
}
static FailureOr<RankedTensorType> verifyDynamicGemmBiasType(RankedTensorType cType, RankedTensorType outType) {
if (!cType.hasStaticShape() || cType.getRank() > 2)
return failure();
if (cType.getRank() == 0)
return cType;
int64_t numOutRows = outType.getDimSize(0);
int64_t numOutCols = outType.getDimSize(1);
if (cType.getRank() == 1) {
int64_t cols = cType.getDimSize(0);
if (cols == 1 || cols == numOutCols)
return cType;
return failure();
}
int64_t rows = cType.getDimSize(0);
int64_t cols = cType.getDimSize(1);
if ((rows == 1 || rows == numOutRows) && (cols == 1 || cols == numOutCols))
return cType;
return failure();
}
static bool hasGemmBias(Value c) {
Operation* definingOp = c.getDefiningOp();
return !definingOp || !isa<ONNXNoneOp>(definingOp);
}
static Value createScalarTensorConstant(RankedTensorType scalarType,
float value,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elementType = scalarType.getElementType();
auto scalarAttr = rewriter.getFloatAttr(elementType, value);
auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr);
return arith::ConstantOp::create(rewriter, loc, scalarType, denseAttr).getResult();
}
static Value createBroadcastedBiasScalar(Value bias,
RankedTensorType biasType,
Value row,
Value column,
RankedTensorType scalarType,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
if (biasType.getRank() == 1) {
SmallVector<OpFoldResult> offsets {
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
Value vector = tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides)
.getResult();
SmallVector<ReassociationIndices> reassociation {ReassociationIndices {0, 1}};
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
}
if (biasType.getRank() == 2) {
SmallVector<OpFoldResult> offsets {
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(row),
biasType.getDimSize(1) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, scalarType, bias, offsets, sizes, unitStrides).getResult();
}
Value scalar = tensor::ExtractOp::create(rewriter, loc, bias, ValueRange {}).getResult();
return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult();
}
static spatial::SpatComputeBatch createVvdmulBatch(Value a,
Value b,
RankedTensorType aType,
RankedTensorType bType,
RankedTensorType scalarPiecesType,
RankedTensorType outType,
bool bAlreadyTransposed,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
const int64_t laneCount = numOutRows * numOutCols;
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange {scalarPiecesType},
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
ValueRange {},
ValueRange {a, b});
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType};
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
Block* body =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
auto lane = batchOp.getLaneArgument();
auto inputA = batchOp.getInputArgument(0);
auto inputB = batchOp.getInputArgument(1);
auto output = batchOp.getOutputArgument(0);
assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body");
Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed
? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
return batchOp;
}
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value bias,
RankedTensorType scalarPiecesType,
RankedTensorType biasType,
RankedTensorType outType,
float alpha,
float beta,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t laneCount = scalarPiecesType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
SmallVector<Value> inputs {scalarPieces};
if (bias)
inputs.push_back(bias);
return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) {
Value pieces = blockArgs[0];
Value biasArg = bias ? blockArgs[1] : Value();
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = createIndexConstant(rewriter, 0);
Value c1 = createIndexConstant(rewriter, 1);
Value cLaneCount = createIndexConstant(rewriter, laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(loop.getBody());
Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front();
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(lane, numOutCols, rewriter, loc);
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scalar =
tensor::ExtractSliceOp::create(rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides)
.getResult();
if (alpha != 1.0f) {
Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, loc);
scalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, scalar, alphaTensor).getResult();
}
if (biasArg) {
Value biasScalar = createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, loc);
if (beta != 1.0f) {
Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, loc);
biasScalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, biasScalar, betaTensor).getResult();
}
scalar = spatial::SpatVAddOp::create(rewriter, loc, scalarType, scalar, biasScalar).getResult();
}
SmallVector<OpFoldResult> outputOffsets {row, column};
Value outputNext =
tensor::InsertSliceOp::create(rewriter, loc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides)
.getResult();
scf::YieldOp::create(rewriter, loc, outputNext);
rewriter.setInsertionPointAfter(loop);
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
});
}
static Value createPartialGroupOffset(Value hSlice,
int64_t kSlice,
int64_t numKSlices,
@@ -570,9 +770,50 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
return failure();
}
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (!isCompileTimeComputable(b)) {
gemmOp.emitOpError("requires Gemm input B to be statically computed from constants");
return failure();
bool hasC = hasGemmBias(c);
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
float beta = gemmOpAdaptor.getBeta().convertToFloat();
RankedTensorType biasType;
if (hasC) {
auto cType = dyn_cast<RankedTensorType>(c.getType());
if (!cType || !cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
auto verifiedBiasType = verifyDynamicGemmBiasType(cType, outType);
if (failed(verifiedBiasType)) {
gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape");
return failure();
}
biasType = *verifiedBiasType;
}
const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize;
const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols;
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows
|| bType.getDimSize(1) != expectedBCols) {
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
return failure();
}
const int64_t laneCount64 = numOutRows * numOutCols;
if (laneCount64 > std::numeric_limits<int32_t>::max()) {
gemmOp.emitOpError("requires Gemm dynamic batch lane count to fit in i32");
return failure();
}
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
auto batchOp = createVvdmulBatch(
a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
auto outputCompute = createDynamicGemmOutputCompute(
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
rewriter.replaceOp(gemmOp, outputCompute.getResults());
return success();
}
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
@@ -590,9 +831,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
bType = cast<RankedTensorType>(b.getType());
}
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
return failure();
@@ -615,7 +853,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
aType = paddedAType;
Value bias;
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
bool hasC = hasGemmBias(c);
auto paddedOutType = RankedTensorType::get({numOutRows, paddedOutCols}, outType.getElementType());
if (hasC) {
auto cType = dyn_cast<RankedTensorType>(c.getType());