fix reshape lowering add support for grouped-convolution lowering quieter verifier with capped error messages
This commit is contained in:
@@ -2,8 +2,12 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
||||
if (rhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
||||
return failure();
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
}
|
||||
|
||||
static Value collapseBatchDims(Value value,
|
||||
int64_t batchSize,
|
||||
int64_t rows,
|
||||
int64_t cols,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2 || type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto collapsedType =
|
||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||
};
|
||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||
reassociation.front().push_back(dim);
|
||||
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value expandBatchDims(Value value,
|
||||
RankedTensorType outputType,
|
||||
size_t batchRank,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||
return value;
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||
};
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
int64_t batchIndex,
|
||||
int64_t batchSize,
|
||||
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
return failure();
|
||||
|
||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||
|
||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
|
||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user