#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include #include #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static bool haveStaticPositiveShape(ArrayRef shape) { return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); } static int64_t getStaticShapeElementCount(ArrayRef shape) { return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies {}); } static FailureOr> inferSupportedBatchShape(ArrayRef lhsBatchShape, ArrayRef rhsBatchShape) { if (lhsBatchShape.empty()) return SmallVector(rhsBatchShape.begin(), rhsBatchShape.end()); if (rhsBatchShape.empty()) return SmallVector(lhsBatchShape.begin(), lhsBatchShape.end()); if (!llvm::equal(lhsBatchShape, rhsBatchShape)) return failure(); return SmallVector(lhsBatchShape.begin(), lhsBatchShape.end()); } static Value collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); if (type.getRank() == 2 || type.getRank() == 3) return value; auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); SmallVector reassociation = { ReassociationIndices {}, ReassociationIndices {static_cast(type.getRank() - 2)}, ReassociationIndices {static_cast(type.getRank() - 1)} }; for (int64_t dim = 0; dim < type.getRank() - 2; ++dim) reassociation.front().push_back(dim); auto buildCollapsed = [&](Value input) -> Value { return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation); }; if (isHostFoldableValue(value)) return buildCollapsed(value); auto collapseCompute = createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) { spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input)); }); return collapseCompute.getResult(0); } static Value expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) { if (cast(value.getType()) == outputType) return value; SmallVector reassociation = { ReassociationIndices {}, ReassociationIndices {static_cast(batchRank)}, ReassociationIndices {static_cast(batchRank + 1)} }; for (size_t dim = 0; dim < batchRank; ++dim) reassociation.front().push_back(static_cast(dim)); auto expandCompute = createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) { Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation); spatial::SpatYieldOp::create(rewriter, loc, expanded); }); return expandCompute.getResult(0); } static Value extractBatchMatrix(Value value, int64_t batchIndex, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); if (type.getRank() == 2) return value; auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType()); SmallVector offsets = { rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType()); auto buildMatrix = [&](Value input) -> Value { Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides); return tensor::CollapseShapeOp::create(rewriter, loc, matrixType, slice, SmallVector { {0, 1}, {2} }); }; if (isHostFoldableValue(value)) return buildMatrix(value); auto batchMatrixCompute = createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) { spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input)); }); return batchMatrixCompute.getResult(0); } static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); auto shape = type.getShape(); RankedTensorType transposedType; SmallVector perm; if (type.getRank() == 2) { transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); perm = {1, 0}; } else { transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); perm = {0, 2, 1}; } auto buildTranspose = [&](Value input) -> Value { return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); }; if (isHostFoldableValue(value)) return buildTranspose(value); auto transposeCompute = createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) { spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input)); }); return transposeCompute.getResult(0); } static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); auto shape = type.getShape(); RankedTensorType transposedType; SmallVector perm; if (type.getRank() == 2) { transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); perm = {1, 0}; } else { transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); perm = {0, 2, 1}; } auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); spatial::SpatYieldOp::create(rewriter, loc, transposed); }); return transposeCompute.getResult(0); } static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) { auto firstType = cast(inputs.front().getType()); SmallVector outputShape(firstType.getShape().begin(), firstType.getShape().end()); int64_t concatDimSize = 0; for (Value input : inputs) concatDimSize += cast(input.getType()).getDimSize(axis); outputShape[axis] = concatDimSize; auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); if (llvm::all_of(inputs, isHostFoldableValue)) return createSpatConcat(rewriter, loc, axis, inputs); auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args)); }); return concatCompute.getResult(0); } struct MatMulToGemm : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { auto lhsType = dyn_cast(matmulOp.getA().getType()); auto rhsType = dyn_cast(matmulOp.getB().getType()); auto outType = dyn_cast(matmulOp.getY().getType()); if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() || !outType.hasStaticShape()) return failure(); if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2) return failure(); if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape()) || !haveStaticPositiveShape(outType.getShape())) return failure(); SmallVector lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2); SmallVector rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2); auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); if (failed(batchShape)) return failure(); const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape); const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape); const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape); const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2); const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1); const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2); const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1); if (k != rhsK) return failure(); if (outType.getRank() == 2) { if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n) return failure(); } else { SmallVector outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2); if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m || outType.getDimSize(outType.getRank() - 1) != n) return failure(); } Location loc = matmulOp.getLoc(); bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB()); Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc); Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc); int64_t lhsBatchForGemm = lhsBatch; int64_t rhsBatchForGemm = rhsBatch; int64_t gemmM = m; int64_t gemmK = k; int64_t gemmN = n; if (useTransposedForm) { lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc); lhsBatchForGemm = rhsBatch; rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); rhsBatchForGemm = lhsBatch; gemmM = n; gemmN = m; } auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType()); auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType()); Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); if (outType.getRank() == 2) { Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); Value gemmResult = ONNXGemmOp::create(rewriter, loc, gemmType, lhsMatrix, rhsMatrix, none, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); if (useTransposedForm) { auto transposeCompute = createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) { Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0})); spatial::SpatYieldOp::create(rewriter, loc, transposed); }); gemmResult = transposeCompute.getResult(0); } rewriter.replaceOp(matmulOp, gemmResult); return success(); } SmallVector batchResults; batchResults.reserve(batch); for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); Value gemmResult = ONNXGemmOp::create(rewriter, loc, gemmType, lhsMatrix, rhsMatrix, none, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); auto batchResultCompute = createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) { Value resultMatrix = input; if (useTransposedForm) { resultMatrix = ONNXTransposeOp::create(rewriter, loc, RankedTensorType::get({m, n}, outType.getElementType()), input, rewriter.getI64ArrayAttr({1, 0})); } Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, batchedOutType, resultMatrix, SmallVector { {0, 1}, {2} }); spatial::SpatYieldOp::create(rewriter, loc, expanded); }); batchResults.push_back(batchResultCompute.getResult(0)); } Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc); result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc); rewriter.replaceOp(matmulOp, result); return success(); } }; } // namespace void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir