compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
This commit is contained in:
@@ -147,161 +147,148 @@ static Value buildPackedBias(bool hasBias,
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType im2colRowType,
|
||||
RankedTensorType gemmInputRowType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value createIm2colRowComputes(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType im2colRowType,
|
||||
RankedTensorType gemmInputRowsType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto elemType = xType.getElementType();
|
||||
constexpr size_t numInputs = 1;
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
auto im2colComputeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
// Pad input with zeros if needed:
|
||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightBegin),
|
||||
rewriter.getIndexAttr(padWidthBegin)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padHeightEnd),
|
||||
rewriter.getIndexAttr(padWidthEnd)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 4; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
|
||||
Value patchIndex = im2colLoop.getInductionVar();
|
||||
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
||||
Value patchIndex = im2colLoop.getInductionVar();
|
||||
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
||||
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
im2colRowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
im2colRowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
|
||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value updatedIm2col =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||
});
|
||||
|
||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value updatedIm2col =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
SmallVector<Value> rowResults;
|
||||
rowResults.reserve(packedNumRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
rowResults.push_back(
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
|
||||
});
|
||||
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(im2colComputeOp.getNumResults());
|
||||
for (Value result : im2colComputeOp.getResults())
|
||||
rows.push_back(result);
|
||||
return rows;
|
||||
return im2colComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
@@ -319,15 +306,12 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||
Value gemmOut;
|
||||
if (packFactor == 1) {
|
||||
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
}
|
||||
else {
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
Value packedOutput = gemmRowArgs.size() == 1
|
||||
? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
@@ -509,35 +493,36 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||
// B_packed: [N * patchSize, N * cOut]
|
||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||
auto gemmOutputRowType =
|
||||
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
gemmInputRowType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||
auto gemmOutputRowsType =
|
||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
Value gemmInputRows = createIm2colRowComputes(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
gemmInputRowsType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
|
||||
Value gemmB = buildPackedWeight(wDenseAttr,
|
||||
wTrans,
|
||||
@@ -553,25 +538,20 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
Value gemmC = buildPackedBias(
|
||||
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(gemmInputRows.size());
|
||||
for (Value gemmInputRow : gemmInputRows) {
|
||||
Value gemmRow = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutputRowType,
|
||||
gemmInputRow,
|
||||
gemmB,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
gemmRows.push_back(gemmRow);
|
||||
}
|
||||
Value gemmRows = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutputRowsType,
|
||||
gemmInputRows,
|
||||
gemmB,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(gemmRows,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
|
||||
Reference in New Issue
Block a user