This commit is contained in:
@@ -19,14 +19,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
@@ -54,15 +46,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
|
||||
}
|
||||
|
||||
static Value
|
||||
@@ -76,12 +60,10 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
auto buildExpanded = [&](Value input) -> Value {
|
||||
return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
|
||||
};
|
||||
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
@@ -100,7 +82,7 @@ static Value extractBatchMatrix(Value value,
|
||||
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
|
||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||
auto buildMatrix = [&](Value input) -> Value {
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
@@ -114,14 +96,7 @@ static Value extractBatchMatrix(Value value,
|
||||
});
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildMatrix(value);
|
||||
|
||||
auto batchMatrixCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
||||
});
|
||||
return batchMatrixCompute.getResult(0);
|
||||
return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -138,18 +113,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -166,10 +130,11 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
@@ -203,8 +168,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
|
||||
Reference in New Issue
Block a user