This commit is contained in:
@@ -2242,8 +2242,8 @@ static FailureOr<Value> rewriteInputKTiledConv(const ConvLoweringState& state,
|
||||
rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides);
|
||||
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
||||
Value bTile = tensor::ExtractSliceOp::create(
|
||||
rewriter, reduceLoc, weightTileType, weightArg, bOffsets, bSizes, unitStrides);
|
||||
Value bTile = extractStaticSliceOrIdentity(
|
||||
rewriter, reduceLoc, weightArg, weightTileType, bOffsets, bSizes, unitStrides);
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
||||
reduceYielded.push_back(
|
||||
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult());
|
||||
@@ -2912,8 +2912,13 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
||||
rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2));
|
||||
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
||||
Value bTile = tensor::ExtractSliceOp::create(
|
||||
rewriter, reduceLoc, paddedWeightTileType, args.weights.front(), bOffsets, bSizes, getUnitStrides(rewriter, 2));
|
||||
Value bTile = extractStaticSliceOrIdentity(rewriter,
|
||||
reduceLoc,
|
||||
args.weights.front(),
|
||||
paddedWeightTileType,
|
||||
bOffsets,
|
||||
bSizes,
|
||||
getUnitStrides(rewriter, 2));
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
||||
reduceYielded.push_back(
|
||||
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult());
|
||||
|
||||
Reference in New Issue
Block a user