#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static bool haveStaticPositiveShape(ArrayRef shape) { return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); } static Value extractBatchMatrix(Value value, int64_t batchIndex, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); if (type.getRank() == 2) return value; auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType()); SmallVector offsets = { rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides); auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType()); return tensor::CollapseShapeOp::create(rewriter, loc, matrixType, slice, SmallVector { {0, 1}, {2} }); } static bool isConstantLikeOperand(Value value) { llvm::SmallPtrSet visited; while (auto* definingOp = value.getDefiningOp()) { if (!visited.insert(definingOp).second) return false; if (definingOp->hasTrait()) return true; if (auto extractSliceOp = dyn_cast(definingOp)) { value = extractSliceOp.getSource(); continue; } if (auto expandShapeOp = dyn_cast(definingOp)) { value = expandShapeOp.getSrc(); continue; } if (auto collapseShapeOp = dyn_cast(definingOp)) { value = collapseShapeOp.getSrc(); continue; } if (auto transposeOp = dyn_cast(definingOp)) { value = transposeOp.getData(); continue; } return false; } return false; } static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); auto shape = type.getShape(); if (type.getRank() == 2) { auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0})); } auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1})); } static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); auto shape = type.getShape(); RankedTensorType transposedType; SmallVector perm; if (type.getRank() == 2) { transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); perm = {1, 0}; } else { transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); perm = {0, 2, 1}; } auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); spatial::SpatYieldOp::create(rewriter, loc, transposed); }); return transposeCompute.getResult(0); } struct MatMulToGemm : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { auto lhsType = dyn_cast(matmulOp.getA().getType()); auto rhsType = dyn_cast(matmulOp.getB().getType()); auto outType = dyn_cast(matmulOp.getY().getType()); if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() || !outType.hasStaticShape()) return failure(); if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3) || (outType.getRank() != 2 && outType.getRank() != 3)) return failure(); if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape()) || !haveStaticPositiveShape(outType.getShape())) return failure(); const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1; const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1; const int64_t batch = std::max(lhsBatch, rhsBatch); if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch)) return failure(); const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0); const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1); const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0); const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1); if (k != rhsK) return failure(); if (outType.getRank() == 2) { if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n) return failure(); } else { if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n) return failure(); } Location loc = matmulOp.getLoc(); bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB()); Value lhs = matmulOp.getA(); Value rhs = matmulOp.getB(); int64_t lhsBatchForGemm = lhsBatch; int64_t rhsBatchForGemm = rhsBatch; int64_t gemmM = m; int64_t gemmK = k; int64_t gemmN = n; if (useTransposedForm) { lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc); lhsBatchForGemm = rhsBatch; rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); rhsBatchForGemm = lhsBatch; gemmM = n; gemmN = m; } auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType()); auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType()); Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); if (outType.getRank() == 2) { Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); Value gemmResult = ONNXGemmOp::create(rewriter, loc, gemmType, lhsMatrix, rhsMatrix, none, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); if (useTransposedForm) gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0})); rewriter.replaceOp(matmulOp, gemmResult); return success(); } SmallVector batchResults; batchResults.reserve(batch); for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); Value gemmResult = ONNXGemmOp::create(rewriter, loc, gemmType, lhsMatrix, rhsMatrix, none, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); if (useTransposedForm) gemmResult = ONNXTransposeOp::create( rewriter, loc, RankedTensorType::get({m, n}, outType.getElementType()), gemmResult, rewriter.getI64ArrayAttr({1, 0})); batchResults.push_back(tensor::ExpandShapeOp::create(rewriter, loc, batchedOutType, gemmResult, SmallVector { {0, 1}, {2} })); } Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults); rewriter.replaceOp(matmulOp, result); return success(); } }; } // namespace void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir