#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct MatMulRank3ToGemm : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { auto lhsType = dyn_cast(matmulOp.getA().getType()); auto rhsType = dyn_cast(matmulOp.getB().getType()); auto outType = dyn_cast(matmulOp.getY().getType()); if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() || !outType.hasStaticShape()) return failure(); if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3) return failure(); const int64_t batch = rhsType.getDimSize(0); const int64_t k = rhsType.getDimSize(1); const int64_t n = rhsType.getDimSize(2); const int64_t m = lhsType.getDimSize(0); if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n) return failure(); Location loc = matmulOp.getLoc(); auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType()); auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType()); auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType()); auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType()); auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType()); auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType()); Value lhsTransposed = ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0})); Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); SmallVector gemmRows; gemmRows.reserve(batch * n); for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { for (int64_t colIdx = 0; colIdx < n; colIdx++) { SmallVector offsets = { rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)}; SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)}; SmallVector strides = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value rhsSlice = tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides); Value rhsRow = tensor::CollapseShapeOp::create(rewriter, loc, rhsRowType, rhsSlice, SmallVector { {0}, {1, 2} }); auto gemmOp = ONNXGemmOp::create(rewriter, loc, gemmRowType, rhsRow, lhsTransposed, none, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); gemmRows.push_back(gemmOp.getY()); } } auto concatComputeOp = spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector(), gemmRows); auto* concatBlock = new Block(); for (Value gemmRow : gemmRows) concatBlock->addArgument(gemmRow.getType(), loc); concatComputeOp.getBody().push_back(concatBlock); rewriter.setInsertionPointToStart(concatBlock); auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments()); spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); rewriter.setInsertionPointAfter(concatComputeOp); Value gemmOut = concatComputeOp.getResult(0); Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter, loc, gemmExpandedType, gemmOut, SmallVector { {0, 1}, {2} }); Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1})); rewriter.replaceOp(matmulOp, result); return success(); } }; } // namespace void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir