d09e76c8f9
Validate Operations / validate-operations (push) Has been cancelled
fix reshape lowering add support for grouped-convolution lowering quieter verifier with capped error messages
351 lines
15 KiB
C++
351 lines
15 KiB
C++
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <functional>
|
|
#include <numeric>
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
|
}
|
|
|
|
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
|
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
|
}
|
|
|
|
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
|
ArrayRef<int64_t> rhsBatchShape) {
|
|
if (lhsBatchShape.empty())
|
|
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
|
if (rhsBatchShape.empty())
|
|
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
|
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
|
return failure();
|
|
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
|
}
|
|
|
|
static Value collapseBatchDims(Value value,
|
|
int64_t batchSize,
|
|
int64_t rows,
|
|
int64_t cols,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
if (type.getRank() == 2 || type.getRank() == 3)
|
|
return value;
|
|
|
|
auto collapsedType =
|
|
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
|
SmallVector<ReassociationIndices> reassociation = {
|
|
ReassociationIndices {},
|
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
|
};
|
|
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
|
reassociation.front().push_back(dim);
|
|
|
|
auto buildCollapsed = [&](Value input) -> Value {
|
|
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
|
};
|
|
|
|
if (isHostFoldableValue(value))
|
|
return buildCollapsed(value);
|
|
|
|
auto collapseCompute =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
|
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
|
});
|
|
return collapseCompute.getResult(0);
|
|
}
|
|
|
|
static Value expandBatchDims(Value value,
|
|
RankedTensorType outputType,
|
|
size_t batchRank,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (cast<RankedTensorType>(value.getType()) == outputType)
|
|
return value;
|
|
|
|
SmallVector<ReassociationIndices> reassociation = {
|
|
ReassociationIndices {},
|
|
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
|
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
|
};
|
|
for (size_t dim = 0; dim < batchRank; ++dim)
|
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
|
|
|
auto expandCompute =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
|
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
|
});
|
|
return expandCompute.getResult(0);
|
|
}
|
|
|
|
static Value extractBatchMatrix(Value value,
|
|
int64_t batchIndex,
|
|
int64_t batchSize,
|
|
int64_t rows,
|
|
int64_t cols,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
if (type.getRank() == 2)
|
|
return value;
|
|
|
|
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
|
|
SmallVector<OpFoldResult> offsets = {
|
|
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
|
auto buildMatrix = [&](Value input) -> Value {
|
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
matrixType,
|
|
slice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
};
|
|
|
|
if (isHostFoldableValue(value))
|
|
return buildMatrix(value);
|
|
|
|
auto batchMatrixCompute =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
|
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
|
});
|
|
return batchMatrixCompute.getResult(0);
|
|
}
|
|
|
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
auto shape = type.getShape();
|
|
RankedTensorType transposedType;
|
|
SmallVector<int64_t> perm;
|
|
if (type.getRank() == 2) {
|
|
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
|
perm = {1, 0};
|
|
}
|
|
else {
|
|
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
|
perm = {0, 2, 1};
|
|
}
|
|
|
|
auto buildTranspose = [&](Value input) -> Value {
|
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
|
};
|
|
|
|
if (isHostFoldableValue(value))
|
|
return buildTranspose(value);
|
|
|
|
auto transposeCompute =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
|
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
|
});
|
|
return transposeCompute.getResult(0);
|
|
}
|
|
|
|
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
auto shape = type.getShape();
|
|
RankedTensorType transposedType;
|
|
SmallVector<int64_t> perm;
|
|
if (type.getRank() == 2) {
|
|
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
|
perm = {1, 0};
|
|
}
|
|
else {
|
|
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
|
perm = {0, 2, 1};
|
|
}
|
|
|
|
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
});
|
|
return transposeCompute.getResult(0);
|
|
}
|
|
|
|
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
|
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
|
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
|
int64_t concatDimSize = 0;
|
|
for (Value input : inputs)
|
|
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
|
outputShape[axis] = concatDimSize;
|
|
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
|
|
|
if (llvm::all_of(inputs, isHostFoldableValue))
|
|
return createSpatConcat(rewriter, loc, axis, inputs);
|
|
|
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
|
});
|
|
return concatCompute.getResult(0);
|
|
}
|
|
|
|
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
|
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
|
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
|
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
|
|| !outType.hasStaticShape())
|
|
return failure();
|
|
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
|
return failure();
|
|
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
|
|| !haveStaticPositiveShape(outType.getShape()))
|
|
return failure();
|
|
|
|
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
|
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
|
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
|
if (failed(batchShape))
|
|
return failure();
|
|
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
|
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
|
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
|
|
|
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
|
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
|
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
|
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
|
if (k != rhsK)
|
|
return failure();
|
|
|
|
if (outType.getRank() == 2) {
|
|
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
|
return failure();
|
|
}
|
|
else {
|
|
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
|
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
|
|| outType.getDimSize(outType.getRank() - 1) != n)
|
|
return failure();
|
|
}
|
|
|
|
Location loc = matmulOp.getLoc();
|
|
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
|
|
|
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
|
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
|
int64_t lhsBatchForGemm = lhsBatch;
|
|
int64_t rhsBatchForGemm = rhsBatch;
|
|
int64_t gemmM = m;
|
|
int64_t gemmK = k;
|
|
int64_t gemmN = n;
|
|
if (useTransposedForm) {
|
|
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
|
lhsBatchForGemm = rhsBatch;
|
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
|
rhsBatchForGemm = lhsBatch;
|
|
gemmM = n;
|
|
gemmN = m;
|
|
}
|
|
|
|
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
|
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
|
|
|
if (outType.getRank() == 2) {
|
|
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
|
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
gemmType,
|
|
lhsMatrix,
|
|
rhsMatrix,
|
|
none,
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getBoolAttr(false),
|
|
rewriter.getBoolAttr(false))
|
|
.getY();
|
|
if (useTransposedForm) {
|
|
auto transposeCompute =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
|
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
});
|
|
gemmResult = transposeCompute.getResult(0);
|
|
}
|
|
rewriter.replaceOp(matmulOp, gemmResult);
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Value> batchResults;
|
|
batchResults.reserve(batch);
|
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
|
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
|
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
gemmType,
|
|
lhsMatrix,
|
|
rhsMatrix,
|
|
none,
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getBoolAttr(false),
|
|
rewriter.getBoolAttr(false))
|
|
.getY();
|
|
auto batchResultCompute =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
|
Value resultMatrix = input;
|
|
if (useTransposedForm) {
|
|
resultMatrix = ONNXTransposeOp::create(rewriter,
|
|
loc,
|
|
RankedTensorType::get({m, n}, outType.getElementType()),
|
|
input,
|
|
rewriter.getI64ArrayAttr({1, 0}));
|
|
}
|
|
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
batchedOutType,
|
|
resultMatrix,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
|
});
|
|
batchResults.push_back(batchResultCompute.getResult(0));
|
|
}
|
|
|
|
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
|
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
|
rewriter.replaceOp(matmulOp, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<MatMulToGemm>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|