Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp
T
NiccoloN a34ac223c0
Validate Operations / validate-operations (push) Has been cancelled
fix remaining failing tests
remove unsupported tests
2026-06-05 15:27:11 +02:00

864 lines
40 KiB
C++

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <limits>
#include <utility>
#include "Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<Value>
materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) {
if (factor == 1.0f)
return value;
auto denseAttr = dyn_cast_or_null<DenseFPElementsAttr>(getHostConstDenseElementsAttr(value));
if (!denseAttr)
return failure();
SmallVector<APFloat> scaledValues;
scaledValues.reserve(denseAttr.getNumElements());
APFloat scale(factor);
bool hadFailure = false;
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
APFloat scaledValue(originalValue);
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
hadFailure = true;
scaledValues.push_back(std::move(scaledValue));
}
if (hadFailure)
return failure();
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaledAttr, denseAttr.getType());
}
static Value createGemmBatchKOffset(
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
if (numKSlices == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter,
loc,
(d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(),
ValueRange {lane},
rewriter.getInsertionBlock()->getParentOp());
}
static Value createGemmBatchHOffset(Value lane,
int64_t numOutRows,
int64_t numKSlices,
int64_t numOutHSlices,
ConversionPatternRewriter& rewriter,
Location loc) {
if (numOutHSlices == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter,
loc,
d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(),
ValueRange {lane},
rewriter.getInsertionBlock()->getParentOp());
}
static Value
createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> materializePaddedConstantMatrix(Value value,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
if (sourceType == resultType)
return value;
auto denseAttr = getHostConstDenseElementsAttr(value);
if (!denseAttr)
return failure();
auto denseType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!denseType || denseType.getRank() != 2 || !denseType.hasStaticShape())
return failure();
ArrayRef<int64_t> sourceShape = denseType.getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
Attribute zero = rewriter.getZeroAttr(resultType.getElementType());
SmallVector<Attribute> resultValues(resultType.getNumElements(), zero);
for (int64_t row = 0; row < sourceShape[0]; ++row)
for (int64_t col = 0; col < sourceShape[1]; ++col)
resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col];
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
}
static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
RankedTensorType resultType,
int64_t unpaddedColumns,
ConversionPatternRewriter& rewriter,
Location loc) {
auto denseAttr = getHostConstDenseElementsAttr(value);
if (!denseAttr)
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() > resultType.getRank())
return failure();
ArrayRef<int64_t> sourceShape = sourceType.getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();
SmallVector<int64_t> unpaddedResultShape(resultShape.begin(), resultShape.end());
unpaddedResultShape.back() = unpaddedColumns;
const int64_t rankOffset = static_cast<int64_t>(resultShape.size() - sourceShape.size());
for (int64_t resultIndex = 0; resultIndex < static_cast<int64_t>(resultShape.size()); ++resultIndex) {
const int64_t sourceIndex = resultIndex - rankOffset;
if (sourceIndex < 0)
continue;
const int64_t sourceDim = sourceShape[sourceIndex];
const int64_t resultDim = unpaddedResultShape[resultIndex];
if (sourceDim != 1 && sourceDim != resultDim)
return failure();
}
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceShape);
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultShape);
Attribute zero = rewriter.getZeroAttr(resultType.getElementType());
SmallVector<Attribute> resultValues;
resultValues.reserve(resultType.getNumElements());
for (int64_t flatIndex = 0; flatIndex < resultType.getNumElements(); ++flatIndex) {
int64_t remaining = flatIndex;
SmallVector<int64_t> resultIndices(resultShape.size(), 0);
for (int64_t dim = 0; dim < static_cast<int64_t>(resultShape.size()); ++dim) {
resultIndices[dim] = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
}
if (resultIndices.back() >= unpaddedColumns) {
resultValues.push_back(zero);
continue;
}
int64_t sourceFlatIndex = 0;
for (int64_t resultIndex = 0; resultIndex < static_cast<int64_t>(resultShape.size()); ++resultIndex) {
const int64_t sourceIndex = resultIndex - rankOffset;
if (sourceIndex < 0)
continue;
const int64_t sourceDim = sourceShape[sourceIndex];
const int64_t mappedIndex = sourceDim == 1 ? 0 : resultIndices[resultIndex];
sourceFlatIndex += mappedIndex * sourceStrides[sourceIndex];
}
resultValues.push_back(sourceValues[sourceFlatIndex]);
}
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
}
static FailureOr<Value> prepareBias(Value c,
RankedTensorType outType,
RankedTensorType paddedOutType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto cType = cast<RankedTensorType>(c.getType());
if (!cType.hasStaticShape())
return failure();
if (isCompileTimeComputable(c))
return materializePaddedBroadcastedConstantTensor(c, paddedOutType, outType.getDimSize(1), rewriter, loc);
if (cType != outType)
return failure();
return c;
}
static Value extractATile(
Value a, Value row, Value kOffset, RankedTensorType aTileType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {row, kOffset};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult();
}
static Value createPaddedInputCompute(Value input,
RankedTensorType paddedInputType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (inputType == paddedInputType)
return input;
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return computeOp.getResult(0);
}
static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
Value b,
RankedTensorType aType,
RankedTensorType paddedBType,
RankedTensorType partialPiecesType,
int64_t numOutRows,
int64_t numKSlices,
int64_t numOutHSlices,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {partialPiecesType},
laneCount,
ValueRange {b},
ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row =
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutRows, rewriter.getInsertionBlock()->getParentOp());
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
auto aTileType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
auto bTileType = RankedTensorType::get(
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
paddedBType.getElementType());
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc);
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
Value bTile =
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
.getResult();
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
});
if (failed(batchOp))
return failure();
return *batchOp;
}
static Value
createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
if (numOutCols == 1)
return lane;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(
rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
}
static Value extractDynamicGemmBColumn(
Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
Value columnSlice =
tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides).getResult();
SmallVector<ReassociationIndices> collapseReassociation {
ReassociationIndices {0, 1}
};
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
Value collapsed =
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
SmallVector<ReassociationIndices> expandReassociation {
ReassociationIndices {0, 1}
};
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
}
static Value extractDynamicGemmRowVector(
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, matrix, offsets, sizes, strides).getResult();
}
static FailureOr<RankedTensorType> verifyDynamicGemmBiasType(RankedTensorType cType, RankedTensorType outType) {
if (!cType.hasStaticShape() || cType.getRank() > 2)
return failure();
if (cType.getRank() == 0)
return cType;
int64_t numOutRows = outType.getDimSize(0);
int64_t numOutCols = outType.getDimSize(1);
if (cType.getRank() == 1) {
int64_t cols = cType.getDimSize(0);
if (cols == 1 || cols == numOutCols)
return cType;
return failure();
}
int64_t rows = cType.getDimSize(0);
int64_t cols = cType.getDimSize(1);
if ((rows == 1 || rows == numOutRows) && (cols == 1 || cols == numOutCols))
return cType;
return failure();
}
static bool hasGemmBias(Value c) {
Operation* definingOp = c.getDefiningOp();
return !definingOp || !isa<ONNXNoneOp>(definingOp);
}
static Value createScalarTensorConstant(RankedTensorType scalarType,
float value,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elementType = scalarType.getElementType();
auto scalarAttr = rewriter.getFloatAttr(elementType, value);
auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), denseAttr, scalarType);
}
static Value createBroadcastedBiasScalar(Value bias,
RankedTensorType biasType,
Value row,
Value column,
RankedTensorType scalarType,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
if (biasType.getRank() == 1) {
SmallVector<OpFoldResult> offsets {biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(column)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
Value vector =
tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides).getResult();
SmallVector<ReassociationIndices> reassociation {
ReassociationIndices {0, 1}
};
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
}
if (biasType.getRank() == 2) {
SmallVector<OpFoldResult> offsets {
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(row),
biasType.getDimSize(1) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, scalarType, bias, offsets, sizes, unitStrides).getResult();
}
Value scalar = tensor::ExtractOp::create(rewriter, loc, bias, ValueRange {}).getResult();
return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult();
}
static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
Value b,
RankedTensorType aType,
RankedTensorType bType,
RankedTensorType scalarPiecesType,
RankedTensorType outType,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
const int64_t laneCount = numOutRows * numOutCols;
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {scalarPiecesType},
laneCount,
ValueRange {},
ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column =
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
});
if (failed(batchOp))
return failure();
return *batchOp;
}
static FailureOr<spatial::SpatCompute> createDynamicGemmOutputCompute(Value scalarPieces,
Value bias,
RankedTensorType scalarPiecesType,
RankedTensorType biasType,
RankedTensorType outType,
float alpha,
float beta,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t laneCount = scalarPiecesType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
SmallVector<Value> inputs {scalarPieces};
if (bias)
inputs.push_back(bias);
return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult {
Value pieces = blockArgs[0];
Value biasArg = bias ? blockArgs[1] : Value();
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = buildNormalizedScfFor(
rewriter,
loc,
c0,
cLaneCount,
c1,
ValueRange {outputInit},
[&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value outputAcc = iterArgs.front();
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, nestedLoc);
Value column =
onnx_mlir::affineModConst(rewriter, nestedLoc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scalar = tensor::ExtractSliceOp::create(
rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides)
.getResult();
if (alpha != 1.0f) {
Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, nestedLoc);
scalar = spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, scalar, alphaTensor).getResult();
}
if (biasArg) {
Value biasScalar =
createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, nestedLoc);
if (beta != 1.0f) {
Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, nestedLoc);
biasScalar =
spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, biasScalar, betaTensor).getResult();
}
scalar = spatial::SpatVAddOp::create(rewriter, nestedLoc, scalarType, scalar, biasScalar).getResult();
}
SmallVector<OpFoldResult> outputOffsets {row, column};
Value outputNext =
tensor::InsertSliceOp::create(rewriter, nestedLoc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides)
.getResult();
yielded.push_back(outputNext);
return success();
});
if (failed(loop))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, loop->results.front());
return success();
});
}
static Value createPartialGroupOffset(Value hSlice,
int64_t kSlice,
int64_t numKSlices,
int64_t numOutRows,
ConversionPatternRewriter& rewriter,
Location loc) {
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter,
loc,
d0 * (numKSlices * numOutRows) + kSlice * numOutRows,
ValueRange {hSlice},
rewriter.getInsertionBlock()->getParentOp());
}
static Value extractReductionPiece(Value partialPiecesArg,
Value hSlice,
int64_t kSlice,
RankedTensorType pieceType,
int64_t numKSlices,
int64_t numOutRows,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(numOutRows),
rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> pieceOffsets {
createPartialGroupOffset(hSlice, kSlice, numKSlices, numOutRows, rewriter, loc), rewriter.getIndexAttr(0)};
return tensor::ExtractSliceOp::create(
rewriter, loc, pieceType, partialPiecesArg, pieceOffsets, pieceSizes, unitStrides)
.getResult();
}
static Value reducePartialPiecesForHSlice(Value partialPiecesArg,
Value hSlice,
RankedTensorType pieceType,
int64_t numKSlices,
int64_t numOutRows,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<Value> activePieces;
activePieces.reserve(numKSlices);
for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice)
activePieces.push_back(
extractReductionPiece(partialPiecesArg, hSlice, kSlice, pieceType, numKSlices, numOutRows, rewriter, loc));
while (activePieces.size() > 1) {
SmallVector<Value> nextPieces;
nextPieces.reserve((activePieces.size() + 1) / 2);
for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2)
nextPieces.push_back(
spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1])
.getResult());
if (activePieces.size() % 2 != 0)
nextPieces.push_back(activePieces.back());
activePieces = std::move(nextPieces);
}
return activePieces.front();
}
static FailureOr<spatial::SpatCompute> createReductionCompute(Value partialPieces,
Value bias,
RankedTensorType partialPiecesType,
RankedTensorType outType,
RankedTensorType paddedOutType,
int64_t numKSlices,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<Value> inputs {partialPieces};
if (bias)
inputs.push_back(bias);
auto computeOp =
createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult {
Value partialPiecesArg = blockArgs[0];
Value biasArg = bias ? blockArgs[1] : Value();
if (biasArg && cast<RankedTensorType>(biasArg.getType()) != paddedOutType)
biasArg = createZeroPaddedTensor(biasArg, paddedOutType, rewriter, loc);
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(1), crossbarSize.getValue());
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
partialPiecesType.getElementType());
Value outputInit =
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(numOutRows),
rewriter.getIndexAttr(crossbarSize.getValue())};
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
Value reduced =
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
Value hOffset = onnx_mlir::affineMulConst(
rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
if (biasArg) {
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
Value biasSlice =
tensor::ExtractSliceOp::create(rewriter, loc, pieceType, biasArg, biasOffsets, pieceSizes, unitStrides)
.getResult();
reduced = spatial::SpatVAddOp::create(rewriter, loc, pieceType, reduced, biasSlice).getResult();
}
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), hOffset};
return tensor::InsertSliceOp::create(rewriter, loc, reduced, outputAcc, outputOffsets, pieceSizes, unitStrides)
.getResult();
};
Value paddedOutput = outputInit;
if (numOutHSlices == 1) {
Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
paddedOutput = buildOutputSlice(outputInit, hSlice);
}
else {
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cOutHSlices =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto hLoop = buildNormalizedScfFor(
rewriter,
loc,
c0,
cOutHSlices,
c1,
ValueRange {outputInit},
[&](OpBuilder&, Location, Value hSlice, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
yielded.push_back(buildOutputSlice(iterArgs.front(), hSlice));
return success();
});
if (failed(hLoop))
return failure();
paddedOutput = hLoop->results.front();
}
Value result = paddedOutput;
if (paddedOutType != outType) {
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(outType.getDimSize(0)),
rewriter.getIndexAttr(outType.getDimSize(1))};
result =
tensor::ExtractSliceOp::create(rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, unitStrides)
.getResult();
}
spatial::SpatYieldOp::create(rewriter, loc, result);
return success();
});
return computeOp;
}
struct GemmToSpatialComputes : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
} // namespace
LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
auto aType = dyn_cast<RankedTensorType>(a.getType());
auto bType = dyn_cast<RankedTensorType>(b.getType());
auto outType = dyn_cast<RankedTensorType>(gemmOp.getY().getType());
if (!aType || !bType || !outType)
return failure();
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
if (aType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm input A", aType.getRank(), {2});
return failure();
}
if (bType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm input B", bType.getRank(), {2});
return failure();
}
if (outType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm result", outType.getRank(), {2});
return failure();
}
if (gemmOpAdaptor.getTransA()) {
auto aShape = aType.getShape();
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType(), aType.getEncoding());
a = ONNXTransposeOp::create(rewriter, loc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})).getResult();
aType = transposedType;
}
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding());
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult();
bType = transposedType;
}
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (!isCompileTimeComputable(b)) {
bool hasC = hasGemmBias(c);
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
float beta = gemmOpAdaptor.getBeta().convertToFloat();
RankedTensorType biasType;
if (hasC) {
auto cType = dyn_cast<RankedTensorType>(c.getType());
if (!cType || !cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
auto verifiedBiasType = verifyDynamicGemmBiasType(cType, outType);
if (failed(verifiedBiasType)) {
gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape");
return failure();
}
biasType = *verifiedBiasType;
}
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize
|| bType.getDimSize(1) != numOutCols) {
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
return failure();
}
const int64_t laneCount64 = numOutRows * numOutCols;
if (laneCount64 > std::numeric_limits<int32_t>::max()) {
gemmOp.emitOpError("requires Gemm dynamic batch lane count to fit in i32");
return failure();
}
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc);
if (failed(batchOp))
return failure();
auto outputCompute = createDynamicGemmOutputCompute(
batchOp->getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
if (failed(outputCompute))
return failure();
rewriter.replaceOp(gemmOp, outputCompute->getResults());
return success();
}
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB)) {
gemmOp.emitOpError("requires constant Gemm input B when alpha is not 1.0");
return failure();
}
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
return failure();
}
const int64_t numKSlices = ceilIntegerDivide(reductionSize, crossbarSize.getValue());
const int64_t numOutHSlices = ceilIntegerDivide(numOutCols, crossbarSize.getValue());
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
auto paddedBType = RankedTensorType::get({paddedReductionSize, paddedOutCols}, bType.getElementType());
auto paddedB = materializePaddedConstantMatrix(b, paddedBType, rewriter, loc);
if (failed(paddedB)) {
gemmOp.emitOpError("requires constant Gemm input B so tiled weights can be padded statically");
return failure();
}
b = *paddedB;
auto paddedAType = RankedTensorType::get({numOutRows, paddedReductionSize}, aType.getElementType());
a = createPaddedInputCompute(a, paddedAType, rewriter, loc);
aType = paddedAType;
Value bias;
bool hasC = hasGemmBias(c);
auto paddedOutType = RankedTensorType::get({numOutRows, paddedOutCols}, outType.getElementType());
if (hasC) {
auto cType = dyn_cast<RankedTensorType>(c.getType());
if (!cType || !cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC)) {
gemmOp.emitOpError("requires constant Gemm bias C when beta is not 1.0");
return failure();
}
c = *scaledC;
auto preparedBias = prepareBias(c, outType, paddedOutType, rewriter, loc);
if (failed(preparedBias)) {
gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape");
return failure();
}
bias = *preparedBias;
}
const int64_t laneCount64 = numOutHSlices * numKSlices * numOutRows;
if (laneCount64 > std::numeric_limits<int32_t>::max()) {
gemmOp.emitOpError("requires Gemm tiled batch lane count to fit in i32");
return failure();
}
auto partialPiecesType =
RankedTensorType::get({laneCount64, static_cast<int64_t>(crossbarSize.getValue())}, outType.getElementType());
auto batchOp =
createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
if (failed(batchOp))
return failure();
auto reductionCompute = createReductionCompute(
batchOp->getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc);
if (failed(reductionCompute))
return failure();
rewriter.replaceOp(gemmOp, reductionCompute->getResults());
return success();
}
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToSpatialComputes>(ctx);
}
} // namespace onnx_mlir