Different type of convolution
This commit is contained in:
@@ -950,7 +950,12 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty())
|
||||
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector)
|
||||
return failure();
|
||||
|
||||
const bool hasNonSingletonOutputBatch =
|
||||
!shapeInfo->outputBatchShape.empty() && getStaticShapeElementCount(shapeInfo->outputBatchShape) != 1;
|
||||
if (hasNonSingletonOutputBatch)
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
@@ -991,7 +996,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
gemmResult =
|
||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
||||
.getResult();
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
|
||||
if (shapeInfo->outputBatchShape.empty()) {
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto directOutType =
|
||||
RankedTensorType::get({1, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding());
|
||||
Value batchedResult = ensureBatchedTensor(gemmResult, /*batchSize=*/1, shapeInfo->m, shapeInfo->n, rewriter, loc);
|
||||
Value finalResult = finalizeNormalizedMatMulResult(batchedResult, directOutType, *shapeInfo, rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, finalResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user