a34ac223c0
Validate Operations / validate-operations (push) Has been cancelled
remove unsupported tests
1127 lines
58 KiB
C++
1127 lines
58 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
|
ArrayRef<int64_t> rhsBatchShape) {
|
|
const int64_t resultRank = std::max<int64_t>(lhsBatchShape.size(), rhsBatchShape.size());
|
|
SmallVector<int64_t> resultShape(resultRank, 1);
|
|
for (int64_t resultIndex = resultRank - 1, lhsIndex = lhsBatchShape.size() - 1, rhsIndex = rhsBatchShape.size() - 1;
|
|
resultIndex >= 0;
|
|
--resultIndex, --lhsIndex, --rhsIndex) {
|
|
const int64_t lhsDim = lhsIndex >= 0 ? lhsBatchShape[lhsIndex] : 1;
|
|
const int64_t rhsDim = rhsIndex >= 0 ? rhsBatchShape[rhsIndex] : 1;
|
|
if (lhsDim != rhsDim && lhsDim != 1 && rhsDim != 1)
|
|
return failure();
|
|
resultShape[resultIndex] = std::max(lhsDim, rhsDim);
|
|
}
|
|
return resultShape;
|
|
}
|
|
|
|
static int64_t mapStaticBroadcastedBatchIndex(int64_t outputBatchIndex,
|
|
ArrayRef<int64_t> sourceBatchShape,
|
|
ArrayRef<int64_t> outputBatchShape) {
|
|
if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1)
|
|
return 0;
|
|
if (llvm::equal(sourceBatchShape, outputBatchShape))
|
|
return outputBatchIndex;
|
|
|
|
SmallVector<int64_t> outputStrides = computeRowMajorStrides(outputBatchShape);
|
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceBatchShape);
|
|
int64_t sourceFlatIndex = 0;
|
|
for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast<int64_t>(sourceBatchShape.size()); ++sourceDimIndex) {
|
|
if (sourceBatchShape[sourceDimIndex] == 1)
|
|
continue;
|
|
const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex;
|
|
const int64_t outputDimStride = outputStrides.empty() ? 1 : outputStrides[outputDimIndex];
|
|
const int64_t outputDimIndexValue = outputDimStride == 1
|
|
? outputBatchIndex % outputBatchShape[outputDimIndex]
|
|
: (outputBatchIndex / outputDimStride) % outputBatchShape[outputDimIndex];
|
|
sourceFlatIndex += outputDimIndexValue * sourceStrides[sourceDimIndex];
|
|
}
|
|
return sourceFlatIndex;
|
|
}
|
|
|
|
static Value computeFlatBatchIndexCoordinate(
|
|
Value flatBatchIndex, ArrayRef<int64_t> batchShape, int64_t dimIndex, PatternRewriter& rewriter, Location loc) {
|
|
if (batchShape[dimIndex] == 1)
|
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
|
|
const int64_t dimStride = dimIndex + 1 == static_cast<int64_t>(batchShape.size())
|
|
? 1
|
|
: getStaticShapeElementCount(batchShape.drop_front(dimIndex + 1));
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value dimCoordinate = flatBatchIndex;
|
|
if (dimStride != 1)
|
|
dimCoordinate = affineFloorDivConst(rewriter, loc, dimCoordinate, dimStride, anchorOp);
|
|
return affineModConst(rewriter, loc, dimCoordinate, batchShape[dimIndex], anchorOp);
|
|
}
|
|
|
|
static Value mapOutputBatchIndexToSourceBatchIndex(Value outputBatchIndex,
|
|
ArrayRef<int64_t> sourceBatchShape,
|
|
ArrayRef<int64_t> outputBatchShape,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1)
|
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
if (llvm::equal(sourceBatchShape, outputBatchShape))
|
|
return outputBatchIndex;
|
|
|
|
Value sourceBatchIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceBatchShape);
|
|
for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast<int64_t>(sourceBatchShape.size()); ++sourceDimIndex) {
|
|
if (sourceBatchShape[sourceDimIndex] == 1)
|
|
continue;
|
|
const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex;
|
|
Value outputCoordinate =
|
|
computeFlatBatchIndexCoordinate(outputBatchIndex, outputBatchShape, outputDimIndex, rewriter, loc);
|
|
Value contribution = sourceStrides[sourceDimIndex] == 1
|
|
? outputCoordinate
|
|
: affineMulConst(rewriter,
|
|
loc,
|
|
outputCoordinate,
|
|
sourceStrides[sourceDimIndex],
|
|
rewriter.getInsertionBlock()->getParentOp());
|
|
sourceBatchIndex = arith::AddIOp::create(rewriter, loc, sourceBatchIndex, contribution);
|
|
}
|
|
return sourceBatchIndex;
|
|
}
|
|
|
|
static Value
|
|
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
if (type.getRank() == 2 || type.getRank() == 3)
|
|
return value;
|
|
|
|
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
|
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
|
|
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
|
reassociation.front().push_back(dim);
|
|
|
|
auto buildCollapsed = [&](Value input) -> Value {
|
|
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
|
};
|
|
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
|
|
}
|
|
|
|
static Value
|
|
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
|
|
if (cast<RankedTensorType>(value.getType()) == outputType)
|
|
return value;
|
|
|
|
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
|
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
|
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
|
|
for (size_t dim = 0; dim < batchRank; ++dim)
|
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
|
|
|
auto buildExpanded = [&](Value input) -> Value {
|
|
return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
|
|
};
|
|
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
|
}
|
|
|
|
static Value createMatrixFromVector(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
|
auto buildExpanded = [&](Value input) -> Value {
|
|
return tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
resultType,
|
|
input,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1}
|
|
});
|
|
};
|
|
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildExpanded);
|
|
}
|
|
|
|
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> removedAxes) {
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
ReassociationIndices currentGroup;
|
|
for (auto [axis, removeAxis] : llvm::enumerate(removedAxes)) {
|
|
currentGroup.push_back(axis);
|
|
if (!removeAxis) {
|
|
reassociation.push_back(currentGroup);
|
|
currentGroup.clear();
|
|
}
|
|
}
|
|
|
|
if (!currentGroup.empty()) {
|
|
if (reassociation.empty())
|
|
reassociation.push_back(std::move(currentGroup));
|
|
else
|
|
reassociation.back().append(currentGroup.begin(), currentGroup.end());
|
|
}
|
|
return reassociation;
|
|
}
|
|
|
|
static Value squeezeUnitDims(
|
|
Value value, RankedTensorType resultType, ArrayRef<bool> removedAxes, PatternRewriter& rewriter, Location loc) {
|
|
if (cast<RankedTensorType>(value.getType()) == resultType)
|
|
return value;
|
|
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
resultType.getRank() == 0 ? SmallVector<ReassociationIndices> {} : buildCollapseReassociation(removedAxes);
|
|
auto buildCollapsed = [&](Value input) -> Value {
|
|
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation).getResult();
|
|
};
|
|
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildCollapsed);
|
|
}
|
|
|
|
static Value ensureBatchedTensor(
|
|
Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
if (type.getRank() == 3)
|
|
return value;
|
|
|
|
auto batchedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
|
auto buildExpanded = [&](Value input) -> Value {
|
|
return tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
batchedType,
|
|
input,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
};
|
|
return materializeOrComputeUnary(value, batchedType, rewriter, loc, buildExpanded);
|
|
}
|
|
|
|
static Value extractBatchMatrix(Value value,
|
|
int64_t batchIndex,
|
|
int64_t batchSize,
|
|
int64_t rows,
|
|
int64_t cols,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
if (type.getRank() == 2)
|
|
return value;
|
|
|
|
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
|
|
SmallVector<OpFoldResult> offsets = {
|
|
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
|
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
|
|
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
|
auto buildMatrix = [&](Value input) -> Value {
|
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
matrixType,
|
|
slice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
};
|
|
|
|
return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
|
|
}
|
|
|
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
auto shape = type.getShape();
|
|
auto createONNXTranspose = [&](RankedTensorType resultType, ArrayRef<int64_t> permutation) {
|
|
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)).getResult();
|
|
};
|
|
if (type.getRank() == 2) {
|
|
auto resultType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType(), type.getEncoding());
|
|
return createONNXTranspose(resultType, {1, 0});
|
|
}
|
|
|
|
auto resultType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType(), type.getEncoding());
|
|
return createONNXTranspose(resultType, {0, 2, 1});
|
|
}
|
|
|
|
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
|
auto sourceType = cast<RankedTensorType>(value.getType());
|
|
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
|
SmallVector<OpFoldResult> highPads;
|
|
highPads.reserve(sourceType.getRank());
|
|
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
|
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
|
|
|
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
|
auto* padBlock = new Block();
|
|
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
|
padOp.getRegion().push_back(padBlock);
|
|
rewriter.setInsertionPointToStart(padBlock);
|
|
auto zero = getOrCreateConstant(
|
|
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
|
tensor::YieldOp::create(rewriter, loc, zero);
|
|
rewriter.setInsertionPointAfter(padOp);
|
|
return padOp.getResult();
|
|
}
|
|
|
|
static Value createPaddedBatchedInputCompute(Value input,
|
|
RankedTensorType paddedInputType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
|
if (inputType == paddedInputType)
|
|
return input;
|
|
|
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
|
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
|
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
|
});
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
static FailureOr<Value> materializePaddedBatchedWeight(Value value,
|
|
ArrayRef<int64_t> sourceBatchShape,
|
|
ArrayRef<int64_t> targetBatchShape,
|
|
RankedTensorType resultType,
|
|
PatternRewriter& rewriter) {
|
|
auto sourceType = cast<RankedTensorType>(value.getType());
|
|
if (sourceType == resultType)
|
|
return value;
|
|
|
|
auto denseAttr = getHostConstDenseElementsAttr(value);
|
|
if (!denseAttr)
|
|
return failure();
|
|
|
|
const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1);
|
|
const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2);
|
|
const int64_t targetBatch = targetBatchShape.empty() ? 1 : getStaticShapeElementCount(targetBatchShape);
|
|
const int64_t targetRows = resultType.getDimSize(1);
|
|
const int64_t targetCols = resultType.getDimSize(2);
|
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType()));
|
|
|
|
for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) {
|
|
const int64_t sourceBatchIdx =
|
|
sourceType.getRank() == 2 ? 0 : mapStaticBroadcastedBatchIndex(batchIdx, sourceBatchShape, targetBatchShape);
|
|
const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols;
|
|
const int64_t targetBatchBase = batchIdx * targetRows * targetCols;
|
|
for (int64_t row = 0; row < sourceRows; ++row)
|
|
for (int64_t col = 0; col < sourceCols; ++col)
|
|
resultValues[targetBatchBase + row * targetCols + col] = sourceValues[sourceBatchBase + row * sourceCols + col];
|
|
}
|
|
|
|
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
|
}
|
|
|
|
static Value extractBatchedATile(Value a,
|
|
ArrayRef<int64_t> sourceBatchShape,
|
|
ArrayRef<int64_t> outputBatchShape,
|
|
Value outputBatchIndex,
|
|
Value row,
|
|
Value kOffset,
|
|
RankedTensorType aTileType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType());
|
|
Value sourceBatchIndex =
|
|
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
|
|
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), row, kOffset};
|
|
SmallVector<OpFoldResult> sizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))};
|
|
auto slice =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, getUnitStrides(rewriter, 3));
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
aTileType,
|
|
slice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
}
|
|
|
|
static Value extractBatchedBTile(Value b,
|
|
ArrayRef<int64_t> sourceBatchShape,
|
|
ArrayRef<int64_t> outputBatchShape,
|
|
Value outputBatchIndex,
|
|
Value kOffset,
|
|
Value hOffset,
|
|
RankedTensorType bTileType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto bSliceType =
|
|
RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType());
|
|
Value sourceBatchIndex =
|
|
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
|
|
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), kOffset, hOffset};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(bTileType.getDimSize(0)),
|
|
rewriter.getIndexAttr(bTileType.getDimSize(1))};
|
|
auto slice =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, bSliceType, b, offsets, sizes, getUnitStrides(rewriter, 3));
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
bTileType,
|
|
slice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
}
|
|
|
|
static Value getBatchLaneIndex(
|
|
Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) {
|
|
return affineFloorDivConst(
|
|
rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp());
|
|
}
|
|
|
|
static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
|
|
Value b,
|
|
RankedTensorType aType,
|
|
ArrayRef<int64_t> aBatchShape,
|
|
RankedTensorType bType,
|
|
ArrayRef<int64_t> bBatchShape,
|
|
ArrayRef<int64_t> outputBatchShape,
|
|
RankedTensorType partialPiecesType,
|
|
int64_t numOutRows,
|
|
int64_t numKSlices,
|
|
int64_t numOutHSlices,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
|
auto batchOp = createSpatComputeBatch(
|
|
rewriter,
|
|
loc,
|
|
TypeRange {partialPiecesType},
|
|
laneCount,
|
|
ValueRange {b},
|
|
ValueRange {a},
|
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value row = affineModConst(rewriter, loc, args.lane, numOutRows, anchorOp);
|
|
Value outerLane = affineFloorDivConst(rewriter, loc, args.lane, numOutRows, anchorOp);
|
|
Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
|
Value sliceLane = affineModConst(rewriter, loc, outerLane, numKSlices * numOutHSlices, anchorOp);
|
|
Value kSlice = affineModConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
|
|
Value hSlice = affineFloorDivConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
|
|
Value kOffset = affineMulConst(rewriter, loc, kSlice, crossbarSize.getValue(), anchorOp);
|
|
Value hOffset = affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), anchorOp);
|
|
|
|
auto aTileType =
|
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
|
auto bTileType = RankedTensorType::get(
|
|
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
|
bType.getElementType());
|
|
auto pieceType =
|
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
|
|
|
Value aTile = extractBatchedATile(
|
|
args.inputs.front(), aBatchShape, outputBatchShape, batch, row, kOffset, aTileType, rewriter, loc);
|
|
Value bTile = extractBatchedBTile(
|
|
args.weights.front(), bBatchShape, outputBatchShape, batch, kOffset, hOffset, bTileType, rewriter, loc);
|
|
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
|
|
|
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
|
createParallelInsertSliceIntoBatchOutput(
|
|
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
return *batchOp;
|
|
}
|
|
|
|
static Value extractDynamicBatchedBColumn(Value matrix,
|
|
ArrayRef<int64_t> sourceBatchShape,
|
|
ArrayRef<int64_t> outputBatchShape,
|
|
Value outputBatchIndex,
|
|
Value column,
|
|
RankedTensorType vectorType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
|
Value sourceBatchIndex =
|
|
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
|
|
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), rewriter.getIndexAttr(0), column};
|
|
SmallVector<OpFoldResult> sizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value columnSlice = tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides);
|
|
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
|
Value collapsed = tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
collapsedType,
|
|
columnSlice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1, 2}
|
|
})
|
|
.getResult();
|
|
return tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
vectorType,
|
|
collapsed,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1}
|
|
})
|
|
.getResult();
|
|
}
|
|
|
|
static Value extractDynamicBatchedRowVector(Value matrix,
|
|
ArrayRef<int64_t> sourceBatchShape,
|
|
ArrayRef<int64_t> outputBatchShape,
|
|
Value outputBatchIndex,
|
|
Value row,
|
|
RankedTensorType vectorType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
|
Value sourceBatchIndex =
|
|
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
|
|
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), row, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
|
auto rowSlice =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
vectorType,
|
|
rowSlice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
}
|
|
|
|
static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
|
|
ArrayRef<int64_t> aBatchShape,
|
|
Value b,
|
|
ArrayRef<int64_t> bBatchShape,
|
|
ArrayRef<int64_t> outputBatchShape,
|
|
RankedTensorType aType,
|
|
RankedTensorType bType,
|
|
RankedTensorType scalarPiecesType,
|
|
RankedTensorType outType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t numBatches = outType.getDimSize(0);
|
|
const int64_t numOutRows = outType.getDimSize(1);
|
|
const int64_t numOutCols = outType.getDimSize(2);
|
|
const int64_t reductionSize = aType.getDimSize(2);
|
|
const int64_t laneCount = numBatches * numOutRows * numOutCols;
|
|
auto batchOp = createSpatComputeBatch(
|
|
rewriter,
|
|
loc,
|
|
TypeRange {scalarPiecesType},
|
|
laneCount,
|
|
ValueRange {},
|
|
ValueRange {a, b},
|
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value batch = affineFloorDivConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
|
|
Value batchLane = affineModConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
|
|
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
|
|
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
|
|
|
|
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
|
Value aVector = extractDynamicBatchedRowVector(
|
|
args.inputs[0], aBatchShape, outputBatchShape, batch, row, vectorType, rewriter, loc);
|
|
Value bVector = extractDynamicBatchedBColumn(
|
|
args.inputs[1], bBatchShape, outputBatchShape, batch, column, vectorType, rewriter, loc);
|
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
|
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
createParallelInsertSliceIntoBatchOutput(
|
|
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
|
});
|
|
if (failed(batchOp))
|
|
return failure();
|
|
return *batchOp;
|
|
}
|
|
|
|
static FailureOr<Value> createBatchedDynamicOutputCompute(Value scalarPieces,
|
|
RankedTensorType scalarPiecesType,
|
|
RankedTensorType outType,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t laneCount = scalarPiecesType.getDimSize(0);
|
|
const int64_t numOutRows = outType.getDimSize(1);
|
|
const int64_t numOutCols = outType.getDimSize(2);
|
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
|
auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType());
|
|
|
|
auto computeOp = createSpatCompute<1>(
|
|
rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) -> LogicalResult {
|
|
Value outputInit =
|
|
tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
|
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
|
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
|
auto loop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
c0,
|
|
cLaneCount,
|
|
c1,
|
|
ValueRange {outputInit},
|
|
[&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
|
Value outputAcc = iterArgs.front();
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value batch = affineFloorDivConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp);
|
|
Value batchLane = affineModConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp);
|
|
Value row = affineFloorDivConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp);
|
|
Value column = affineModConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp);
|
|
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value scalar = tensor::ExtractSliceOp::create(
|
|
rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
|
nestedLoc,
|
|
outputScalarType,
|
|
scalar,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2}
|
|
});
|
|
SmallVector<OpFoldResult> outputOffsets {batch, row, column};
|
|
SmallVector<OpFoldResult> outputSizes = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value next =
|
|
tensor::InsertSliceOp::create(
|
|
rewriter, nestedLoc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
|
.getResult();
|
|
yielded.push_back(next);
|
|
return success();
|
|
});
|
|
if (failed(loop))
|
|
return failure();
|
|
spatial::SpatYieldOp::create(rewriter, loc, loop->results.front());
|
|
return success();
|
|
});
|
|
if (failed(computeOp))
|
|
return failure();
|
|
return computeOp->getResult(0);
|
|
}
|
|
|
|
static Value extractBatchedReductionPiece(Value partialPiecesArg,
|
|
Value batch,
|
|
Value hSlice,
|
|
int64_t kSlice,
|
|
RankedTensorType pieceType,
|
|
int64_t numKSlices,
|
|
int64_t numOutHSlices,
|
|
int64_t numOutRows,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value batchOffset = affineMulConst(rewriter, loc, batch, numOutRows * numKSlices * numOutHSlices, anchorOp);
|
|
Value hOffset = affineMulConst(rewriter, loc, hSlice, numKSlices * numOutRows, anchorOp);
|
|
Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows);
|
|
Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset);
|
|
Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset);
|
|
SmallVector<OpFoldResult> offsets {pieceOffset, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
|
|
return tensor::ExtractSliceOp::create(
|
|
rewriter, loc, pieceType, partialPiecesArg, offsets, sizes, getUnitStrides(rewriter, 2));
|
|
}
|
|
|
|
static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg,
|
|
Value batch,
|
|
Value hSlice,
|
|
RankedTensorType pieceType,
|
|
int64_t numKSlices,
|
|
int64_t numOutHSlices,
|
|
int64_t numOutRows,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
SmallVector<Value> activePieces;
|
|
activePieces.reserve(numKSlices);
|
|
for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice)
|
|
activePieces.push_back(extractBatchedReductionPiece(
|
|
partialPiecesArg, batch, hSlice, kSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc));
|
|
|
|
while (activePieces.size() > 1) {
|
|
SmallVector<Value> nextPieces;
|
|
nextPieces.reserve((activePieces.size() + 1) / 2);
|
|
for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2)
|
|
nextPieces.push_back(
|
|
spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1])
|
|
.getResult());
|
|
if (activePieces.size() % 2 != 0)
|
|
nextPieces.push_back(activePieces.back());
|
|
activePieces = std::move(nextPieces);
|
|
}
|
|
|
|
return activePieces.front();
|
|
}
|
|
|
|
static FailureOr<Value> createBatchedReductionCompute(Value partialPieces,
|
|
RankedTensorType partialPiecesType,
|
|
RankedTensorType outType,
|
|
RankedTensorType paddedOutType,
|
|
int64_t numBatches,
|
|
int64_t numKSlices,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto computeOp = createSpatCompute<1>(
|
|
rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) -> LogicalResult {
|
|
const int64_t numOutRows = outType.getDimSize(1);
|
|
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue());
|
|
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
|
partialPiecesType.getElementType());
|
|
auto outputSliceType = RankedTensorType::get({1, numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
|
partialPiecesType.getElementType());
|
|
|
|
Value outputInit =
|
|
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
|
|
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
|
Value cNumBatches = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numBatches);
|
|
Value cNumOutHSlices =
|
|
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
|
|
|
auto batchLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
c0,
|
|
cNumBatches,
|
|
c1,
|
|
ValueRange {outputInit},
|
|
[&](
|
|
OpBuilder&, Location batchLoc, Value batch, ValueRange batchIterArgs, SmallVectorImpl<Value>& batchYielded) {
|
|
auto hLoop = buildNormalizedScfFor(
|
|
rewriter,
|
|
batchLoc,
|
|
c0,
|
|
cNumOutHSlices,
|
|
c1,
|
|
ValueRange {batchIterArgs.front()},
|
|
[&](OpBuilder&, Location hLoc, Value hSlice, ValueRange hIterArgs, SmallVectorImpl<Value>& hYielded) {
|
|
Value outputAcc = hIterArgs.front();
|
|
Value reduced = reduceBatchedPartialPiecesForHSlice(
|
|
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, hLoc);
|
|
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
|
|
hLoc,
|
|
outputSliceType,
|
|
reduced,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
Value hOffset = affineMulConst(
|
|
rewriter, hLoc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
|
|
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
|
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(numOutRows),
|
|
rewriter.getIndexAttr(crossbarSize.getValue())};
|
|
Value next =
|
|
tensor::InsertSliceOp::create(
|
|
rewriter, hLoc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
|
.getResult();
|
|
hYielded.push_back(next);
|
|
return success();
|
|
});
|
|
if (failed(hLoop))
|
|
return failure();
|
|
batchYielded.push_back(hLoop->results.front());
|
|
return success();
|
|
});
|
|
if (failed(batchLoop))
|
|
return failure();
|
|
Value paddedOutput = batchLoop->results.front();
|
|
Value result = paddedOutput;
|
|
if (paddedOutType != outType) {
|
|
SmallVector<OpFoldResult> outputOffsets {
|
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(numBatches),
|
|
rewriter.getIndexAttr(outType.getDimSize(1)),
|
|
rewriter.getIndexAttr(outType.getDimSize(2))};
|
|
result = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3));
|
|
}
|
|
spatial::SpatYieldOp::create(rewriter, loc, result);
|
|
return success();
|
|
});
|
|
if (failed(computeOp))
|
|
return failure();
|
|
return computeOp->getResult(0);
|
|
}
|
|
|
|
struct NormalizedMatMulInfo {
|
|
RankedTensorType lhsType;
|
|
RankedTensorType rhsType;
|
|
RankedTensorType outType;
|
|
RankedTensorType normalizedLhsType;
|
|
RankedTensorType normalizedRhsType;
|
|
SmallVector<int64_t> lhsBatchShape;
|
|
SmallVector<int64_t> rhsBatchShape;
|
|
SmallVector<int64_t> outputBatchShape;
|
|
bool lhsWasVector;
|
|
bool rhsWasVector;
|
|
int64_t lhsBatch;
|
|
int64_t rhsBatch;
|
|
int64_t batch;
|
|
int64_t m;
|
|
int64_t k;
|
|
int64_t n;
|
|
};
|
|
|
|
struct MatMulLoweringPlan {
|
|
Value lhs;
|
|
Value rhs;
|
|
RankedTensorType lhsType;
|
|
RankedTensorType rhsType;
|
|
SmallVector<int64_t> lhsBatchShape;
|
|
SmallVector<int64_t> rhsBatchShape;
|
|
SmallVector<int64_t> outputBatchShape;
|
|
int64_t lhsBatch;
|
|
int64_t rhsBatch;
|
|
int64_t batch;
|
|
int64_t m;
|
|
int64_t k;
|
|
int64_t n;
|
|
bool transposedResult;
|
|
};
|
|
|
|
static SmallVector<int64_t> computeExpectedMatMulOutputShape(
|
|
ArrayRef<int64_t> batchShape, int64_t m, int64_t n, bool lhsWasVector, bool rhsWasVector) {
|
|
SmallVector<int64_t> shape(batchShape.begin(), batchShape.end());
|
|
if (lhsWasVector && rhsWasVector)
|
|
return shape;
|
|
if (lhsWasVector) {
|
|
shape.push_back(n);
|
|
return shape;
|
|
}
|
|
if (rhsWasVector) {
|
|
shape.push_back(m);
|
|
return shape;
|
|
}
|
|
shape.push_back(m);
|
|
shape.push_back(n);
|
|
return shape;
|
|
}
|
|
|
|
static FailureOr<NormalizedMatMulInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
|
|
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
|
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
|
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
|
|| !outType.hasStaticShape())
|
|
return failure();
|
|
if (lhsType.getRank() < 1 || rhsType.getRank() < 1)
|
|
return failure();
|
|
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
|
return failure();
|
|
|
|
const bool lhsWasVector = lhsType.getRank() == 1;
|
|
const bool rhsWasVector = rhsType.getRank() == 1;
|
|
auto normalizedLhsType =
|
|
lhsWasVector ? RankedTensorType::get({1, lhsType.getDimSize(0)}, lhsType.getElementType(), lhsType.getEncoding())
|
|
: lhsType;
|
|
auto normalizedRhsType =
|
|
rhsWasVector ? RankedTensorType::get({rhsType.getDimSize(0), 1}, rhsType.getElementType(), rhsType.getEncoding())
|
|
: rhsType;
|
|
|
|
SmallVector<int64_t> lhsBatchShape(normalizedLhsType.getShape().begin(), normalizedLhsType.getShape().end() - 2);
|
|
SmallVector<int64_t> rhsBatchShape(normalizedRhsType.getShape().begin(), normalizedRhsType.getShape().end() - 2);
|
|
auto outputBatchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
|
if (failed(outputBatchShape))
|
|
return failure();
|
|
|
|
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
|
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
|
const int64_t batch = outputBatchShape->empty() ? 1 : getStaticShapeElementCount(*outputBatchShape);
|
|
const int64_t m = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 2);
|
|
const int64_t k = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 1);
|
|
const int64_t rhsK = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 2);
|
|
const int64_t n = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 1);
|
|
if (k != rhsK)
|
|
return failure();
|
|
|
|
if (SmallVector<int64_t>(outType.getShape().begin(), outType.getShape().end())
|
|
!= computeExpectedMatMulOutputShape(*outputBatchShape, m, n, lhsWasVector, rhsWasVector)) {
|
|
return failure();
|
|
}
|
|
|
|
return NormalizedMatMulInfo {lhsType,
|
|
rhsType,
|
|
outType,
|
|
normalizedLhsType,
|
|
normalizedRhsType,
|
|
lhsBatchShape,
|
|
rhsBatchShape,
|
|
*outputBatchShape,
|
|
lhsWasVector,
|
|
rhsWasVector,
|
|
lhsBatch,
|
|
rhsBatch,
|
|
batch,
|
|
m,
|
|
k,
|
|
n};
|
|
}
|
|
|
|
static MatMulLoweringPlan buildLoweringPlan(Value normalizedLhs,
|
|
Value normalizedRhs,
|
|
const NormalizedMatMulInfo& info,
|
|
bool useTransposedForm,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
MatMulLoweringPlan plan {normalizedLhs,
|
|
normalizedRhs,
|
|
cast<RankedTensorType>(normalizedLhs.getType()),
|
|
cast<RankedTensorType>(normalizedRhs.getType()),
|
|
info.lhsBatchShape,
|
|
info.rhsBatchShape,
|
|
info.outputBatchShape,
|
|
info.lhsBatch,
|
|
info.rhsBatch,
|
|
info.batch,
|
|
info.m,
|
|
info.k,
|
|
info.n,
|
|
false};
|
|
if (!useTransposedForm)
|
|
return plan;
|
|
|
|
plan.lhs = transposeLastTwoDims(normalizedRhs, rewriter, loc);
|
|
plan.rhs = transposeLastTwoDims(normalizedLhs, rewriter, loc);
|
|
plan.lhsType = cast<RankedTensorType>(plan.lhs.getType());
|
|
plan.rhsType = cast<RankedTensorType>(plan.rhs.getType());
|
|
std::swap(plan.lhsBatchShape, plan.rhsBatchShape);
|
|
std::swap(plan.lhsBatch, plan.rhsBatch);
|
|
plan.m = info.n;
|
|
plan.n = info.m;
|
|
plan.transposedResult = true;
|
|
return plan;
|
|
}
|
|
|
|
static Value normalizeMatMulOperand(
|
|
Value value, RankedTensorType normalizedType, bool wasVector, PatternRewriter& rewriter, Location loc) {
|
|
if (!wasVector)
|
|
return value;
|
|
return createMatrixFromVector(value, normalizedType, rewriter, loc);
|
|
}
|
|
|
|
static Value finalizeNormalizedMatMulResult(Value value,
|
|
RankedTensorType directOutType,
|
|
const NormalizedMatMulInfo& info,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
// The direct lowered result is always [flatBatch, normalizedM, normalizedN].
|
|
// Restore ONNX MatMul result rank by expanding right-aligned batch dimensions
|
|
// and removing the synthetic unit matrix axes introduced for vector operands.
|
|
Value result = value;
|
|
RankedTensorType currentType = directOutType;
|
|
if (info.outputBatchShape.size() > 1) {
|
|
SmallVector<int64_t> expandedShape(info.outputBatchShape.begin(), info.outputBatchShape.end());
|
|
expandedShape.push_back(info.m);
|
|
expandedShape.push_back(info.n);
|
|
auto expandedType = RankedTensorType::get(expandedShape, info.outType.getElementType(), info.outType.getEncoding());
|
|
result = expandBatchDims(result, expandedType, info.outputBatchShape.size(), rewriter, loc);
|
|
currentType = expandedType;
|
|
}
|
|
|
|
SmallVector<bool> removedAxes(currentType.getRank(), false);
|
|
if (info.outputBatchShape.empty())
|
|
removedAxes[0] = true;
|
|
if (info.lhsWasVector)
|
|
removedAxes[currentType.getRank() - 2] = true;
|
|
if (info.rhsWasVector)
|
|
removedAxes[currentType.getRank() - 1] = true;
|
|
return squeezeUnitDims(result, info.outType, removedAxes, rewriter, loc);
|
|
}
|
|
|
|
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
|
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
|
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty())
|
|
return failure();
|
|
|
|
Location loc = matmulOp.getLoc();
|
|
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
|
|
|
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
|
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
|
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
|
|
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
|
|
int64_t gemmM = shapeInfo->m;
|
|
int64_t gemmK = shapeInfo->k;
|
|
int64_t gemmN = shapeInfo->n;
|
|
if (useTransposedForm) {
|
|
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
|
|
lhsBatchForGemm = shapeInfo->rhsBatch;
|
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
|
rhsBatchForGemm = shapeInfo->lhsBatch;
|
|
gemmM = shapeInfo->n;
|
|
gemmN = shapeInfo->m;
|
|
}
|
|
|
|
auto gemmType = RankedTensorType::get({gemmM, gemmN}, shapeInfo->outType.getElementType());
|
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
|
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
|
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
gemmType,
|
|
lhsMatrix,
|
|
rhsMatrix,
|
|
none,
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getBoolAttr(false),
|
|
rewriter.getBoolAttr(false))
|
|
.getY();
|
|
if (useTransposedForm)
|
|
gemmResult =
|
|
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
|
.getResult();
|
|
rewriter.replaceOp(matmulOp, gemmResult);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
|
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
|
if (failed(shapeInfo))
|
|
return failure();
|
|
if (!shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector && shapeInfo->outputBatchShape.empty())
|
|
return failure();
|
|
|
|
Location loc = matmulOp.getLoc();
|
|
bool useTransposedForm = !shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector
|
|
&& isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
|
|
|
Value lhs =
|
|
normalizeMatMulOperand(matmulOp.getA(), shapeInfo->normalizedLhsType, shapeInfo->lhsWasVector, rewriter, loc);
|
|
Value rhs =
|
|
normalizeMatMulOperand(matmulOp.getB(), shapeInfo->normalizedRhsType, shapeInfo->rhsWasVector, rewriter, loc);
|
|
lhs = collapseBatchDims(lhs, shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
|
rhs = collapseBatchDims(rhs, shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
|
MatMulLoweringPlan plan = buildLoweringPlan(lhs, rhs, *shapeInfo, useTransposedForm, rewriter, loc);
|
|
|
|
plan.lhs = ensureBatchedTensor(plan.lhs, plan.lhsBatch, plan.m, plan.k, rewriter, loc);
|
|
plan.rhs = ensureBatchedTensor(plan.rhs, plan.rhsBatch, plan.k, plan.n, rewriter, loc);
|
|
plan.lhsType = cast<RankedTensorType>(plan.lhs.getType());
|
|
plan.rhsType = cast<RankedTensorType>(plan.rhs.getType());
|
|
auto directOutType = RankedTensorType::get(
|
|
{plan.batch, plan.m, plan.n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding());
|
|
|
|
if (isCompileTimeComputable(plan.rhs)) {
|
|
const int64_t numKSlices = ceilIntegerDivide(plan.k, crossbarSize.getValue());
|
|
const int64_t numOutHSlices = ceilIntegerDivide(plan.n, crossbarSize.getValue());
|
|
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
|
|
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
|
|
auto paddedLhsType = RankedTensorType::get(
|
|
{plan.lhsBatch, plan.m, paddedReductionSize}, plan.lhsType.getElementType(), plan.lhsType.getEncoding());
|
|
auto paddedRhsType = RankedTensorType::get(
|
|
{plan.batch, paddedReductionSize, paddedOutCols}, plan.rhsType.getElementType(), plan.rhsType.getEncoding());
|
|
auto paddedOutType =
|
|
RankedTensorType::get({plan.batch, plan.m, paddedOutCols}, shapeInfo->outType.getElementType());
|
|
|
|
auto paddedRhs =
|
|
materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter);
|
|
if (succeeded(paddedRhs)) {
|
|
Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
|
|
const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices;
|
|
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
|
shapeInfo->outType.getElementType());
|
|
auto batchOp = createBatchedVmmBatch(paddedLhs,
|
|
*paddedRhs,
|
|
paddedLhsType,
|
|
plan.lhsBatchShape,
|
|
paddedRhsType,
|
|
plan.rhsBatchShape,
|
|
plan.outputBatchShape,
|
|
partialPiecesType,
|
|
plan.m,
|
|
numKSlices,
|
|
numOutHSlices,
|
|
rewriter,
|
|
loc);
|
|
if (failed(batchOp))
|
|
return failure();
|
|
auto result = createBatchedReductionCompute(batchOp->getResult(0),
|
|
partialPiecesType,
|
|
directOutType,
|
|
paddedOutType,
|
|
plan.batch,
|
|
numKSlices,
|
|
rewriter,
|
|
loc);
|
|
if (failed(result))
|
|
return failure();
|
|
Value finalResult = *result;
|
|
if (plan.transposedResult) {
|
|
auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n},
|
|
shapeInfo->outType.getElementType(),
|
|
shapeInfo->outType.getEncoding());
|
|
finalResult =
|
|
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
|
.getResult();
|
|
}
|
|
finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc);
|
|
rewriter.replaceOp(matmulOp, finalResult);
|
|
return success();
|
|
}
|
|
}
|
|
const int64_t laneCount = plan.batch * plan.m * plan.n;
|
|
auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType());
|
|
auto batchOp = createBatchedVvdmulBatch(plan.lhs,
|
|
plan.lhsBatchShape,
|
|
plan.rhs,
|
|
plan.rhsBatchShape,
|
|
plan.outputBatchShape,
|
|
plan.lhsType,
|
|
plan.rhsType,
|
|
scalarPiecesType,
|
|
directOutType,
|
|
rewriter,
|
|
loc);
|
|
if (failed(batchOp))
|
|
return failure();
|
|
auto result =
|
|
createBatchedDynamicOutputCompute(batchOp->getResult(0), scalarPiecesType, directOutType, rewriter, loc);
|
|
if (failed(result))
|
|
return failure();
|
|
Value finalResult = *result;
|
|
if (plan.transposedResult) {
|
|
auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n},
|
|
shapeInfo->outType.getElementType(),
|
|
shapeInfo->outType.getEncoding());
|
|
finalResult =
|
|
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
|
|
.getResult();
|
|
}
|
|
finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc);
|
|
rewriter.replaceOp(matmulOp, finalResult);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<MatMulToGemm, MatMulBatchedToSpatialComputes>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|