fix MatMul pattern non-contiguous extract_slices
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s

This commit is contained in:
NiccoloN
2026-04-23 14:44:30 +02:00
parent cff929a083
commit 5545b0f672
6 changed files with 254 additions and 78 deletions

View File

@@ -1,10 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <filesystem> #include <filesystem>
@@ -137,6 +139,35 @@ bool isSpatialMvmVmmWeightUse(OpOperand& use) {
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex); return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
} }
bool hasOnlySpatialMvmVmmWeightUses(Value value) {
SmallPtrSet<Value, 8> visited;
auto walkUses = [&](Value currentValue, auto& self) -> bool {
if (!visited.insert(currentValue).second)
return true;
if (currentValue.use_empty())
return false;
return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) {
if (isSpatialMvmVmmWeightUse(use))
return true;
Operation* user = use.getOwner();
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user))
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user))
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
return false;
});
};
return walkUses(value, walkUses);
}
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) { void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
assert(root && "expected valid root op"); assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) { root->walk([&](pim::PimCoreOp coreOp) {

View File

@@ -42,6 +42,7 @@ bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op); void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use); bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback); void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);

View File

@@ -87,8 +87,7 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect, tensor::TensorDialect,
arith::ArithDialect, arith::ArithDialect,
scf::SCFDialect>(); scf::SCFDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>( target.addIllegalOp<ONNXMatMulOp>();
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXAddOp>(); target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>(); target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>(); target.addIllegalOp<ONNXMulOp>();
@@ -391,11 +390,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) { funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight = if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
!constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
return isSpatialMvmVmmWeightUse(use);
});
if (isAlwaysWeight)
markWeightAlways(constantOp); markWeightAlways(constantOp);
}); });
} }

View File

@@ -2,6 +2,7 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
@@ -14,7 +15,108 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> { static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static Value extractBatchMatrix(Value value,
int64_t batchIndex,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2)
return value;
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
return tensor::CollapseShapeOp::create(rewriter,
loc,
matrixType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static bool isConstantLikeOperand(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
if (type.getRank() == 2) {
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
}
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
@@ -24,80 +126,115 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape()) || !outType.hasStaticShape())
return failure(); return failure();
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3) if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|| (outType.getRank() != 2 && outType.getRank() != 3))
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
return failure(); return failure();
const int64_t batch = rhsType.getDimSize(0); const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
const int64_t k = rhsType.getDimSize(1); const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
const int64_t n = rhsType.getDimSize(2); const int64_t batch = std::max(lhsBatch, rhsBatch);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|| outType.getDimSize(2) != n)
return failure(); return failure();
Location loc = matmulOp.getLoc(); const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType()); const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType()); const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType()); const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType()); if (k != rhsK)
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType()); return failure();
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
Value lhsTransposed = if (outType.getRank() == 2) {
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0})); if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); return failure();
}
SmallVector<Value> gemmRows; else {
gemmRows.reserve(batch * n); if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { return failure();
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
loc,
rhsRowType,
rhsSlice,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
} }
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) { Location loc = matmulOp.getLoc();
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs); bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
});
Value gemmOut = concatComputeOp.getResult(0); Value lhs = matmulOp.getA();
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter, Value rhs = matmulOp.getB();
loc, int64_t lhsBatchForGemm = lhsBatch;
gemmExpandedType, int64_t rhsBatchForGemm = rhsBatch;
gemmOut, int64_t gemmM = m;
SmallVector<ReassociationIndices> { int64_t gemmK = k;
{0, 1}, int64_t gemmN = n;
{2} if (useTransposedForm) {
}); lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1})); lhsBatchForGemm = rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = lhsBatch;
gemmM = n;
gemmN = m;
}
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
if (outType.getRank() == 2) {
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm)
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
SmallVector<Value> batchResults;
batchResults.reserve(batch);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm)
gemmResult = ONNXTransposeOp::create(
rewriter,
loc,
RankedTensorType::get({m, n}, outType.getElementType()),
gemmResult,
rewriter.getI64ArrayAttr({1, 0}));
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
loc,
batchedOutType,
gemmResult,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
}));
}
Value result = batchResults.size() == 1
? batchResults.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, batchResults).getResult();
rewriter.replaceOp(matmulOp, result); rewriter.replaceOp(matmulOp, result);
return success(); return success();
} }
@@ -106,7 +243,7 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
} // namespace } // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulRank3ToGemm>(ctx); patterns.insert<MatMulToGemm>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -419,8 +419,16 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
if (outShape[1] != static_cast<int64_t>(crossbarSize)) { if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)}; auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType()); auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain); if (outTensorOperand == vmmOp.getInput()) {
outTensorOperand.setType(newType); rewriter.setInsertionPoint(vmmOp);
auto newOutputBuffer =
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
}
else {
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
outTensorOperand.setType(newType);
}
resultTensor.setType(newType); resultTensor.setType(newType);
IntegerAttr zeroAttr = rewriter.getIndexAttr(0); IntegerAttr zeroAttr = rewriter.getIndexAttr(0);

View File

@@ -178,8 +178,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>( replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt); rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
return success(); return success();
} }
}; };
@@ -203,8 +205,10 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface,
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimMVMOp>( replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt); rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
return success(); return success();
} }
}; };