Different type of convolution

This commit is contained in:
ilgeco
2026-06-18 10:59:02 +02:00
parent 4ab24eb288
commit 3a985b3675
2 changed files with 1165 additions and 2 deletions
@@ -950,7 +950,12 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto shapeInfo = analyzeMatMulShape(matmulOp);
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty())
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector)
return failure();
const bool hasNonSingletonOutputBatch =
!shapeInfo->outputBatchShape.empty() && getStaticShapeElementCount(shapeInfo->outputBatchShape) != 1;
if (hasNonSingletonOutputBatch)
return failure();
Location loc = matmulOp.getLoc();
@@ -991,7 +996,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
gemmResult =
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
.getResult();
rewriter.replaceOp(matmulOp, gemmResult);
if (shapeInfo->outputBatchShape.empty()) {
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
auto directOutType =
RankedTensorType::get({1, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding());
Value batchedResult = ensureBatchedTensor(gemmResult, /*batchSize=*/1, shapeInfo->m, shapeInfo->n, rewriter, loc);
Value finalResult = finalizeNormalizedMatMulResult(batchedResult, directOutType, *shapeInfo, rewriter, loc);
rewriter.replaceOp(matmulOp, finalResult);
return success();
}
};