add conv helpers new validation tests for matmul
This commit is contained in:
@@ -51,23 +51,107 @@ static Value createPaddedRows(Value tensorValue,
|
|||||||
if (tensorType.getDimSize(0) == paddedRows)
|
if (tensorType.getDimSize(0) == paddedRows)
|
||||||
return tensorValue;
|
return tensorValue;
|
||||||
|
|
||||||
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
|
auto paddedType =
|
||||||
|
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
|
||||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||||
rewriter.getIndexAttr(0)};
|
rewriter.getIndexAttr(0)};
|
||||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
||||||
auto* padBlock = new Block();
|
auto* padBlock = new Block();
|
||||||
for (int i = 0; i < 2; i++)
|
for (int i = 0; i < 2; ++i)
|
||||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
padOp.getRegion().push_back(padBlock);
|
padOp.getRegion().push_back(padBlock);
|
||||||
rewriter.setInsertionPointToStart(padBlock);
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()),
|
auto zero = getOrCreateConstant(rewriter,
|
||||||
|
padOp.getOperation(),
|
||||||
|
rewriter.getZeroAttr(tensorType.getElementType()),
|
||||||
tensorType.getElementType());
|
tensorType.getElementType());
|
||||||
tensor::YieldOp::create(rewriter, loc, zero);
|
tensor::YieldOp::create(rewriter, loc, zero);
|
||||||
rewriter.setInsertionPointAfter(padOp);
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
return padOp.getResult();
|
return padOp.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value packRowsForParallelGemm(Value rows,
|
||||||
|
RankedTensorType rowsType,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (packFactor == 1)
|
||||||
|
return rows;
|
||||||
|
|
||||||
|
const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor);
|
||||||
|
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||||
|
const int64_t rowWidth = rowsType.getDimSize(1);
|
||||||
|
auto groupedType =
|
||||||
|
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
||||||
|
auto packedType =
|
||||||
|
RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
||||||
|
|
||||||
|
Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc);
|
||||||
|
Value groupedRows = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
groupedType,
|
||||||
|
paddedRows,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
packedType,
|
||||||
|
groupedRows,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value unpackRowsFromParallelGemm(Value packedRows,
|
||||||
|
RankedTensorType packedRowsType,
|
||||||
|
int64_t unpackedRows,
|
||||||
|
int64_t rowWidth,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (packFactor == 1)
|
||||||
|
return packedRows;
|
||||||
|
|
||||||
|
const int64_t packedNumRows = packedRowsType.getDimSize(0);
|
||||||
|
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||||
|
auto expandedType =
|
||||||
|
RankedTensorType::get({packedNumRows, packFactor, rowWidth},
|
||||||
|
packedRowsType.getElementType(),
|
||||||
|
packedRowsType.getEncoding());
|
||||||
|
auto paddedType =
|
||||||
|
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||||
|
auto unpackedType =
|
||||||
|
RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||||
|
|
||||||
|
Value expandedRows = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
expandedType,
|
||||||
|
packedRows,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
Value paddedRows = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
paddedType,
|
||||||
|
expandedRows,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
if (paddedNumRows == unpackedRows)
|
||||||
|
return paddedRows;
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)};
|
||||||
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
|
||||||
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||||
Value wTrans,
|
Value wTrans,
|
||||||
RankedTensorType wType,
|
RankedTensorType wType,
|
||||||
@@ -189,7 +273,6 @@ static Value createIm2colRowComputes(Value x,
|
|||||||
Location loc) {
|
Location loc) {
|
||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
|
||||||
auto im2colComputeOp =
|
auto im2colComputeOp =
|
||||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||||
Value paddedInput = xArg;
|
Value paddedInput = xArg;
|
||||||
@@ -278,26 +361,7 @@ static Value createIm2colRowComputes(Value x,
|
|||||||
|
|
||||||
Value gemmInputRows = im2col;
|
Value gemmInputRows = im2col;
|
||||||
if (packFactor != 1) {
|
if (packFactor != 1) {
|
||||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
|
||||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
|
||||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
|
||||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
|
||||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
groupedType,
|
|
||||||
paddedIm2col,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1},
|
|
||||||
{2}
|
|
||||||
});
|
|
||||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
packedType,
|
|
||||||
groupedIm2col,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||||
@@ -316,41 +380,15 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
|||||||
int64_t packFactor,
|
int64_t packFactor,
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
|
||||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
|
||||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||||
Value gemmOut;
|
Value gemmOut;
|
||||||
if (packFactor == 1) {
|
if (packFactor == 1) {
|
||||||
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
|
||||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
|
||||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
gemmOut = unpackRowsFromParallelGemm(
|
||||||
loc,
|
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
|
||||||
expandedType,
|
|
||||||
packedOutput,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2}
|
|
||||||
});
|
|
||||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
paddedType,
|
|
||||||
expandedOutput,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1},
|
|
||||||
{2}
|
|
||||||
});
|
|
||||||
|
|
||||||
gemmOut = paddedOutput;
|
|
||||||
if (paddedNumPatches != numPatches) {
|
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore to NCHW layout:
|
// Restore to NCHW layout:
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
@@ -5,9 +7,6 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
@@ -66,6 +65,26 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
|
|||||||
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value ensureBatchedTensor(
|
||||||
|
Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
if (type.getRank() == 3)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto batchedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||||
|
auto buildExpanded = [&](Value input) -> Value {
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
batchedType,
|
||||||
|
input,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
return materializeOrComputeUnary(value, batchedType, rewriter, loc, buildExpanded);
|
||||||
|
}
|
||||||
|
|
||||||
static Value extractBatchMatrix(Value value,
|
static Value extractBatchMatrix(Value value,
|
||||||
int64_t batchIndex,
|
int64_t batchIndex,
|
||||||
int64_t batchSize,
|
int64_t batchSize,
|
||||||
@@ -130,36 +149,533 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
|||||||
perm = {0, 2, 1};
|
perm = {0, 2, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto transposeCompute =
|
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
|
||||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
});
|
});
|
||||||
return transposeCompute.getResult(0);
|
return transposeCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||||
int64_t concatDimSize = 0;
|
SmallVector<OpFoldResult> highPads;
|
||||||
for (Value input : inputs)
|
highPads.reserve(sourceType.getRank());
|
||||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||||
outputShape[axis] = concatDimSize;
|
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
|
||||||
|
|
||||||
if (llvm::all_of(inputs, isCompileTimeComputable))
|
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
auto* padBlock = new Block();
|
||||||
|
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
padOp.getRegion().push_back(padBlock);
|
||||||
});
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
return concatCompute.getResult(0);
|
auto zero = getOrCreateConstant(
|
||||||
|
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||||
|
tensor::YieldOp::create(rewriter, loc, zero);
|
||||||
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
|
return padOp.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
static Value createPaddedBatchedInputCompute(Value input,
|
||||||
using OpRewritePattern::OpRewritePattern;
|
RankedTensorType paddedInputType,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
|
if (inputType == paddedInputType)
|
||||||
|
return input;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||||
|
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> materializePaddedBatchedWeight(
|
||||||
|
Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) {
|
||||||
|
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||||
|
if (sourceType == resultType)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto denseAttr = getHostConstDenseElementsAttr(value);
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1);
|
||||||
|
const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2);
|
||||||
|
const int64_t targetRows = resultType.getDimSize(1);
|
||||||
|
const int64_t targetCols = resultType.getDimSize(2);
|
||||||
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType()));
|
||||||
|
|
||||||
|
for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) {
|
||||||
|
const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx);
|
||||||
|
const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols;
|
||||||
|
const int64_t targetBatchBase = batchIdx * targetRows * targetCols;
|
||||||
|
for (int64_t row = 0; row < sourceRows; ++row)
|
||||||
|
for (int64_t col = 0; col < sourceCols; ++col)
|
||||||
|
resultValues[targetBatchBase + row * targetCols + col] = sourceValues[sourceBatchBase + row * sourceCols + col];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||||
|
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractBatchedATile(Value a,
|
||||||
|
int64_t sourceBatchCount,
|
||||||
|
Value batch,
|
||||||
|
Value row,
|
||||||
|
Value kOffset,
|
||||||
|
RankedTensorType aTileType,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType());
|
||||||
|
SmallVector<OpFoldResult> offsets {
|
||||||
|
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset};
|
||||||
|
SmallVector<OpFoldResult> sizes {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))};
|
||||||
|
auto slice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
aTileType,
|
||||||
|
slice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractBatchedBTile(Value b,
|
||||||
|
int64_t sourceBatchCount,
|
||||||
|
Value batch,
|
||||||
|
Value kOffset,
|
||||||
|
Value hOffset,
|
||||||
|
RankedTensorType bTileType,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto bSliceType =
|
||||||
|
RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType());
|
||||||
|
SmallVector<OpFoldResult> offsets {
|
||||||
|
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset};
|
||||||
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(bTileType.getDimSize(0)),
|
||||||
|
rewriter.getIndexAttr(bTileType.getDimSize(1))};
|
||||||
|
auto slice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, bSliceType, b, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
bTileType,
|
||||||
|
slice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value getBatchLaneIndex(
|
||||||
|
Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) {
|
||||||
|
return floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices);
|
||||||
|
}
|
||||||
|
|
||||||
|
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
|
||||||
|
Value b,
|
||||||
|
RankedTensorType aType,
|
||||||
|
int64_t aBatchCount,
|
||||||
|
RankedTensorType bType,
|
||||||
|
int64_t bBatchCount,
|
||||||
|
RankedTensorType partialPiecesType,
|
||||||
|
int64_t numOutRows,
|
||||||
|
int64_t numKSlices,
|
||||||
|
int64_t numOutHSlices,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||||
|
auto batchOp = createSpatComputeBatch(
|
||||||
|
rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {partialPiecesType},
|
||||||
|
laneCount,
|
||||||
|
ValueRange {b},
|
||||||
|
ValueRange {a},
|
||||||
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
|
Value row = modIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
||||||
|
Value outerLane = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
||||||
|
Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||||
|
Value sliceLane = modIndexByConstant(rewriter, loc, outerLane, numKSlices * numOutHSlices);
|
||||||
|
Value kSlice = modIndexByConstant(rewriter, loc, sliceLane, numKSlices);
|
||||||
|
Value hSlice = floorDivIndexByConstant(rewriter, loc, sliceLane, numKSlices);
|
||||||
|
Value kOffset =
|
||||||
|
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice, crossbarSize.getValue());
|
||||||
|
Value hOffset =
|
||||||
|
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
|
||||||
|
|
||||||
|
auto aTileType =
|
||||||
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||||
|
auto bTileType = RankedTensorType::get(
|
||||||
|
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
||||||
|
bType.getElementType());
|
||||||
|
auto pieceType =
|
||||||
|
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||||
|
|
||||||
|
Value aTile =
|
||||||
|
extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc);
|
||||||
|
Value bTile =
|
||||||
|
extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc);
|
||||||
|
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||||
|
createParallelInsertSliceIntoBatchOutput(
|
||||||
|
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
|
||||||
|
});
|
||||||
|
assert(succeeded(batchOp) && "expected batched MatMul VMM construction to succeed");
|
||||||
|
return *batchOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractDynamicBatchedBColumn(Value matrix,
|
||||||
|
int64_t sourceBatchCount,
|
||||||
|
Value batch,
|
||||||
|
Value column,
|
||||||
|
RankedTensorType vectorType,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
||||||
|
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||||
|
: OpFoldResult(batch),
|
||||||
|
rewriter.getIndexAttr(0),
|
||||||
|
column};
|
||||||
|
SmallVector<OpFoldResult> sizes {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value columnSlice = tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides);
|
||||||
|
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||||
|
Value collapsed = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
collapsedType,
|
||||||
|
columnSlice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1, 2}
|
||||||
|
})
|
||||||
|
.getResult();
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
vectorType,
|
||||||
|
collapsed,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1}
|
||||||
|
})
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractDynamicBatchedBRow(Value matrix,
|
||||||
|
int64_t sourceBatchCount,
|
||||||
|
Value batch,
|
||||||
|
Value row,
|
||||||
|
RankedTensorType vectorType,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||||
|
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||||
|
: OpFoldResult(batch),
|
||||||
|
row,
|
||||||
|
rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||||
|
auto rowSlice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
vectorType,
|
||||||
|
rowSlice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractDynamicBatchedRowVector(Value matrix,
|
||||||
|
int64_t sourceBatchCount,
|
||||||
|
Value batch,
|
||||||
|
Value row,
|
||||||
|
RankedTensorType vectorType,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||||
|
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||||
|
: OpFoldResult(batch),
|
||||||
|
row,
|
||||||
|
rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||||
|
auto rowSlice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
vectorType,
|
||||||
|
rowSlice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
|
||||||
|
int64_t aBatchCount,
|
||||||
|
Value b,
|
||||||
|
int64_t bBatchCount,
|
||||||
|
RankedTensorType aType,
|
||||||
|
RankedTensorType bType,
|
||||||
|
RankedTensorType scalarPiecesType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
bool bAlreadyTransposed,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
const int64_t numBatches = outType.getDimSize(0);
|
||||||
|
const int64_t numOutRows = outType.getDimSize(1);
|
||||||
|
const int64_t numOutCols = outType.getDimSize(2);
|
||||||
|
const int64_t reductionSize = aType.getDimSize(2);
|
||||||
|
const int64_t laneCount = numBatches * numOutRows * numOutCols;
|
||||||
|
auto batchOp = createSpatComputeBatch(
|
||||||
|
rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {scalarPiecesType},
|
||||||
|
laneCount,
|
||||||
|
ValueRange {},
|
||||||
|
ValueRange {a, b},
|
||||||
|
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||||
|
Value batch = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols);
|
||||||
|
Value batchLane = modIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols);
|
||||||
|
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols);
|
||||||
|
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols);
|
||||||
|
|
||||||
|
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||||
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
|
Value aVector =
|
||||||
|
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
|
||||||
|
Value bVector =
|
||||||
|
bAlreadyTransposed
|
||||||
|
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
|
||||||
|
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
||||||
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||||
|
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
createParallelInsertSliceIntoBatchOutput(
|
||||||
|
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||||
|
});
|
||||||
|
assert(succeeded(batchOp) && "expected batched MatMul VVDMul construction to succeed");
|
||||||
|
return *batchOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createBatchedDynamicOutputCompute(Value scalarPieces,
|
||||||
|
RankedTensorType scalarPiecesType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
const int64_t laneCount = scalarPiecesType.getDimSize(0);
|
||||||
|
const int64_t numOutRows = outType.getDimSize(1);
|
||||||
|
const int64_t numOutCols = outType.getDimSize(2);
|
||||||
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
|
auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType());
|
||||||
|
|
||||||
|
auto computeOp =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) {
|
||||||
|
Value outputInit =
|
||||||
|
tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||||
|
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||||
|
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||||
|
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
|
Value lane = loop.getInductionVar();
|
||||||
|
Value outputAcc = loop.getRegionIterArgs().front();
|
||||||
|
Value batch = floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols);
|
||||||
|
Value batchLane = modIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols);
|
||||||
|
Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols);
|
||||||
|
Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols);
|
||||||
|
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value scalar = tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputScalarType,
|
||||||
|
scalar,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
SmallVector<OpFoldResult> outputOffsets {batch, row, column};
|
||||||
|
SmallVector<OpFoldResult> outputSizes {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
scf::YieldOp::create(
|
||||||
|
rewriter,
|
||||||
|
loc,
|
||||||
|
tensor::InsertSliceOp::create(
|
||||||
|
rewriter, loc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||||
|
.getResult());
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(loop);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto transposeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
|
});
|
||||||
|
return transposeCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractBatchedReductionPiece(Value partialPiecesArg,
|
||||||
|
Value batch,
|
||||||
|
Value hSlice,
|
||||||
|
int64_t kSlice,
|
||||||
|
RankedTensorType pieceType,
|
||||||
|
int64_t numKSlices,
|
||||||
|
int64_t numOutHSlices,
|
||||||
|
int64_t numOutRows,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
Value batchOffset = multiplyIndexByConstant(
|
||||||
|
rewriter, rewriter.getInsertionBlock()->getParentOp(), batch, numOutRows * numKSlices * numOutHSlices);
|
||||||
|
Value hOffset =
|
||||||
|
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, numKSlices * numOutRows);
|
||||||
|
Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows);
|
||||||
|
Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset);
|
||||||
|
Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset);
|
||||||
|
SmallVector<OpFoldResult> offsets {pieceOffset, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||||
|
return tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, pieceType, partialPiecesArg, offsets, sizes, getUnitStrides(rewriter, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg,
|
||||||
|
Value batch,
|
||||||
|
Value hSlice,
|
||||||
|
RankedTensorType pieceType,
|
||||||
|
int64_t numKSlices,
|
||||||
|
int64_t numOutHSlices,
|
||||||
|
int64_t numOutRows,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
SmallVector<Value> activePieces;
|
||||||
|
activePieces.reserve(numKSlices);
|
||||||
|
for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice)
|
||||||
|
activePieces.push_back(extractBatchedReductionPiece(
|
||||||
|
partialPiecesArg, batch, hSlice, kSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc));
|
||||||
|
|
||||||
|
while (activePieces.size() > 1) {
|
||||||
|
SmallVector<Value> nextPieces;
|
||||||
|
nextPieces.reserve((activePieces.size() + 1) / 2);
|
||||||
|
for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2)
|
||||||
|
nextPieces.push_back(
|
||||||
|
spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1])
|
||||||
|
.getResult());
|
||||||
|
if (activePieces.size() % 2 != 0)
|
||||||
|
nextPieces.push_back(activePieces.back());
|
||||||
|
activePieces = std::move(nextPieces);
|
||||||
|
}
|
||||||
|
|
||||||
|
return activePieces.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createBatchedReductionCompute(Value partialPieces,
|
||||||
|
RankedTensorType partialPiecesType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
RankedTensorType paddedOutType,
|
||||||
|
int64_t numBatches,
|
||||||
|
int64_t numKSlices,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto computeOp = createSpatCompute<1>(
|
||||||
|
rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) {
|
||||||
|
const int64_t numOutRows = outType.getDimSize(1);
|
||||||
|
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue());
|
||||||
|
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||||
|
partialPiecesType.getElementType());
|
||||||
|
auto outputSliceType = RankedTensorType::get({1, numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||||
|
partialPiecesType.getElementType());
|
||||||
|
|
||||||
|
Value outputInit =
|
||||||
|
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
|
||||||
|
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||||
|
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||||
|
Value cNumBatches = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numBatches);
|
||||||
|
Value cNumOutHSlices =
|
||||||
|
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||||
|
|
||||||
|
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit});
|
||||||
|
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||||
|
Value batch = batchLoop.getInductionVar();
|
||||||
|
Value batchAcc = batchLoop.getRegionIterArgs().front();
|
||||||
|
|
||||||
|
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cNumOutHSlices, c1, ValueRange {batchAcc});
|
||||||
|
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||||
|
Value hSlice = hLoop.getInductionVar();
|
||||||
|
Value outputAcc = hLoop.getRegionIterArgs().front();
|
||||||
|
|
||||||
|
Value reduced = reduceBatchedPartialPiecesForHSlice(
|
||||||
|
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc);
|
||||||
|
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputSliceType,
|
||||||
|
reduced,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
Value hOffset =
|
||||||
|
multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
|
||||||
|
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
|
||||||
|
SmallVector<OpFoldResult> outputSizes {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||||
|
scf::YieldOp::create(
|
||||||
|
rewriter,
|
||||||
|
loc,
|
||||||
|
tensor::InsertSliceOp::create(
|
||||||
|
rewriter, loc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||||
|
.getResult());
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(hLoop);
|
||||||
|
scf::YieldOp::create(rewriter, loc, hLoop.getResult(0));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(batchLoop);
|
||||||
|
Value paddedOutput = batchLoop.getResult(0);
|
||||||
|
Value result = paddedOutput;
|
||||||
|
if (paddedOutType != outType) {
|
||||||
|
SmallVector<OpFoldResult> outputOffsets {
|
||||||
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(numBatches),
|
||||||
|
rewriter.getIndexAttr(outType.getDimSize(1)),
|
||||||
|
rewriter.getIndexAttr(outType.getDimSize(2))};
|
||||||
|
result = tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3));
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MatMulShapeInfo {
|
||||||
|
RankedTensorType lhsType;
|
||||||
|
RankedTensorType rhsType;
|
||||||
|
RankedTensorType outType;
|
||||||
|
SmallVector<int64_t> batchShape;
|
||||||
|
int64_t lhsBatch;
|
||||||
|
int64_t rhsBatch;
|
||||||
|
int64_t batch;
|
||||||
|
int64_t m;
|
||||||
|
int64_t k;
|
||||||
|
int64_t n;
|
||||||
|
};
|
||||||
|
|
||||||
|
static FailureOr<MatMulShapeInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
|
||||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||||
@@ -176,10 +692,10 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||||
if (failed(batchShape))
|
if (failed(batchShape))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||||
|
|
||||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||||
@@ -198,30 +714,38 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n};
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||||
|
if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2)
|
||||||
|
return failure();
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
||||||
|
|
||||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
||||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
||||||
int64_t lhsBatchForGemm = lhsBatch;
|
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
|
||||||
int64_t rhsBatchForGemm = rhsBatch;
|
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
|
||||||
int64_t gemmM = m;
|
int64_t gemmM = shapeInfo->m;
|
||||||
int64_t gemmK = k;
|
int64_t gemmK = shapeInfo->k;
|
||||||
int64_t gemmN = n;
|
int64_t gemmN = shapeInfo->n;
|
||||||
if (useTransposedForm) {
|
if (useTransposedForm) {
|
||||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||||
lhsBatchForGemm = rhsBatch;
|
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||||
rhsBatchForGemm = lhsBatch;
|
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||||
gemmM = n;
|
gemmM = shapeInfo->n;
|
||||||
gemmN = m;
|
gemmN = shapeInfo->m;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
auto gemmType = RankedTensorType::get({gemmM, gemmN}, shapeInfo->outType.getElementType());
|
||||||
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
|
||||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
if (outType.getRank() == 2) {
|
|
||||||
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||||
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||||
@@ -237,8 +761,9 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
.getY();
|
.getY();
|
||||||
if (useTransposedForm) {
|
if (useTransposedForm) {
|
||||||
auto transposeCompute =
|
auto transposeCompute =
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
|
||||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
Value transposed =
|
||||||
|
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
});
|
});
|
||||||
gemmResult = transposeCompute.getResult(0);
|
gemmResult = transposeCompute.getResult(0);
|
||||||
@@ -246,48 +771,115 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
rewriter.replaceOp(matmulOp, gemmResult);
|
rewriter.replaceOp(matmulOp, gemmResult);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
SmallVector<Value> batchResults;
|
struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||||
batchResults.reserve(batch);
|
using OpRewritePattern::OpRewritePattern;
|
||||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
|
||||||
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
if (failed(shapeInfo))
|
||||||
loc,
|
return failure();
|
||||||
gemmType,
|
if (shapeInfo->outType.getRank() == 2)
|
||||||
lhsMatrix,
|
return failure();
|
||||||
rhsMatrix,
|
|
||||||
none,
|
Location loc = matmulOp.getLoc();
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getBoolAttr(false),
|
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
||||||
rewriter.getBoolAttr(false))
|
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
||||||
.getY();
|
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
|
||||||
auto batchResultCompute =
|
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
int64_t gemmM = shapeInfo->m;
|
||||||
Value resultMatrix = input;
|
int64_t gemmK = shapeInfo->k;
|
||||||
|
int64_t gemmN = shapeInfo->n;
|
||||||
if (useTransposedForm) {
|
if (useTransposedForm) {
|
||||||
resultMatrix = ONNXTransposeOp::create(rewriter,
|
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||||
loc,
|
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||||
input,
|
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||||
rewriter.getI64ArrayAttr({1, 0}));
|
gemmM = shapeInfo->n;
|
||||||
}
|
gemmN = shapeInfo->m;
|
||||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
batchedOutType,
|
|
||||||
resultMatrix,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1},
|
|
||||||
{2}
|
|
||||||
});
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
|
||||||
});
|
|
||||||
batchResults.push_back(batchResultCompute.getResult(0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||||
|
auto lhsBatchedType = cast<RankedTensorType>(lhs.getType());
|
||||||
|
auto rhsBatchedType = cast<RankedTensorType>(rhs.getType());
|
||||||
|
auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType());
|
||||||
|
|
||||||
|
if (isCompileTimeComputable(rhs)) {
|
||||||
|
const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue());
|
||||||
|
const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue());
|
||||||
|
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
|
||||||
|
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
|
||||||
|
auto paddedLhsType = RankedTensorType::get(
|
||||||
|
{lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding());
|
||||||
|
auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols},
|
||||||
|
rhsBatchedType.getElementType(),
|
||||||
|
rhsBatchedType.getEncoding());
|
||||||
|
auto paddedOutType =
|
||||||
|
RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType());
|
||||||
|
|
||||||
|
auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter);
|
||||||
|
if (succeeded(paddedRhs)) {
|
||||||
|
Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc);
|
||||||
|
const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices;
|
||||||
|
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
||||||
|
shapeInfo->outType.getElementType());
|
||||||
|
auto batchOp = createBatchedVmmBatch(paddedLhs,
|
||||||
|
*paddedRhs,
|
||||||
|
paddedLhsType,
|
||||||
|
lhsBatchForGemm,
|
||||||
|
paddedRhsType,
|
||||||
|
rhsBatchForGemm,
|
||||||
|
partialPiecesType,
|
||||||
|
gemmM,
|
||||||
|
numKSlices,
|
||||||
|
numOutHSlices,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
Value result = createBatchedReductionCompute(batchOp.getResult(0),
|
||||||
|
partialPiecesType,
|
||||||
|
directOutType,
|
||||||
|
paddedOutType,
|
||||||
|
shapeInfo->batch,
|
||||||
|
numKSlices,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
if (useTransposedForm)
|
||||||
|
result = transposeBatchedOutput(
|
||||||
|
result,
|
||||||
|
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||||
|
rewriter.replaceOp(matmulOp, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const int64_t laneCount = shapeInfo->batch * gemmM * gemmN;
|
||||||
|
auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType());
|
||||||
|
auto batchOp = createBatchedVvdmulBatch(lhs,
|
||||||
|
lhsBatchForGemm,
|
||||||
|
rhs,
|
||||||
|
rhsBatchForGemm,
|
||||||
|
lhsBatchedType,
|
||||||
|
rhsBatchedType,
|
||||||
|
scalarPiecesType,
|
||||||
|
directOutType,
|
||||||
|
false,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
Value result =
|
||||||
|
createBatchedDynamicOutputCompute(batchOp.getResult(0), scalarPiecesType, directOutType, rewriter, loc);
|
||||||
|
if (useTransposedForm)
|
||||||
|
result = transposeBatchedOutput(
|
||||||
|
result,
|
||||||
|
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||||
rewriter.replaceOp(matmulOp, result);
|
rewriter.replaceOp(matmulOp, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -296,7 +888,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<MatMulToGemm>(ctx);
|
patterns.insert<MatMulToGemm, MatMulBatchedToSpatialComputes>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -49,11 +49,15 @@ python3 validation/operations/gen_tests.py
|
|||||||
## MatMul
|
## MatMul
|
||||||
|
|
||||||
| Test | Directory | A input | B tensor | Output | Notes |
|
| Test | Directory | A input | B tensor | Output | Notes |
|
||||||
|------------|---------------------|---------|----------|---------|------------------------------------|
|
|---------------------|----------------------------------|----------|----------|---------|-------------------------------------------------|
|
||||||
| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path |
|
| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path |
|
||||||
| Left constant | `matmul/left_constant` | [2,3] | [3,4] | [2,4] | Constant LHS transpose rewrite path |
|
| Left constant | `matmul/left_constant` | [2,3] | [3,4] | [2,4] | Constant LHS transpose rewrite path |
|
||||||
| Dynamic | `matmul/dynamic` | [2,3] | [3,4] | [2,4] | Runtime matrix operands |
|
| Dynamic | `matmul/dynamic` | [2,3] | [3,4] | [2,4] | Runtime matrix operands |
|
||||||
| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch MatMul rewrite path |
|
| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch direct batched lowering |
|
||||||
|
| Batched 3D dynamic | `matmul/batched_3d_dynamic` | [2,2,3] | [2,3,4] | [2,2,4] | Batched runtime operands |
|
||||||
|
| Batched left const | `matmul/batched_left_constant` | [2,2,3] | [2,3,4] | [2,2,4] | Batched constant-LHS transpose path |
|
||||||
|
| Batched RHS broadcast | `matmul/batched_rhs_broadcast` | [2,2,3] | [3,4] | [2,2,4] | Rank-2 RHS broadcast across batch |
|
||||||
|
| Batched LHS broadcast | `matmul/batched_lhs_broadcast` | [2,3] | [2,3,4] | [2,2,4] | Rank-2 LHS broadcast across batched RHS |
|
||||||
|
|
||||||
## Gemv
|
## Gemv
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -421,6 +421,53 @@ def matmul_batched_3d():
|
|||||||
save_model(model, "matmul/batched_3d", "matmul_batched_3d.onnx")
|
save_model(model, "matmul/batched_3d", "matmul_batched_3d.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_batched_3d_dynamic():
|
||||||
|
"""Batched 3D MatMul with both operands provided at runtime."""
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 2, 3])
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
|
||||||
|
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "matmul_batched_3d_dynamic", [A, B], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "matmul/batched_3d_dynamic", "matmul_batched_3d_dynamic.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_batched_left_constant():
|
||||||
|
"""Batched 3D MatMul with constant LHS and runtime RHS."""
|
||||||
|
rng = np.random.default_rng(70)
|
||||||
|
A = numpy_helper.from_array(rng.uniform(-1, 1, (2, 2, 3)).astype(np.float32), name="A")
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
|
||||||
|
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "matmul_batched_left_constant", [B], [Y], initializer=[A])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "matmul/batched_left_constant", "matmul_batched_left_constant.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_batched_rhs_broadcast():
|
||||||
|
"""Batched 3D MatMul with 2D constant RHS broadcast across batch."""
|
||||||
|
rng = np.random.default_rng(71)
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 2, 3])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
|
||||||
|
B = numpy_helper.from_array(rng.uniform(-1, 1, (3, 4)).astype(np.float32), name="B")
|
||||||
|
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "matmul_batched_rhs_broadcast", [A], [Y], initializer=[B])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "matmul/batched_rhs_broadcast", "matmul_batched_rhs_broadcast.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_batched_lhs_broadcast():
|
||||||
|
"""Batched 3D MatMul with 2D runtime LHS broadcast across batched RHS."""
|
||||||
|
rng = np.random.default_rng(72)
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
|
||||||
|
B = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 4)).astype(np.float32), name="B")
|
||||||
|
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "matmul_batched_lhs_broadcast", [A], [Y], initializer=[B])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "matmul/batched_lhs_broadcast", "matmul_batched_lhs_broadcast.onnx")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Pooling tests
|
# Pooling tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -972,6 +1019,10 @@ if __name__ == "__main__":
|
|||||||
matmul_left_constant()
|
matmul_left_constant()
|
||||||
matmul_dynamic()
|
matmul_dynamic()
|
||||||
matmul_batched_3d()
|
matmul_batched_3d()
|
||||||
|
matmul_batched_3d_dynamic()
|
||||||
|
matmul_batched_left_constant()
|
||||||
|
matmul_batched_rhs_broadcast()
|
||||||
|
matmul_batched_lhs_broadcast()
|
||||||
|
|
||||||
print("\nGenerating Pooling tests:")
|
print("\nGenerating Pooling tests:")
|
||||||
maxpool_basic()
|
maxpool_basic()
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user