Different type of convolution
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -950,7 +950,12 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||||
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty())
|
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const bool hasNonSingletonOutputBatch =
|
||||||
|
!shapeInfo->outputBatchShape.empty() && getStaticShapeElementCount(shapeInfo->outputBatchShape) != 1;
|
||||||
|
if (hasNonSingletonOutputBatch)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
@@ -991,7 +996,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
gemmResult =
|
gemmResult =
|
||||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
|
||||||
.getResult();
|
.getResult();
|
||||||
rewriter.replaceOp(matmulOp, gemmResult);
|
|
||||||
|
if (shapeInfo->outputBatchShape.empty()) {
|
||||||
|
rewriter.replaceOp(matmulOp, gemmResult);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto directOutType =
|
||||||
|
RankedTensorType::get({1, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding());
|
||||||
|
Value batchedResult = ensureBatchedTensor(gemmResult, /*batchSize=*/1, shapeInfo->m, shapeInfo->n, rewriter, loc);
|
||||||
|
Value finalResult = finalizeNormalizedMatMulResult(batchedResult, directOutType, *shapeInfo, rewriter, loc);
|
||||||
|
rewriter.replaceOp(matmulOp, finalResult);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user