reformat code
This commit is contained in:
@@ -58,8 +58,14 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value rhsSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rhsRowType,
|
||||
rhsSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
@@ -89,10 +95,15 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
|
||||
rewriter.setInsertionPointAfter(concatComputeOp);
|
||||
Value gemmOut = concatComputeOp.getResult(0);
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(
|
||||
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
|
||||
Value result = ONNXTransposeOp::create(
|
||||
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
gemmExpandedType,
|
||||
gemmOut,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user