a34ac223c0
Validate Operations / validate-operations (push) Has been cancelled
remove unsupported tests
864 lines
40 KiB
C++
864 lines
40 KiB
C++
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Location.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <limits>
|
|
#include <utility>
|
|
|
|
#include "Common/IR/ConstantUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static FailureOr<Value>
|
|
materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) {
|
|
if (factor == 1.0f)
|
|
return value;
|
|
|
|
auto denseAttr = dyn_cast_or_null<DenseFPElementsAttr>(getHostConstDenseElementsAttr(value));
|
|
if (!denseAttr)
|
|
return failure();
|
|
|
|
SmallVector<APFloat> scaledValues;
|
|
scaledValues.reserve(denseAttr.getNumElements());
|
|
APFloat scale(factor);
|
|
bool hadFailure = false;
|
|
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
|
|
APFloat scaledValue(originalValue);
|
|
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
|
|
hadFailure = true;
|
|
scaledValues.push_back(std::move(scaledValue));
|
|
}
|
|
if (hadFailure)
|
|
return failure();
|
|
|
|
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaledAttr, denseAttr.getType());
|
|
}
|
|
|
|
static Value createGemmBatchKOffset(
|
|
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
|
|
if (numKSlices == 1)
|
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
|
|
MLIRContext* context = rewriter.getContext();
|
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
return createOrFoldAffineApply(rewriter,
|
|
loc,
|
|
(d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(),
|
|
ValueRange {lane},
|
|
rewriter.getInsertionBlock()->getParentOp());
|
|
}
|
|
|
|
static Value createGemmBatchHOffset(Value lane,
|
|
int64_t numOutRows,
|
|
int64_t numKSlices,
|
|
int64_t numOutHSlices,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (numOutHSlices == 1)
|
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
|
|
MLIRContext* context = rewriter.getContext();
|
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
return createOrFoldAffineApply(rewriter,
|
|
loc,
|
|
d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(),
|
|
ValueRange {lane},
|
|
rewriter.getInsertionBlock()->getParentOp());
|
|
}
|
|
|
|
static Value
|
|
createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
|
auto sourceType = cast<RankedTensorType>(value.getType());
|
|
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
|
SmallVector<OpFoldResult> highPads;
|
|
highPads.reserve(sourceType.getRank());
|
|
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
|
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
|
|
|
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
|
auto* padBlock = new Block();
|
|
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
|
padOp.getRegion().push_back(padBlock);
|
|
rewriter.setInsertionPointToStart(padBlock);
|
|
auto zero = getOrCreateConstant(
|
|
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
|
tensor::YieldOp::create(rewriter, loc, zero);
|
|
rewriter.setInsertionPointAfter(padOp);
|
|
return padOp.getResult();
|
|
}
|
|
|
|
static FailureOr<Value> materializePaddedConstantMatrix(Value value,
|
|
RankedTensorType resultType,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto sourceType = cast<RankedTensorType>(value.getType());
|
|
if (sourceType == resultType)
|
|
return value;
|
|
|
|
auto denseAttr = getHostConstDenseElementsAttr(value);
|
|
if (!denseAttr)
|
|
return failure();
|
|
|
|
auto denseType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
if (!denseType || denseType.getRank() != 2 || !denseType.hasStaticShape())
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> sourceShape = denseType.getShape();
|
|
ArrayRef<int64_t> resultShape = resultType.getShape();
|
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
|
Attribute zero = rewriter.getZeroAttr(resultType.getElementType());
|
|
SmallVector<Attribute> resultValues(resultType.getNumElements(), zero);
|
|
|
|
for (int64_t row = 0; row < sourceShape[0]; ++row)
|
|
for (int64_t col = 0; col < sourceShape[1]; ++col)
|
|
resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col];
|
|
|
|
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
|
}
|
|
|
|
static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
|
|
RankedTensorType resultType,
|
|
int64_t unpaddedColumns,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto denseAttr = getHostConstDenseElementsAttr(value);
|
|
if (!denseAttr)
|
|
return failure();
|
|
|
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() > resultType.getRank())
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
ArrayRef<int64_t> resultShape = resultType.getShape();
|
|
SmallVector<int64_t> unpaddedResultShape(resultShape.begin(), resultShape.end());
|
|
unpaddedResultShape.back() = unpaddedColumns;
|
|
|
|
const int64_t rankOffset = static_cast<int64_t>(resultShape.size() - sourceShape.size());
|
|
for (int64_t resultIndex = 0; resultIndex < static_cast<int64_t>(resultShape.size()); ++resultIndex) {
|
|
const int64_t sourceIndex = resultIndex - rankOffset;
|
|
if (sourceIndex < 0)
|
|
continue;
|
|
const int64_t sourceDim = sourceShape[sourceIndex];
|
|
const int64_t resultDim = unpaddedResultShape[resultIndex];
|
|
if (sourceDim != 1 && sourceDim != resultDim)
|
|
return failure();
|
|
}
|
|
|
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceShape);
|
|
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultShape);
|
|
Attribute zero = rewriter.getZeroAttr(resultType.getElementType());
|
|
|
|
SmallVector<Attribute> resultValues;
|
|
resultValues.reserve(resultType.getNumElements());
|
|
for (int64_t flatIndex = 0; flatIndex < resultType.getNumElements(); ++flatIndex) {
|
|
int64_t remaining = flatIndex;
|
|
SmallVector<int64_t> resultIndices(resultShape.size(), 0);
|
|
for (int64_t dim = 0; dim < static_cast<int64_t>(resultShape.size()); ++dim) {
|
|
resultIndices[dim] = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
|
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
|
}
|
|
|
|
if (resultIndices.back() >= unpaddedColumns) {
|
|
resultValues.push_back(zero);
|
|
continue;
|
|
}
|
|
|
|
int64_t sourceFlatIndex = 0;
|
|
for (int64_t resultIndex = 0; resultIndex < static_cast<int64_t>(resultShape.size()); ++resultIndex) {
|
|
const int64_t sourceIndex = resultIndex - rankOffset;
|
|
if (sourceIndex < 0)
|
|
continue;
|
|
|
|
const int64_t sourceDim = sourceShape[sourceIndex];
|
|
const int64_t mappedIndex = sourceDim == 1 ? 0 : resultIndices[resultIndex];
|
|
sourceFlatIndex += mappedIndex * sourceStrides[sourceIndex];
|
|
}
|
|
resultValues.push_back(sourceValues[sourceFlatIndex]);
|
|
}
|
|
|
|
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
|
}
|
|
|
|
static FailureOr<Value> prepareBias(Value c,
|
|
RankedTensorType outType,
|
|
RankedTensorType paddedOutType,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto cType = cast<RankedTensorType>(c.getType());
|
|
if (!cType.hasStaticShape())
|
|
return failure();
|
|
|
|
if (isCompileTimeComputable(c))
|
|
return materializePaddedBroadcastedConstantTensor(c, paddedOutType, outType.getDimSize(1), rewriter, loc);
|
|
|
|
if (cType != outType)
|
|
return failure();
|
|
|
|
return c;
|
|
}
|
|
|
|
static Value extractATile(
|
|
Value a, Value row, Value kOffset, RankedTensorType aTileType, ConversionPatternRewriter& rewriter, Location loc) {
|
|
SmallVector<OpFoldResult> offsets {row, kOffset};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
|
|
return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult();
|
|
}
|
|
|
|
static Value createPaddedInputCompute(Value input,
|
|
RankedTensorType paddedInputType,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
|
if (inputType == paddedInputType)
|
|
return input;
|
|
|
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
|
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
|
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
|
});
|
|
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
|
|
Value b,
|
|
RankedTensorType aType,
|
|
RankedTensorType paddedBType,
|
|
RankedTensorType partialPiecesType,
|
|
int64_t numOutRows,
|
|
int64_t numKSlices,
|
|
int64_t numOutHSlices,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
|
auto batchOp = createSpatComputeBatch(
|
|
rewriter,
|
|
loc,
|
|
TypeRange {partialPiecesType},
|
|
laneCount,
|
|
ValueRange {b},
|
|
ValueRange {a},
|
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
|
Value row =
|
|
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutRows, rewriter.getInsertionBlock()->getParentOp());
|
|
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
|
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
|
|
|
auto aTileType =
|
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
|
auto bTileType = RankedTensorType::get(
|
|
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
|
paddedBType.getElementType());
|
|
auto pieceType =
|
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
|
Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc);
|
|
|
|
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
|
|
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
|
rewriter.getIndexAttr(crossbarSize.getValue())};
|
|
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
|
Value bTile =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
|
|
.getResult();
|
|
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
|
|
|
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
|
createParallelInsertSliceIntoBatchOutput(
|
|
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
return *batchOp;
|
|
}
|
|
|
|
static Value
|
|
createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
|
if (numOutCols == 1)
|
|
return lane;
|
|
|
|
MLIRContext* context = rewriter.getContext();
|
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
return createOrFoldAffineApply(
|
|
rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
|
|
}
|
|
|
|
static Value extractDynamicGemmBColumn(
|
|
Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
|
Value columnSlice =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides).getResult();
|
|
SmallVector<ReassociationIndices> collapseReassociation {
|
|
ReassociationIndices {0, 1}
|
|
};
|
|
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
|
Value collapsed =
|
|
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
|
|
SmallVector<ReassociationIndices> expandReassociation {
|
|
ReassociationIndices {0, 1}
|
|
};
|
|
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
|
}
|
|
|
|
static Value extractDynamicGemmRowVector(
|
|
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
|
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, matrix, offsets, sizes, strides).getResult();
|
|
}
|
|
|
|
static FailureOr<RankedTensorType> verifyDynamicGemmBiasType(RankedTensorType cType, RankedTensorType outType) {
|
|
if (!cType.hasStaticShape() || cType.getRank() > 2)
|
|
return failure();
|
|
|
|
if (cType.getRank() == 0)
|
|
return cType;
|
|
|
|
int64_t numOutRows = outType.getDimSize(0);
|
|
int64_t numOutCols = outType.getDimSize(1);
|
|
if (cType.getRank() == 1) {
|
|
int64_t cols = cType.getDimSize(0);
|
|
if (cols == 1 || cols == numOutCols)
|
|
return cType;
|
|
return failure();
|
|
}
|
|
|
|
int64_t rows = cType.getDimSize(0);
|
|
int64_t cols = cType.getDimSize(1);
|
|
if ((rows == 1 || rows == numOutRows) && (cols == 1 || cols == numOutCols))
|
|
return cType;
|
|
return failure();
|
|
}
|
|
|
|
static bool hasGemmBias(Value c) {
|
|
Operation* definingOp = c.getDefiningOp();
|
|
return !definingOp || !isa<ONNXNoneOp>(definingOp);
|
|
}
|
|
|
|
static Value createScalarTensorConstant(RankedTensorType scalarType,
|
|
float value,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto elementType = scalarType.getElementType();
|
|
auto scalarAttr = rewriter.getFloatAttr(elementType, value);
|
|
auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), denseAttr, scalarType);
|
|
}
|
|
|
|
static Value createBroadcastedBiasScalar(Value bias,
|
|
RankedTensorType biasType,
|
|
Value row,
|
|
Value column,
|
|
RankedTensorType scalarType,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
|
|
if (biasType.getRank() == 1) {
|
|
SmallVector<OpFoldResult> offsets {biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
|
: OpFoldResult(column)};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
|
|
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
|
|
Value vector =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides).getResult();
|
|
SmallVector<ReassociationIndices> reassociation {
|
|
ReassociationIndices {0, 1}
|
|
};
|
|
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
|
|
}
|
|
|
|
if (biasType.getRank() == 2) {
|
|
SmallVector<OpFoldResult> offsets {
|
|
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(row),
|
|
biasType.getDimSize(1) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
return tensor::ExtractSliceOp::create(rewriter, loc, scalarType, bias, offsets, sizes, unitStrides).getResult();
|
|
}
|
|
|
|
Value scalar = tensor::ExtractOp::create(rewriter, loc, bias, ValueRange {}).getResult();
|
|
return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult();
|
|
}
|
|
|
|
static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
|
|
Value b,
|
|
RankedTensorType aType,
|
|
RankedTensorType bType,
|
|
RankedTensorType scalarPiecesType,
|
|
RankedTensorType outType,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t numOutRows = outType.getDimSize(0);
|
|
const int64_t numOutCols = outType.getDimSize(1);
|
|
const int64_t reductionSize = aType.getDimSize(1);
|
|
const int64_t laneCount = numOutRows * numOutCols;
|
|
auto batchOp = createSpatComputeBatch(
|
|
rewriter,
|
|
loc,
|
|
TypeRange {scalarPiecesType},
|
|
laneCount,
|
|
ValueRange {},
|
|
ValueRange {a, b},
|
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
|
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
|
Value column =
|
|
onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
|
|
|
|
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
|
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
|
Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
|
|
|
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
|
createParallelInsertSliceIntoBatchOutput(
|
|
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
return *batchOp;
|
|
}
|
|
|
|
static FailureOr<spatial::SpatCompute> createDynamicGemmOutputCompute(Value scalarPieces,
|
|
Value bias,
|
|
RankedTensorType scalarPiecesType,
|
|
RankedTensorType biasType,
|
|
RankedTensorType outType,
|
|
float alpha,
|
|
float beta,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t laneCount = scalarPiecesType.getDimSize(0);
|
|
const int64_t numOutCols = outType.getDimSize(1);
|
|
SmallVector<Value> inputs {scalarPieces};
|
|
if (bias)
|
|
inputs.push_back(bias);
|
|
|
|
return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult {
|
|
Value pieces = blockArgs[0];
|
|
Value biasArg = bias ? blockArgs[1] : Value();
|
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
|
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
|
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
|
auto loop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
c0,
|
|
cLaneCount,
|
|
c1,
|
|
ValueRange {outputInit},
|
|
[&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
|
Value outputAcc = iterArgs.front();
|
|
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, nestedLoc);
|
|
Value column =
|
|
onnx_mlir::affineModConst(rewriter, nestedLoc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp());
|
|
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value scalar = tensor::ExtractSliceOp::create(
|
|
rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides)
|
|
.getResult();
|
|
if (alpha != 1.0f) {
|
|
Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, nestedLoc);
|
|
scalar = spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, scalar, alphaTensor).getResult();
|
|
}
|
|
if (biasArg) {
|
|
Value biasScalar =
|
|
createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, nestedLoc);
|
|
if (beta != 1.0f) {
|
|
Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, nestedLoc);
|
|
biasScalar =
|
|
spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, biasScalar, betaTensor).getResult();
|
|
}
|
|
scalar = spatial::SpatVAddOp::create(rewriter, nestedLoc, scalarType, scalar, biasScalar).getResult();
|
|
}
|
|
SmallVector<OpFoldResult> outputOffsets {row, column};
|
|
Value outputNext =
|
|
tensor::InsertSliceOp::create(rewriter, nestedLoc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides)
|
|
.getResult();
|
|
yielded.push_back(outputNext);
|
|
return success();
|
|
});
|
|
if (failed(loop))
|
|
return failure();
|
|
|
|
spatial::SpatYieldOp::create(rewriter, loc, loop->results.front());
|
|
return success();
|
|
});
|
|
}
|
|
|
|
static Value createPartialGroupOffset(Value hSlice,
|
|
int64_t kSlice,
|
|
int64_t numKSlices,
|
|
int64_t numOutRows,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
MLIRContext* context = rewriter.getContext();
|
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
return createOrFoldAffineApply(rewriter,
|
|
loc,
|
|
d0 * (numKSlices * numOutRows) + kSlice * numOutRows,
|
|
ValueRange {hSlice},
|
|
rewriter.getInsertionBlock()->getParentOp());
|
|
}
|
|
|
|
static Value extractReductionPiece(Value partialPiecesArg,
|
|
Value hSlice,
|
|
int64_t kSlice,
|
|
RankedTensorType pieceType,
|
|
int64_t numKSlices,
|
|
int64_t numOutRows,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(numOutRows),
|
|
rewriter.getIndexAttr(crossbarSize.getValue())};
|
|
SmallVector<OpFoldResult> pieceOffsets {
|
|
createPartialGroupOffset(hSlice, kSlice, numKSlices, numOutRows, rewriter, loc), rewriter.getIndexAttr(0)};
|
|
return tensor::ExtractSliceOp::create(
|
|
rewriter, loc, pieceType, partialPiecesArg, pieceOffsets, pieceSizes, unitStrides)
|
|
.getResult();
|
|
}
|
|
|
|
static Value reducePartialPiecesForHSlice(Value partialPiecesArg,
|
|
Value hSlice,
|
|
RankedTensorType pieceType,
|
|
int64_t numKSlices,
|
|
int64_t numOutRows,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
SmallVector<Value> activePieces;
|
|
activePieces.reserve(numKSlices);
|
|
for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice)
|
|
activePieces.push_back(
|
|
extractReductionPiece(partialPiecesArg, hSlice, kSlice, pieceType, numKSlices, numOutRows, rewriter, loc));
|
|
|
|
while (activePieces.size() > 1) {
|
|
SmallVector<Value> nextPieces;
|
|
nextPieces.reserve((activePieces.size() + 1) / 2);
|
|
for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2)
|
|
nextPieces.push_back(
|
|
spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1])
|
|
.getResult());
|
|
if (activePieces.size() % 2 != 0)
|
|
nextPieces.push_back(activePieces.back());
|
|
activePieces = std::move(nextPieces);
|
|
}
|
|
|
|
return activePieces.front();
|
|
}
|
|
|
|
static FailureOr<spatial::SpatCompute> createReductionCompute(Value partialPieces,
|
|
Value bias,
|
|
RankedTensorType partialPiecesType,
|
|
RankedTensorType outType,
|
|
RankedTensorType paddedOutType,
|
|
int64_t numKSlices,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
SmallVector<Value> inputs {partialPieces};
|
|
if (bias)
|
|
inputs.push_back(bias);
|
|
|
|
auto computeOp =
|
|
createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult {
|
|
Value partialPiecesArg = blockArgs[0];
|
|
Value biasArg = bias ? blockArgs[1] : Value();
|
|
if (biasArg && cast<RankedTensorType>(biasArg.getType()) != paddedOutType)
|
|
biasArg = createZeroPaddedTensor(biasArg, paddedOutType, rewriter, loc);
|
|
|
|
const int64_t numOutRows = outType.getDimSize(0);
|
|
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(1), crossbarSize.getValue());
|
|
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
|
partialPiecesType.getElementType());
|
|
|
|
Value outputInit =
|
|
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
|
|
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(numOutRows),
|
|
rewriter.getIndexAttr(crossbarSize.getValue())};
|
|
|
|
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
|
|
Value reduced =
|
|
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
|
|
Value hOffset = onnx_mlir::affineMulConst(
|
|
rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
|
|
if (biasArg) {
|
|
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
|
|
Value biasSlice =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, pieceType, biasArg, biasOffsets, pieceSizes, unitStrides)
|
|
.getResult();
|
|
reduced = spatial::SpatVAddOp::create(rewriter, loc, pieceType, reduced, biasSlice).getResult();
|
|
}
|
|
|
|
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), hOffset};
|
|
return tensor::InsertSliceOp::create(rewriter, loc, reduced, outputAcc, outputOffsets, pieceSizes, unitStrides)
|
|
.getResult();
|
|
};
|
|
|
|
Value paddedOutput = outputInit;
|
|
if (numOutHSlices == 1) {
|
|
Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
|
}
|
|
else {
|
|
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
|
Value cOutHSlices =
|
|
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
|
auto hLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
c0,
|
|
cOutHSlices,
|
|
c1,
|
|
ValueRange {outputInit},
|
|
[&](OpBuilder&, Location, Value hSlice, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
|
yielded.push_back(buildOutputSlice(iterArgs.front(), hSlice));
|
|
return success();
|
|
});
|
|
if (failed(hLoop))
|
|
return failure();
|
|
paddedOutput = hLoop->results.front();
|
|
}
|
|
|
|
Value result = paddedOutput;
|
|
if (paddedOutType != outType) {
|
|
SmallVector<OpFoldResult> outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(outType.getDimSize(0)),
|
|
rewriter.getIndexAttr(outType.getDimSize(1))};
|
|
result =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, unitStrides)
|
|
.getResult();
|
|
}
|
|
spatial::SpatYieldOp::create(rewriter, loc, result);
|
|
return success();
|
|
});
|
|
|
|
return computeOp;
|
|
}
|
|
|
|
struct GemmToSpatialComputes : OpConversionPattern<ONNXGemmOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const {
|
|
Location loc = gemmOp.getLoc();
|
|
Value a = gemmOpAdaptor.getA();
|
|
Value b = gemmOpAdaptor.getB();
|
|
Value c = gemmOpAdaptor.getC();
|
|
|
|
auto aType = dyn_cast<RankedTensorType>(a.getType());
|
|
auto bType = dyn_cast<RankedTensorType>(b.getType());
|
|
auto outType = dyn_cast<RankedTensorType>(gemmOp.getY().getType());
|
|
if (!aType || !bType || !outType)
|
|
return failure();
|
|
if (!aType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
|
|
return failure();
|
|
}
|
|
if (!bType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
|
|
return failure();
|
|
}
|
|
if (!outType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
|
|
return failure();
|
|
}
|
|
if (aType.getRank() != 2) {
|
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm input A", aType.getRank(), {2});
|
|
return failure();
|
|
}
|
|
if (bType.getRank() != 2) {
|
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm input B", bType.getRank(), {2});
|
|
return failure();
|
|
}
|
|
if (outType.getRank() != 2) {
|
|
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm result", outType.getRank(), {2});
|
|
return failure();
|
|
}
|
|
|
|
if (gemmOpAdaptor.getTransA()) {
|
|
auto aShape = aType.getShape();
|
|
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType(), aType.getEncoding());
|
|
a = ONNXTransposeOp::create(rewriter, loc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
|
aType = transposedType;
|
|
}
|
|
|
|
if (gemmOpAdaptor.getTransB()) {
|
|
auto bShape = bType.getShape();
|
|
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding());
|
|
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
|
bType = transposedType;
|
|
}
|
|
|
|
const int64_t numOutRows = outType.getDimSize(0);
|
|
const int64_t numOutCols = outType.getDimSize(1);
|
|
const int64_t reductionSize = aType.getDimSize(1);
|
|
|
|
if (!isCompileTimeComputable(b)) {
|
|
bool hasC = hasGemmBias(c);
|
|
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
|
float beta = gemmOpAdaptor.getBeta().convertToFloat();
|
|
RankedTensorType biasType;
|
|
if (hasC) {
|
|
auto cType = dyn_cast<RankedTensorType>(c.getType());
|
|
if (!cType || !cType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
|
return failure();
|
|
}
|
|
auto verifiedBiasType = verifyDynamicGemmBiasType(cType, outType);
|
|
if (failed(verifiedBiasType)) {
|
|
gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape");
|
|
return failure();
|
|
}
|
|
biasType = *verifiedBiasType;
|
|
}
|
|
|
|
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize
|
|
|| bType.getDimSize(1) != numOutCols) {
|
|
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
|
|
return failure();
|
|
}
|
|
|
|
const int64_t laneCount64 = numOutRows * numOutCols;
|
|
if (laneCount64 > std::numeric_limits<int32_t>::max()) {
|
|
gemmOp.emitOpError("requires Gemm dynamic batch lane count to fit in i32");
|
|
return failure();
|
|
}
|
|
|
|
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
|
auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc);
|
|
if (failed(batchOp))
|
|
return failure();
|
|
auto outputCompute = createDynamicGemmOutputCompute(
|
|
batchOp->getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
|
|
if (failed(outputCompute))
|
|
return failure();
|
|
rewriter.replaceOp(gemmOp, outputCompute->getResults());
|
|
return success();
|
|
}
|
|
|
|
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
|
if (failed(scaledB)) {
|
|
gemmOp.emitOpError("requires constant Gemm input B when alpha is not 1.0");
|
|
return failure();
|
|
}
|
|
b = *scaledB;
|
|
bType = cast<RankedTensorType>(b.getType());
|
|
|
|
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
|
|
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
|
|
return failure();
|
|
}
|
|
|
|
const int64_t numKSlices = ceilIntegerDivide(reductionSize, crossbarSize.getValue());
|
|
const int64_t numOutHSlices = ceilIntegerDivide(numOutCols, crossbarSize.getValue());
|
|
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
|
|
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
|
|
|
|
auto paddedBType = RankedTensorType::get({paddedReductionSize, paddedOutCols}, bType.getElementType());
|
|
auto paddedB = materializePaddedConstantMatrix(b, paddedBType, rewriter, loc);
|
|
if (failed(paddedB)) {
|
|
gemmOp.emitOpError("requires constant Gemm input B so tiled weights can be padded statically");
|
|
return failure();
|
|
}
|
|
b = *paddedB;
|
|
auto paddedAType = RankedTensorType::get({numOutRows, paddedReductionSize}, aType.getElementType());
|
|
a = createPaddedInputCompute(a, paddedAType, rewriter, loc);
|
|
aType = paddedAType;
|
|
|
|
Value bias;
|
|
bool hasC = hasGemmBias(c);
|
|
auto paddedOutType = RankedTensorType::get({numOutRows, paddedOutCols}, outType.getElementType());
|
|
if (hasC) {
|
|
auto cType = dyn_cast<RankedTensorType>(c.getType());
|
|
if (!cType || !cType.hasStaticShape()) {
|
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
|
return failure();
|
|
}
|
|
|
|
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
|
if (failed(scaledC)) {
|
|
gemmOp.emitOpError("requires constant Gemm bias C when beta is not 1.0");
|
|
return failure();
|
|
}
|
|
c = *scaledC;
|
|
|
|
auto preparedBias = prepareBias(c, outType, paddedOutType, rewriter, loc);
|
|
if (failed(preparedBias)) {
|
|
gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape");
|
|
return failure();
|
|
}
|
|
bias = *preparedBias;
|
|
}
|
|
|
|
const int64_t laneCount64 = numOutHSlices * numKSlices * numOutRows;
|
|
if (laneCount64 > std::numeric_limits<int32_t>::max()) {
|
|
gemmOp.emitOpError("requires Gemm tiled batch lane count to fit in i32");
|
|
return failure();
|
|
}
|
|
|
|
auto partialPiecesType =
|
|
RankedTensorType::get({laneCount64, static_cast<int64_t>(crossbarSize.getValue())}, outType.getElementType());
|
|
auto batchOp =
|
|
createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
|
if (failed(batchOp))
|
|
return failure();
|
|
auto reductionCompute = createReductionCompute(
|
|
batchOp->getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc);
|
|
if (failed(reductionCompute))
|
|
return failure();
|
|
|
|
rewriter.replaceOp(gemmOp, reductionCompute->getResults());
|
|
return success();
|
|
}
|
|
|
|
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<GemmToSpatialComputes>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|