add conv helpers new validation tests for matmul
This commit is contained in:
@@ -51,23 +51,107 @@ static Value createPaddedRows(Value tensorValue,
|
||||
if (tensorType.getDimSize(0) == paddedRows)
|
||||
return tensorValue;
|
||||
|
||||
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
|
||||
auto paddedType =
|
||||
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 2; i++)
|
||||
for (int i = 0; i < 2; ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()),
|
||||
auto zero = getOrCreateConstant(rewriter,
|
||||
padOp.getOperation(),
|
||||
rewriter.getZeroAttr(tensorType.getElementType()),
|
||||
tensorType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value packRowsForParallelGemm(Value rows,
|
||||
RankedTensorType rowsType,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return rows;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor);
|
||||
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||
const int64_t rowWidth = rowsType.getDimSize(1);
|
||||
auto groupedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
||||
auto packedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
||||
|
||||
Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc);
|
||||
Value groupedRows = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value unpackRowsFromParallelGemm(Value packedRows,
|
||||
RankedTensorType packedRowsType,
|
||||
int64_t unpackedRows,
|
||||
int64_t rowWidth,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return packedRows;
|
||||
|
||||
const int64_t packedNumRows = packedRowsType.getDimSize(0);
|
||||
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||
auto expandedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor, rowWidth},
|
||||
packedRowsType.getElementType(),
|
||||
packedRowsType.getEncoding());
|
||||
auto paddedType =
|
||||
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
auto unpackedType =
|
||||
RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
|
||||
Value expandedRows = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
if (paddedNumRows == unpackedRows)
|
||||
return paddedRows;
|
||||
|
||||
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)};
|
||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
Value wTrans,
|
||||
RankedTensorType wType,
|
||||
@@ -189,7 +273,6 @@ static Value createIm2colRowComputes(Value x,
|
||||
Location loc) {
|
||||
auto elemType = xType.getElementType();
|
||||
constexpr size_t numInputs = 1;
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
auto im2colComputeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
@@ -278,26 +361,7 @@ static Value createIm2colRowComputes(Value x,
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||
@@ -316,41 +380,15 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||
Value gemmOut;
|
||||
if (packFactor == 1) {
|
||||
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
}
|
||||
else {
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
|
||||
gemmOut = paddedOutput;
|
||||
if (paddedNumPatches != numPatches) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||
}
|
||||
gemmOut = unpackRowsFromParallelGemm(
|
||||
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
|
||||
}
|
||||
|
||||
// Restore to NCHW layout:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@@ -5,9 +7,6 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
@@ -66,6 +65,26 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
|
||||
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
||||
}
|
||||
|
||||
static Value ensureBatchedTensor(
|
||||
Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto batchedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
auto buildExpanded = [&](Value input) -> Value {
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedType,
|
||||
input,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
};
|
||||
return materializeOrComputeUnary(value, batchedType, rewriter, loc, buildExpanded);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
int64_t batchIndex,
|
||||
int64_t batchSize,
|
||||
@@ -130,164 +149,737 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value createPaddedBatchedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedBatchedWeight(
|
||||
Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
if (sourceType == resultType)
|
||||
return value;
|
||||
|
||||
auto denseAttr = getHostConstDenseElementsAttr(value);
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1);
|
||||
const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2);
|
||||
const int64_t targetRows = resultType.getDimSize(1);
|
||||
const int64_t targetCols = resultType.getDimSize(2);
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType()));
|
||||
|
||||
for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) {
|
||||
const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx);
|
||||
const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols;
|
||||
const int64_t targetBatchBase = batchIdx * targetRows * targetCols;
|
||||
for (int64_t row = 0; row < sourceRows; ++row)
|
||||
for (int64_t col = 0; col < sourceCols; ++col)
|
||||
resultValues[targetBatchBase + row * targetCols + col] = sourceValues[sourceBatchBase + row * sourceCols + col];
|
||||
}
|
||||
|
||||
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
||||
}
|
||||
|
||||
static Value extractBatchedATile(Value a,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value row,
|
||||
Value kOffset,
|
||||
RankedTensorType aTileType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {
|
||||
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))};
|
||||
auto slice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
aTileType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value extractBatchedBTile(Value b,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value kOffset,
|
||||
Value hOffset,
|
||||
RankedTensorType bTileType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto bSliceType =
|
||||
RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {
|
||||
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(bTileType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(bTileType.getDimSize(1))};
|
||||
auto slice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, bSliceType, b, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
bTileType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value getBatchLaneIndex(
|
||||
Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) {
|
||||
return floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices);
|
||||
}
|
||||
|
||||
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
int64_t aBatchCount,
|
||||
RankedTensorType bType,
|
||||
int64_t bBatchCount,
|
||||
RankedTensorType partialPiecesType,
|
||||
int64_t numOutRows,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter,
|
||||
loc,
|
||||
TypeRange {partialPiecesType},
|
||||
laneCount,
|
||||
ValueRange {b},
|
||||
ValueRange {a},
|
||||
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = modIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
||||
Value outerLane = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
||||
Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
Value sliceLane = modIndexByConstant(rewriter, loc, outerLane, numKSlices * numOutHSlices);
|
||||
Value kSlice = modIndexByConstant(rewriter, loc, sliceLane, numKSlices);
|
||||
Value hSlice = floorDivIndexByConstant(rewriter, loc, sliceLane, numKSlices);
|
||||
Value kOffset =
|
||||
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice, crossbarSize.getValue());
|
||||
Value hOffset =
|
||||
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
|
||||
|
||||
auto aTileType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||
auto bTileType = RankedTensorType::get(
|
||||
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
||||
bType.getElementType());
|
||||
auto pieceType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||
|
||||
Value aTile =
|
||||
extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc);
|
||||
Value bTile =
|
||||
extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc);
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||
|
||||
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected batched MatMul VMM construction to succeed");
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedBColumn(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value column,
|
||||
RankedTensorType vectorType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(batch),
|
||||
rewriter.getIndexAttr(0),
|
||||
column};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value columnSlice = tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides);
|
||||
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
Value collapsed = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
collapsedType,
|
||||
columnSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2}
|
||||
})
|
||||
.getResult();
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
vectorType,
|
||||
collapsed,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
})
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedBRow(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value row,
|
||||
RankedTensorType vectorType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(batch),
|
||||
row,
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
vectorType,
|
||||
rowSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedRowVector(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value row,
|
||||
RankedTensorType vectorType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(batch),
|
||||
row,
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
vectorType,
|
||||
rowSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
|
||||
int64_t aBatchCount,
|
||||
Value b,
|
||||
int64_t bBatchCount,
|
||||
RankedTensorType aType,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numBatches = outType.getDimSize(0);
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutCols = outType.getDimSize(2);
|
||||
const int64_t reductionSize = aType.getDimSize(2);
|
||||
const int64_t laneCount = numBatches * numOutRows * numOutCols;
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter,
|
||||
loc,
|
||||
TypeRange {scalarPiecesType},
|
||||
laneCount,
|
||||
ValueRange {},
|
||||
ValueRange {a, b},
|
||||
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value batch = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols);
|
||||
Value batchLane = modIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols);
|
||||
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols);
|
||||
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols);
|
||||
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector =
|
||||
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
|
||||
Value bVector =
|
||||
bAlreadyTransposed
|
||||
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
|
||||
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected batched MatMul VVDMul construction to succeed");
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static Value createBatchedDynamicOutputCompute(Value scalarPieces,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = scalarPiecesType.getDimSize(0);
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutCols = outType.getDimSize(2);
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType());
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) {
|
||||
Value outputInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value lane = loop.getInductionVar();
|
||||
Value outputAcc = loop.getRegionIterArgs().front();
|
||||
Value batch = floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols);
|
||||
Value batchLane = modIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols);
|
||||
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols);
|
||||
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols);
|
||||
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scalar = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
outputScalarType,
|
||||
scalar,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
SmallVector<OpFoldResult> outputOffsets {batch, row, column};
|
||||
SmallVector<OpFoldResult> outputSizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
scf::YieldOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
tensor::InsertSliceOp::create(
|
||||
rewriter, loc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||
.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
static Value extractBatchedReductionPiece(Value partialPiecesArg,
|
||||
Value batch,
|
||||
Value hSlice,
|
||||
int64_t kSlice,
|
||||
RankedTensorType pieceType,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
int64_t numOutRows,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
Value batchOffset = multiplyIndexByConstant(
|
||||
rewriter, rewriter.getInsertionBlock()->getParentOp(), batch, numOutRows * numKSlices * numOutHSlices);
|
||||
Value hOffset =
|
||||
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, numKSlices * numOutRows);
|
||||
Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows);
|
||||
Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset);
|
||||
Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset);
|
||||
SmallVector<OpFoldResult> offsets {pieceOffset, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, pieceType, partialPiecesArg, offsets, sizes, getUnitStrides(rewriter, 2));
|
||||
}
|
||||
|
||||
if (llvm::all_of(inputs, isCompileTimeComputable))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg,
|
||||
Value batch,
|
||||
Value hSlice,
|
||||
RankedTensorType pieceType,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
int64_t numOutRows,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
SmallVector<Value> activePieces;
|
||||
activePieces.reserve(numKSlices);
|
||||
for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice)
|
||||
activePieces.push_back(extractBatchedReductionPiece(
|
||||
partialPiecesArg, batch, hSlice, kSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc));
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
while (activePieces.size() > 1) {
|
||||
SmallVector<Value> nextPieces;
|
||||
nextPieces.reserve((activePieces.size() + 1) / 2);
|
||||
for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2)
|
||||
nextPieces.push_back(
|
||||
spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1])
|
||||
.getResult());
|
||||
if (activePieces.size() % 2 != 0)
|
||||
nextPieces.push_back(activePieces.back());
|
||||
activePieces = std::move(nextPieces);
|
||||
}
|
||||
|
||||
return activePieces.front();
|
||||
}
|
||||
|
||||
static Value createBatchedReductionCompute(Value partialPieces,
|
||||
RankedTensorType partialPiecesType,
|
||||
RankedTensorType outType,
|
||||
RankedTensorType paddedOutType,
|
||||
int64_t numBatches,
|
||||
int64_t numKSlices,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) {
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue());
|
||||
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
partialPiecesType.getElementType());
|
||||
auto outputSliceType = RankedTensorType::get({1, numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
partialPiecesType.getElementType());
|
||||
|
||||
Value outputInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cNumBatches = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numBatches);
|
||||
Value cNumOutHSlices =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
|
||||
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||
Value batch = batchLoop.getInductionVar();
|
||||
Value batchAcc = batchLoop.getRegionIterArgs().front();
|
||||
|
||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cNumOutHSlices, c1, ValueRange {batchAcc});
|
||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||
Value hSlice = hLoop.getInductionVar();
|
||||
Value outputAcc = hLoop.getRegionIterArgs().front();
|
||||
|
||||
Value reduced = reduceBatchedPartialPiecesForHSlice(
|
||||
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc);
|
||||
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
outputSliceType,
|
||||
reduced,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value hOffset =
|
||||
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
|
||||
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
|
||||
SmallVector<OpFoldResult> outputSizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
scf::YieldOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
tensor::InsertSliceOp::create(
|
||||
rewriter, loc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||
.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(hLoop);
|
||||
scf::YieldOp::create(rewriter, loc, hLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(batchLoop);
|
||||
Value paddedOutput = batchLoop.getResult(0);
|
||||
Value result = paddedOutput;
|
||||
if (paddedOutType != outType) {
|
||||
SmallVector<OpFoldResult> outputOffsets {
|
||||
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(numBatches),
|
||||
rewriter.getIndexAttr(outType.getDimSize(1)),
|
||||
rewriter.getIndexAttr(outType.getDimSize(2))};
|
||||
result = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
struct MatMulShapeInfo {
|
||||
RankedTensorType lhsType;
|
||||
RankedTensorType rhsType;
|
||||
RankedTensorType outType;
|
||||
SmallVector<int64_t> batchShape;
|
||||
int64_t lhsBatch;
|
||||
int64_t rhsBatch;
|
||||
int64_t batch;
|
||||
int64_t m;
|
||||
int64_t k;
|
||||
int64_t n;
|
||||
};
|
||||
|
||||
static FailureOr<MatMulShapeInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
|
||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
if (outType.getRank() == 2) {
|
||||
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n};
|
||||
}
|
||||
|
||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||
if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2)
|
||||
return failure();
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
if (outType.getRank() == 2) {
|
||||
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
||||
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
int64_t gemmK = k;
|
||||
int64_t gemmN = n;
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
int64_t gemmM = shapeInfo->m;
|
||||
int64_t gemmK = shapeInfo->k;
|
||||
int64_t gemmN = shapeInfo->n;
|
||||
if (useTransposedForm) {
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = rhsBatch;
|
||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = lhsBatch;
|
||||
gemmM = n;
|
||||
gemmN = m;
|
||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
gemmM = shapeInfo->n;
|
||||
gemmN = shapeInfo->m;
|
||||
}
|
||||
|
||||
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
||||
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
||||
auto gemmType = RankedTensorType::get({gemmM, gemmN}, shapeInfo->outType.getElementType());
|
||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
if (outType.getRank() == 2) {
|
||||
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmType,
|
||||
lhsMatrix,
|
||||
rhsMatrix,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
gemmResult = transposeCompute.getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> batchResults;
|
||||
batchResults.reserve(batch);
|
||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmType,
|
||||
lhsMatrix,
|
||||
rhsMatrix,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
auto batchResultCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
||||
Value resultMatrix = input;
|
||||
if (useTransposedForm) {
|
||||
resultMatrix = ONNXTransposeOp::create(rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
input,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
resultMatrix,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmType,
|
||||
lhsMatrix,
|
||||
rhsMatrix,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
batchResults.push_back(batchResultCompute.getResult(0));
|
||||
gemmResult = transposeCompute.getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||
if (failed(shapeInfo))
|
||||
return failure();
|
||||
if (shapeInfo->outType.getRank() == 2)
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
||||
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
int64_t gemmM = shapeInfo->m;
|
||||
int64_t gemmK = shapeInfo->k;
|
||||
int64_t gemmN = shapeInfo->n;
|
||||
if (useTransposedForm) {
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
gemmM = shapeInfo->n;
|
||||
gemmN = shapeInfo->m;
|
||||
}
|
||||
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||
lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
auto lhsBatchedType = cast<RankedTensorType>(lhs.getType());
|
||||
auto rhsBatchedType = cast<RankedTensorType>(rhs.getType());
|
||||
auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType());
|
||||
|
||||
if (isCompileTimeComputable(rhs)) {
|
||||
const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue());
|
||||
const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue());
|
||||
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
|
||||
auto paddedLhsType = RankedTensorType::get(
|
||||
{lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding());
|
||||
auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols},
|
||||
rhsBatchedType.getElementType(),
|
||||
rhsBatchedType.getEncoding());
|
||||
auto paddedOutType =
|
||||
RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType());
|
||||
|
||||
auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter);
|
||||
if (succeeded(paddedRhs)) {
|
||||
Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc);
|
||||
const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices;
|
||||
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
shapeInfo->outType.getElementType());
|
||||
auto batchOp = createBatchedVmmBatch(paddedLhs,
|
||||
*paddedRhs,
|
||||
paddedLhsType,
|
||||
lhsBatchForGemm,
|
||||
paddedRhsType,
|
||||
rhsBatchForGemm,
|
||||
partialPiecesType,
|
||||
gemmM,
|
||||
numKSlices,
|
||||
numOutHSlices,
|
||||
rewriter,
|
||||
loc);
|
||||
Value result = createBatchedReductionCompute(batchOp.getResult(0),
|
||||
partialPiecesType,
|
||||
directOutType,
|
||||
paddedOutType,
|
||||
shapeInfo->batch,
|
||||
numKSlices,
|
||||
rewriter,
|
||||
loc);
|
||||
if (useTransposedForm)
|
||||
result = transposeBatchedOutput(
|
||||
result,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
const int64_t laneCount = shapeInfo->batch * gemmM * gemmN;
|
||||
auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType());
|
||||
auto batchOp = createBatchedVvdmulBatch(lhs,
|
||||
lhsBatchForGemm,
|
||||
rhs,
|
||||
rhsBatchForGemm,
|
||||
lhsBatchedType,
|
||||
rhsBatchedType,
|
||||
scalarPiecesType,
|
||||
directOutType,
|
||||
false,
|
||||
rewriter,
|
||||
loc);
|
||||
Value result =
|
||||
createBatchedDynamicOutputCompute(batchOp.getResult(0), scalarPiecesType, directOutType, rewriter, loc);
|
||||
if (useTransposedForm)
|
||||
result = transposeBatchedOutput(
|
||||
result,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
@@ -296,7 +888,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
} // namespace
|
||||
|
||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<MatMulToGemm>(ctx);
|
||||
patterns.insert<MatMulToGemm, MatMulBatchedToSpatialComputes>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user