From 5545b0f672c2cb7698a1a0d3deb41e3360d97a07 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Thu, 23 Apr 2026 14:44:30 +0200 Subject: [PATCH] fix MatMul pattern non-contiguous extract_slices --- src/PIM/Common/PimCommon.cpp | 31 ++ src/PIM/Common/PimCommon.hpp | 1 + .../ONNXToSpatial/ONNXToSpatialPass.cpp | 9 +- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 271 +++++++++++++----- .../SpatialToPim/SpatialToPimPass.cpp | 12 +- .../OpBufferizationInterfaces.cpp | 8 +- 6 files changed, 254 insertions(+), 78 deletions(-) diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index 3a9d382..3ca1839 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -1,10 +1,12 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_os_ostream.h" #include @@ -137,6 +139,35 @@ bool isSpatialMvmVmmWeightUse(OpOperand& use) { return hasMvmVmmWeightUse(computeOp, operandIndex); } +bool hasOnlySpatialMvmVmmWeightUses(Value value) { + SmallPtrSet visited; + auto walkUses = [&](Value currentValue, auto& self) -> bool { + if (!visited.insert(currentValue).second) + return true; + if (currentValue.use_empty()) + return false; + + return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) { + if (isSpatialMvmVmmWeightUse(use)) + return true; + + Operation* user = use.getOwner(); + if (auto extractSliceOp = dyn_cast(user)) + return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self); + if (auto expandShapeOp = dyn_cast(user)) + return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); + if (auto collapseShapeOp = dyn_cast(user)) + return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); + if (auto transposeOp = dyn_cast(user)) + return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self); + + return false; + }); + }; + + return walkUses(value, walkUses); +} + void walkPimMvmVmmWeightUses(Operation* root, function_ref callback) { assert(root && "expected valid root op"); root->walk([&](pim::PimCoreOp coreOp) { diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index 7b42d35..d65ad8c 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -42,6 +42,7 @@ bool hasWeightAlways(mlir::Operation* op); void markWeightAlways(mlir::Operation* op); bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use); +bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value); void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 71b7619..28e7078 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -87,8 +87,7 @@ void ONNXToSpatialPass::runOnOperation() { tensor::TensorDialect, arith::ArithDialect, scf::SCFDialect>(); - target.addDynamicallyLegalOp( - [](ONNXMatMulOp op) { return cast(op.getY().getType()).getRank() != 2; }); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -391,11 +390,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { funcOp.walk([&](arith::ConstantOp constantOp) { - bool isAlwaysWeight = - !constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool { - return isSpatialMvmVmmWeightUse(use); - }); - if (isAlwaysWeight) + if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult())) markWeightAlways(constantOp); }); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index fa5bb20..5bf5801 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -2,6 +2,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" @@ -14,7 +15,108 @@ using namespace mlir; namespace onnx_mlir { namespace { -struct MatMulRank3ToGemm : OpRewritePattern { +static bool haveStaticPositiveShape(ArrayRef shape) { + return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); +} + +static Value extractBatchMatrix(Value value, + int64_t batchIndex, + int64_t batchSize, + int64_t rows, + int64_t cols, + PatternRewriter& rewriter, + Location loc) { + auto type = cast(value.getType()); + if (type.getRank() == 2) + return value; + + auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType()); + SmallVector offsets = { + rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector sizes = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides); + + auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType()); + return tensor::CollapseShapeOp::create(rewriter, + loc, + matrixType, + slice, + SmallVector { + {0, 1}, + {2} + }); +} + +static bool isConstantLikeOperand(Value value) { + llvm::SmallPtrSet visited; + + while (auto* definingOp = value.getDefiningOp()) { + if (!visited.insert(definingOp).second) + return false; + if (definingOp->hasTrait()) + return true; + + if (auto extractSliceOp = dyn_cast(definingOp)) { + value = extractSliceOp.getSource(); + continue; + } + if (auto expandShapeOp = dyn_cast(definingOp)) { + value = expandShapeOp.getSrc(); + continue; + } + if (auto collapseShapeOp = dyn_cast(definingOp)) { + value = collapseShapeOp.getSrc(); + continue; + } + if (auto transposeOp = dyn_cast(definingOp)) { + value = transposeOp.getData(); + continue; + } + + return false; + } + + return false; +} + +static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { + auto type = cast(value.getType()); + auto shape = type.getShape(); + if (type.getRank() == 2) { + auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); + return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0})); + } + + auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); + return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1})); +} + +static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) { + auto type = cast(value.getType()); + auto shape = type.getShape(); + RankedTensorType transposedType; + SmallVector perm; + if (type.getRank() == 2) { + transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType()); + perm = {1, 0}; + } + else { + transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType()); + perm = {0, 2, 1}; + } + + auto transposeCompute = + createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { + Value transposed = + ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); + spatial::SpatYieldOp::create(rewriter, loc, transposed); + }); + return transposeCompute.getResult(0); +} + +struct MatMulToGemm : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { @@ -24,80 +126,115 @@ struct MatMulRank3ToGemm : OpRewritePattern { if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() || !outType.hasStaticShape()) return failure(); - if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3) + if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3) + || (outType.getRank() != 2 && outType.getRank() != 3)) + return failure(); + if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape()) + || !haveStaticPositiveShape(outType.getShape())) return failure(); - const int64_t batch = rhsType.getDimSize(0); - const int64_t k = rhsType.getDimSize(1); - const int64_t n = rhsType.getDimSize(2); - const int64_t m = lhsType.getDimSize(0); - if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m - || outType.getDimSize(2) != n) + const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1; + const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1; + const int64_t batch = std::max(lhsBatch, rhsBatch); + + if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch)) return failure(); - Location loc = matmulOp.getLoc(); - auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType()); - auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType()); - auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType()); - auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType()); - auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType()); - auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType()); + const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0); + const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1); + const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0); + const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1); + if (k != rhsK) + return failure(); - Value lhsTransposed = - ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0})); - Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); - - SmallVector gemmRows; - gemmRows.reserve(batch * n); - for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { - for (int64_t colIdx = 0; colIdx < n; colIdx++) { - SmallVector offsets = { - rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)}; - SmallVector sizes = { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)}; - SmallVector strides = { - rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value rhsSlice = - tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides); - Value rhsRow = tensor::CollapseShapeOp::create(rewriter, - loc, - rhsRowType, - rhsSlice, - SmallVector { - {0}, - {1, 2} - }); - - auto gemmOp = ONNXGemmOp::create(rewriter, - loc, - gemmRowType, - rhsRow, - lhsTransposed, - none, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - gemmRows.push_back(gemmOp.getY()); - } + if (outType.getRank() == 2) { + if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n) + return failure(); + } + else { + if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n) + return failure(); } - auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) { - auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs); - spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); - }); + Location loc = matmulOp.getLoc(); + bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB()); - Value gemmOut = concatComputeOp.getResult(0); - Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter, - loc, - gemmExpandedType, - gemmOut, - SmallVector { - {0, 1}, - {2} - }); - Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1})); + Value lhs = matmulOp.getA(); + Value rhs = matmulOp.getB(); + int64_t lhsBatchForGemm = lhsBatch; + int64_t rhsBatchForGemm = rhsBatch; + int64_t gemmM = m; + int64_t gemmK = k; + int64_t gemmN = n; + if (useTransposedForm) { + lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc); + lhsBatchForGemm = rhsBatch; + rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); + rhsBatchForGemm = lhsBatch; + gemmM = n; + gemmN = m; + } + auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType()); + auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType()); + Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + + if (outType.getRank() == 2) { + Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); + Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); + Value gemmResult = ONNXGemmOp::create(rewriter, + loc, + gemmType, + lhsMatrix, + rhsMatrix, + none, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + if (useTransposedForm) + gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0})); + rewriter.replaceOp(matmulOp, gemmResult); + return success(); + } + + SmallVector batchResults; + batchResults.reserve(batch); + for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { + Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); + Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); + Value gemmResult = ONNXGemmOp::create(rewriter, + loc, + gemmType, + lhsMatrix, + rhsMatrix, + none, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + if (useTransposedForm) + gemmResult = ONNXTransposeOp::create( + rewriter, + loc, + RankedTensorType::get({m, n}, outType.getElementType()), + gemmResult, + rewriter.getI64ArrayAttr({1, 0})); + batchResults.push_back(tensor::ExpandShapeOp::create(rewriter, + loc, + batchedOutType, + gemmResult, + SmallVector { + {0, 1}, + {2} + })); + } + + Value result = batchResults.size() == 1 + ? batchResults.front() + : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, batchResults).getResult(); rewriter.replaceOp(matmulOp, result); return success(); } @@ -106,7 +243,7 @@ struct MatMulRank3ToGemm : OpRewritePattern { } // namespace void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index ecc5d0d..ebdb429 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -419,8 +419,16 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I if (outShape[1] != static_cast(crossbarSize)) { auto newShape = SmallVector {outShape[0], static_cast(crossbarSize)}; auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); - enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); - outTensorOperand.setType(newType); + if (outTensorOperand == vmmOp.getInput()) { + rewriter.setInsertionPoint(vmmOp); + auto newOutputBuffer = + tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType()); + vmmOp.getOutputBufferMutable().assign(newOutputBuffer); + } + else { + enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); + outTensorOperand.setType(newType); + } resultTensor.setType(newType); IntegerAttr zeroAttr = rewriter.getIndexAttr(0); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index e217cd6..e3445f8 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -178,8 +178,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt); + rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt); return success(); } }; @@ -203,8 +205,10 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt); + rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt); return success(); } };