diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 0954a5c..d4dad33 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -7,7 +7,9 @@ #include "llvm/ADT/SmallVector.h" #include +#include +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -29,6 +31,36 @@ struct ConvToGemm : OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; +struct ConvLoweringState { + Value x; + Value w; + Value b; + RankedTensorType xType; + RankedTensorType wType; + RankedTensorType outType; + int64_t batchSize; + int64_t numChannelsIn; + int64_t xHeight; + int64_t xWidth; + int64_t numChannelsOut; + int64_t wHeight; + int64_t wWidth; + int64_t outHeight; + int64_t outWidth; + int64_t group; + int64_t numChannelsInPerGroup; + int64_t numChannelsOutPerGroup; + int64_t padHeightBegin; + int64_t padHeightEnd; + int64_t padWidthBegin; + int64_t padWidthEnd; + int64_t strideHeight; + int64_t strideWidth; + int64_t dilationHeight; + int64_t dilationWidth; + bool hasBias; +}; + static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) @@ -44,58 +76,621 @@ static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, }); } -static Value createPaddedRows(Value tensorValue, - RankedTensorType tensorType, - int64_t paddedRows, - ConversionPatternRewriter& rewriter, - Location loc) { - if (tensorType.getDimSize(0) == paddedRows) - return tensorValue; +static bool +isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) { + return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0; +} - auto paddedType = RankedTensorType::get( - {paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding()); - SmallVector lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; - SmallVector highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)), - rewriter.getIndexAttr(0)}; - auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads); +static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) { + assert(value > 0 && "expected positive value"); + limit = std::min(value, limit); + for (int64_t candidate = limit; candidate >= 1; --candidate) + if (value % candidate == 0) + return candidate; + return 1; +} + +static Value createZeroPaddedTensor(Value value, + RankedTensorType resultType, + ArrayRef lowPadValues, + ArrayRef highPadValues, + ConversionPatternRewriter& rewriter, + Location loc) { + auto valueType = cast(value.getType()); + if (valueType == resultType) + return value; + + SmallVector lowPads; + SmallVector highPads; + lowPads.reserve(lowPadValues.size()); + highPads.reserve(highPadValues.size()); + for (auto lowPad : lowPadValues) + lowPads.push_back(rewriter.getIndexAttr(lowPad)); + for (auto highPad : highPadValues) + highPads.push_back(rewriter.getIndexAttr(highPad)); + + auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); auto* padBlock = new Block(); - for (int i = 0; i < 2; ++i) + for (int64_t dim = 0, rank = resultType.getRank(); dim < rank; ++dim) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); auto zero = getOrCreateConstant( - rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType()); + rewriter, padOp.getOperation(), rewriter.getZeroAttr(resultType.getElementType()), resultType.getElementType()); tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } +static Value affineAddConst( + ConversionPatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) { + if (offset == 0) + return value; + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor); +} + +static Value createConvInputPatch(Value input, + RankedTensorType patchType, + Value batchIndex, + Value channelOffset, + Value inputHeightOffset, + Value inputWidthOffset, + int64_t dilationHeight, + int64_t dilationWidth, + ConversionPatternRewriter& rewriter, + Location loc) { + const int64_t patchChannels = patchType.getDimSize(1); + const int64_t kernelHeight = patchType.getDimSize(2); + const int64_t kernelWidth = patchType.getDimSize(3); + if (dilationHeight == 1 && dilationWidth == 1) { + SmallVector offsets {batchIndex, channelOffset, inputHeightOffset, inputWidthOffset}; + SmallVector sizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(patchChannels), + rewriter.getIndexAttr(kernelHeight), + rewriter.getIndexAttr(kernelWidth)}; + return tensor::ExtractSliceOp::create(rewriter, loc, patchType, input, offsets, sizes, getUnitStrides(rewriter, 4)); + } + + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + auto elementType = patchType.getElementType(); + auto pixelType = RankedTensorType::get({1, patchChannels, 1, 1}, elementType, patchType.getEncoding()); + Value patch = tensor::EmptyOp::create(rewriter, loc, patchType.getShape(), elementType); + for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { + Value sourceHeightOffset = affineAddConst(rewriter, loc, inputHeightOffset, kernelH * dilationHeight, anchorOp); + for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { + Value sourceWidthOffset = affineAddConst(rewriter, loc, inputWidthOffset, kernelW * dilationWidth, anchorOp); + SmallVector sourceOffsets {batchIndex, channelOffset, sourceHeightOffset, sourceWidthOffset}; + SmallVector sourceSizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(patchChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + Value sourcePixel = tensor::ExtractSliceOp::create( + rewriter, loc, pixelType, input, sourceOffsets, sourceSizes, getUnitStrides(rewriter, 4)); + SmallVector targetOffsets { + rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(kernelH), rewriter.getIndexAttr(kernelW)}; + patch = tensor::InsertSliceOp::create( + rewriter, loc, sourcePixel, patch, targetOffsets, sourceSizes, getUnitStrides(rewriter, 4)); + } + } + return patch; +} + +static Value createCollectedConvOutput(ValueRange gemmRows, + Type convType, + RankedTensorType gemmOutType, + RankedTensorType nhwcType, + RankedTensorType outType, + int64_t numPatches, + int64_t numChannelsOut, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc); + +namespace depthwise { + +struct Tiling { + int64_t outputMultiplier; + int64_t kernelElements; + int64_t channelsPerTile; + int64_t tileInputRows; + int64_t tileOutputChannels; + int64_t numChannelTiles; + int64_t spatialPatchesPerBatch; + int64_t totalPatches; +}; + +static std::optional computeTiling(int64_t batchSize, + int64_t numChannelsIn, + int64_t numChannelsOut, + int64_t wHeight, + int64_t wWidth, + int64_t outHeight, + int64_t outWidth) { + const int64_t kernelElements = wHeight * wWidth; + const int64_t outputMultiplier = numChannelsOut / numChannelsIn; + const int64_t xbarDim = static_cast(crossbarSize.getValue()); + if (kernelElements <= 0 || outputMultiplier <= 0 || kernelElements > xbarDim || outputMultiplier > xbarDim) + return std::nullopt; + + const int64_t maxChannelsPerTile = std::min(xbarDim / kernelElements, xbarDim / outputMultiplier); + if (maxChannelsPerTile <= 0) + return std::nullopt; + + const int64_t channelsPerTile = findLargestDivisorAtMost(numChannelsIn, maxChannelsPerTile); + const int64_t tileInputRows = channelsPerTile * kernelElements; + const int64_t tileOutputChannels = channelsPerTile * outputMultiplier; + if (tileInputRows > xbarDim || tileOutputChannels > xbarDim) + return std::nullopt; + + return Tiling { + outputMultiplier, + kernelElements, + channelsPerTile, + tileInputRows, + tileOutputChannels, + numChannelsIn / channelsPerTile, + outHeight * outWidth, + batchSize * outHeight * outWidth, + }; +} + +static Value buildPackedWeights(DenseElementsAttr wDenseAttr, + RankedTensorType wType, + const Tiling& tiling, + ConversionPatternRewriter& rewriter, + Location loc) { + auto packedWeightType = RankedTensorType::get( + {tiling.numChannelTiles, tiling.tileInputRows, tiling.tileOutputChannels}, wType.getElementType()); + SmallVector packedValues(packedWeightType.getNumElements(), + cast(rewriter.getZeroAttr(wType.getElementType()))); + SmallVector sourceValues(wDenseAttr.getValues()); + + for (int64_t tileIndex = 0; tileIndex < tiling.numChannelTiles; ++tileIndex) { + const int64_t channelBase = tileIndex * tiling.channelsPerTile; + for (int64_t localChannel = 0; localChannel < tiling.channelsPerTile; ++localChannel) { + const int64_t globalChannel = channelBase + localChannel; + for (int64_t kernelIndex = 0; kernelIndex < tiling.kernelElements; ++kernelIndex) { + const int64_t kernelH = kernelIndex / wType.getDimSize(3); + const int64_t kernelW = kernelIndex % wType.getDimSize(3); + const int64_t targetRow = localChannel * tiling.kernelElements + kernelIndex; + for (int64_t multiplierIndex = 0; multiplierIndex < tiling.outputMultiplier; ++multiplierIndex) { + const int64_t globalOutChannel = globalChannel * tiling.outputMultiplier + multiplierIndex; + const int64_t sourceFlatIndex = + ((globalOutChannel * wType.getDimSize(1) * wType.getDimSize(2)) + kernelH) * wType.getDimSize(3) + kernelW; + const int64_t targetCol = localChannel * tiling.outputMultiplier + multiplierIndex; + const int64_t targetFlatIndex = + ((tileIndex * tiling.tileInputRows) + targetRow) * tiling.tileOutputChannels + targetCol; + packedValues[targetFlatIndex] = sourceValues[sourceFlatIndex]; + } + } + } + } + + auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType); +} + +static Value createPaddedInput(Value input, + RankedTensorType inputType, + int64_t padHeightBegin, + int64_t padHeightEnd, + int64_t padWidthBegin, + int64_t padWidthEnd, + ConversionPatternRewriter& rewriter, + Location loc) { + if (padHeightBegin == 0 && padHeightEnd == 0 && padWidthBegin == 0 && padWidthEnd == 0) + return input; + + auto paddedInputType = RankedTensorType::get({inputType.getDimSize(0), + inputType.getDimSize(1), + inputType.getDimSize(2) + padHeightBegin + padHeightEnd, + inputType.getDimSize(3) + padWidthBegin + padWidthEnd}, + inputType.getElementType()); + auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { + Value padded = createZeroPaddedTensor(computeInput, + paddedInputType, + {0, 0, padHeightBegin, padWidthBegin}, + {0, 0, padHeightEnd, padWidthEnd}, + rewriter, + loc); + spatial::SpatYieldOp::create(rewriter, loc, padded); + }); + return computeOp.getResult(0); +} + +static Value createInputTile(Value input, + Value patchIndex, + Value channelTileIndex, + RankedTensorType inputTileType, + const Tiling& tiling, + int64_t strideHeight, + int64_t strideWidth, + int64_t dilationHeight, + int64_t dilationWidth, + int64_t outWidth, + ConversionPatternRewriter& rewriter, + Location loc) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value batchIndex = affineFloorDivConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp); + Value batchPatchIndex = affineModConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp); + Value outHeightIndex = affineFloorDivConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp); + Value outWidthIndex = affineModConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp); + Value inputHeightOffset = + strideHeight == 1 ? outHeightIndex : affineMulConst(rewriter, loc, outHeightIndex, strideHeight, anchorOp); + Value inputWidthOffset = + strideWidth == 1 ? outWidthIndex : affineMulConst(rewriter, loc, outWidthIndex, strideWidth, anchorOp); + Value channelOffset = tiling.channelsPerTile == 1 + ? channelTileIndex + : affineMulConst(rewriter, loc, channelTileIndex, tiling.channelsPerTile, anchorOp); + Value tile4D = createConvInputPatch(input, + inputTileType, + batchIndex, + channelOffset, + inputHeightOffset, + inputWidthOffset, + dilationHeight, + dilationWidth, + rewriter, + loc); + auto collapsedType = RankedTensorType::get({1, tiling.tileInputRows}, inputTileType.getElementType()); + return tensor::CollapseShapeOp::create(rewriter, + loc, + collapsedType, + tile4D, + SmallVector { + {0}, + {1, 2, 3} + }); +} + +static Value createWeightTile(Value packedWeights, + Value channelTileIndex, + RankedTensorType packedWeightType, + const Tiling& tiling, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector offsets {channelTileIndex, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector sizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tiling.tileInputRows), + rewriter.getIndexAttr(tiling.tileOutputChannels)}; + auto sliceType = + RankedTensorType::get({1, tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType()); + Value slice = tensor::ExtractSliceOp::create( + rewriter, loc, sliceType, packedWeights, offsets, sizes, getUnitStrides(rewriter, 3)); + auto collapsedType = + RankedTensorType::get({tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType()); + return tensor::CollapseShapeOp::create(rewriter, + loc, + collapsedType, + slice, + SmallVector { + {0, 1}, + {2} + }); +} + +static Value createBiasTile( + Value bias, Value channelTileIndex, const Tiling& tiling, ConversionPatternRewriter& rewriter, Location loc) { + auto biasType = cast(bias.getType()); + auto biasTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, biasType.getElementType()); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value channelOffset = tiling.tileOutputChannels == 1 + ? channelTileIndex + : affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp); + SmallVector offsets {rewriter.getIndexAttr(0), channelOffset}; + SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)}; + return tensor::ExtractSliceOp::create(rewriter, loc, biasTileType, bias, offsets, sizes, getUnitStrides(rewriter, 2)); +} + +static Value insertOutputTile(Value rowTile, + Value rowAccumulator, + Value channelTileIndex, + const Tiling& tiling, + ConversionPatternRewriter& rewriter, + Location loc) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value channelOffset = tiling.tileOutputChannels == 1 + ? channelTileIndex + : affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp); + SmallVector offsets {rewriter.getIndexAttr(0), channelOffset}; + SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)}; + return tensor::InsertSliceOp::create( + rewriter, loc, rowTile, rowAccumulator, offsets, sizes, getUnitStrides(rewriter, 2)); +} + +static FailureOr reconstructDepthwiseGemmRows(Value pieces, + RankedTensorType piecesType, + RankedTensorType gemmOutType, + const Tiling& tiling, + ConversionPatternRewriter& rewriter, + Location loc) { + auto collectedOp = createSpatCompute<1>(rewriter, loc, TypeRange {gemmOutType}, {}, pieces, [&](Value piecesArg) { + auto rowType = RankedTensorType::get({1, gemmOutType.getDimSize(1)}, gemmOutType.getElementType()); + Value outputInit = tensor::EmptyOp::create(rewriter, loc, gemmOutType.getShape(), gemmOutType.getElementType()); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, tiling.totalPatches); + Value cNumChannelTiles = getOrCreateIndexConstant(rewriter, anchorOp, tiling.numChannelTiles); + + auto patchLoop = buildNormalizedScfFor( + rewriter, + loc, + c0, + cNumPatches, + c1, + ValueRange {outputInit}, + [&](OpBuilder&, + Location nestedLoc, + Value patchIndex, + ValueRange patchIterArgs, + SmallVectorImpl& patchYielded) { + Value outputAcc = patchIterArgs.front(); + Value rowInit = tensor::EmptyOp::create(rewriter, nestedLoc, rowType.getShape(), rowType.getElementType()); + auto tileLoop = buildNormalizedScfFor( + rewriter, + nestedLoc, + c0, + cNumChannelTiles, + c1, + ValueRange {rowInit}, + [&](OpBuilder&, + Location tileLoc, + Value channelTileIndex, + ValueRange tileIterArgs, + SmallVectorImpl& tileYielded) { + Value rowAcc = tileIterArgs.front(); + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineExpr d1 = getAffineDimExpr(1, context); + Value laneIndex = createOrFoldAffineApply( + rewriter, tileLoc, (d0 * tiling.totalPatches) + d1, ValueRange {channelTileIndex, patchIndex}, anchorOp); + auto rowTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, piecesType.getElementType()); + SmallVector pieceOffsets {laneIndex, rewriter.getIndexAttr(0)}; + SmallVector pieceSizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tiling.tileOutputChannels)}; + Value rowTile = tensor::ExtractSliceOp::create( + rewriter, tileLoc, rowTileType, piecesArg, pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2)); + Value rowNext = insertOutputTile(rowTile, rowAcc, channelTileIndex, tiling, rewriter, tileLoc); + tileYielded.push_back(rowNext); + return success(); + }); + if (failed(tileLoop)) + return failure(); + + SmallVector rowOffsets {patchIndex, rewriter.getIndexAttr(0)}; + SmallVector rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(gemmOutType.getDimSize(1))}; + Value outputNext = tensor::InsertSliceOp::create(rewriter, + nestedLoc, + tileLoop->results.front(), + outputAcc, + rowOffsets, + rowSizes, + getUnitStrides(rewriter, 2)) + .getResult(); + patchYielded.push_back(outputNext); + return success(); + }); + if (failed(patchLoop)) + return failure(); + + spatial::SpatYieldOp::create(rewriter, loc, patchLoop->results.front()); + return success(); + }); + if (failed(collectedOp)) + return failure(); + return collectedOp->getResult(0); +} + +static bool canUseStructuredRewrite(const ConvLoweringState& state) { + if (!getHostConstDenseElementsAttr(state.w)) + return false; + + if (!computeTiling(state.batchSize, + state.numChannelsIn, + state.numChannelsOut, + state.wHeight, + state.wWidth, + state.outHeight, + state.outWidth)) { + return false; + } + + if (isa(state.b.getDefiningOp())) + return true; + + auto biasType = dyn_cast(state.b.getType()); + if (!biasType) + return false; + if (biasType.getRank() == 1) + return biasType.getDimSize(0) == state.numChannelsOut; + if (biasType.getRank() != 2) + return false; + return biasType.getDimSize(0) == 1 && biasType.getDimSize(1) == state.numChannelsOut; +} + +static FailureOr +rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter, Location loc) { + auto wDenseAttr = getHostConstDenseElementsAttr(state.w); + if (!wDenseAttr) { + convOp.emitOpError("requires constant-derived weights for structured depthwise Spatial lowering"); + return failure(); + } + + auto tiling = computeTiling(state.xType.getDimSize(0), + state.xType.getDimSize(1), + state.outType.getDimSize(1), + state.wType.getDimSize(2), + state.wType.getDimSize(3), + state.outType.getDimSize(2), + state.outType.getDimSize(3)); + if (!tiling) { + convOp.emitOpError("failed to derive a structured depthwise tiling that fits Spatial weighted VMM lowering"); + return failure(); + } + + Value paddedInput = createPaddedInput(state.x, + state.xType, + state.padHeightBegin, + state.padHeightEnd, + state.padWidthBegin, + state.padWidthEnd, + rewriter, + loc); + Value packedWeights = buildPackedWeights(wDenseAttr, state.wType, *tiling, rewriter, loc); + + Value expandedBias; + SmallVector batchInputs {paddedInput}; + if (!isa(state.b.getDefiningOp())) { + expandedBias = expandBiasIfNeeded(state.b, rewriter, loc); + auto biasType = dyn_cast(expandedBias.getType()); + if (!biasType || biasType.getRank() != 2 || biasType.getDimSize(0) != 1 + || biasType.getDimSize(1) != state.outType.getDimSize(1)) { + convOp.emitOpError("requires bias sliceable as tensor<1xCout> for structured depthwise Spatial lowering"); + return failure(); + } + batchInputs.push_back(expandedBias); + } + + auto gemmOutType = + RankedTensorType::get({tiling->totalPatches, state.outType.getDimSize(1)}, state.outType.getElementType()); + auto piecesType = RankedTensorType::get({tiling->totalPatches * tiling->numChannelTiles, tiling->tileOutputChannels}, + state.outType.getElementType()); + auto paddedInputType = cast(paddedInput.getType()); + auto inputTileType = + RankedTensorType::get({1, tiling->channelsPerTile, state.wType.getDimSize(2), state.wType.getDimSize(3)}, + paddedInputType.getElementType()); + + auto batchOp = createSpatComputeBatch( + rewriter, + loc, + TypeRange {piecesType}, + tiling->totalPatches * tiling->numChannelTiles, + ValueRange {packedWeights}, + batchInputs, + [&](detail::SpatComputeBatchBodyArgs args) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value patchIndex = affineModConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); + Value channelTileIndex = affineFloorDivConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); + Value inputTile = createInputTile(args.inputs.front(), + patchIndex, + channelTileIndex, + inputTileType, + *tiling, + state.strideHeight, + state.strideWidth, + state.dilationHeight, + state.dilationWidth, + state.outType.getDimSize(3), + rewriter, + loc); + Value weightTile = createWeightTile(args.weights.front(), + channelTileIndex, + cast(args.weights.front().getType()), + *tiling, + rewriter, + loc); + auto rowTileType = RankedTensorType::get({1, tiling->tileOutputChannels}, state.outType.getElementType()); + Value rowTile = spatial::SpatVMMOp::create(rewriter, loc, rowTileType, weightTile, inputTile).getResult(); + if (args.inputs.size() > 1) { + Value biasTile = createBiasTile(args.inputs[1], channelTileIndex, *tiling, rewriter, loc); + rowTile = spatial::SpatVAddOp::create(rewriter, loc, rowTileType, rowTile, biasTile).getResult(); + } + + SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; + SmallVector outputSizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tiling->tileOutputChannels)}; + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, rowTile, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2)); + }); + if (failed(batchOp)) + return failure(); + + auto nhwcType = RankedTensorType::get( + {state.xType.getDimSize(0), state.outType.getDimSize(2), state.outType.getDimSize(3), state.outType.getDimSize(1)}, + state.outType.getElementType()); + auto collectedRows = + reconstructDepthwiseGemmRows(batchOp->getResult(0), piecesType, gemmOutType, *tiling, rewriter, loc); + if (failed(collectedRows)) + return failure(); + + return createCollectedConvOutput(ValueRange {*collectedRows}, + state.outType, + gemmOutType, + nhwcType, + state.outType, + tiling->totalPatches, + state.outType.getDimSize(1), + /*packFactor=*/1, + rewriter, + loc); +} + +} // namespace depthwise + +namespace standard { + +struct ConvGemmPlan { + int64_t patchSize; + int64_t numPatchesPerBatch; + int64_t numPatches; + int64_t maxParallelPixels; + int64_t effectiveMaxParallelPixels; + int64_t packedNumRows; + + RankedTensorType im2colType; + RankedTensorType im2colRowType; + RankedTensorType gemmInputRowsType; + RankedTensorType wFlatType; + RankedTensorType wTransType; + RankedTensorType gemmOutType; + RankedTensorType gemmOutputRowsType; + RankedTensorType nhwcType; +}; + +static Value createPaddedRows(Value rows, + RankedTensorType rowsType, + int64_t paddedRows, + ConversionPatternRewriter& rewriter, + Location loc) { + if (rowsType.getDimSize(0) == paddedRows) + return rows; + + auto paddedType = + RankedTensorType::get({paddedRows, rowsType.getDimSize(1)}, rowsType.getElementType(), rowsType.getEncoding()); + return createZeroPaddedTensor( + rows, paddedType, {0, 0}, {paddedRows - rowsType.getDimSize(0), 0}, rewriter, loc); +} + static Value packRowsForParallelGemm( Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { if (packFactor == 1) return rows; - const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor); - const int64_t paddedNumRows = packedNumRows * packFactor; + const int64_t paddedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor) * packFactor; + const int64_t packedNumRows = paddedNumRows / packFactor; const int64_t rowWidth = rowsType.getDimSize(1); auto groupedType = RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding()); auto packedType = RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding()); - Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc); - Value groupedRows = tensor::ExpandShapeOp::create(rewriter, - loc, - groupedType, - paddedRows, - SmallVector { - {0, 1}, - {2} + Value padded = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc); + Value grouped = tensor::ExpandShapeOp::create(rewriter, + loc, + groupedType, + padded, + SmallVector { + {0, 1}, + {2} }); return tensor::CollapseShapeOp::create(rewriter, loc, packedType, - groupedRows, + grouped, SmallVector { {0}, {1, 2} @@ -121,62 +716,82 @@ static Value unpackRowsFromParallelGemm(Value packedRows, auto unpackedType = RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding()); - Value expandedRows = tensor::ExpandShapeOp::create(rewriter, - loc, - expandedType, - packedRows, - SmallVector { - {0}, - {1, 2} + Value expanded = tensor::ExpandShapeOp::create(rewriter, + loc, + expandedType, + packedRows, + SmallVector { + {0}, + {1, 2} }); - Value paddedRows = tensor::CollapseShapeOp::create(rewriter, - loc, - paddedType, - expandedRows, - SmallVector { - {0, 1}, - {2} + Value padded = tensor::CollapseShapeOp::create(rewriter, + loc, + paddedType, + expanded, + SmallVector { + {0, 1}, + {2} }); if (paddedNumRows == unpackedRows) - return paddedRows; + return padded; SmallVector offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)}; - SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides); + return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, padded, offsets, sizes, getUnitStrides(rewriter, 2)); } -static Value buildPackedWeight(DenseElementsAttr wDenseAttr, - Value wTrans, - RankedTensorType wType, - int64_t numChannelsIn, - int64_t numChannelsOut, - int64_t wHeight, - int64_t wWidth, - int64_t patchSize, - int64_t packFactor, - ConversionPatternRewriter& rewriter, - Location loc) { - if (packFactor == 1) +static Value createWeightMatrix( + Value weights, const ConvGemmPlan& plan, ConversionPatternRewriter& rewriter, Location loc) { + auto buildWeightMatrix = [&](Value weight) -> Value { + Value flattened = tensor::CollapseShapeOp::create(rewriter, + loc, + plan.wFlatType, + weight, + SmallVector { + {0}, + {1, 2, 3} + }); + return ONNXTransposeOp::create(rewriter, loc, plan.wTransType, flattened, rewriter.getI64ArrayAttr({1, 0})) + .getResult(); + }; + + if (isCompileTimeComputable(weights)) + return buildWeightMatrix(weights); + + auto computeOp = + createSpatCompute<1>(rewriter, loc, TypeRange {plan.wTransType}, {}, ValueRange {weights}, [&](Value weight) { + spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight)); + }); + return computeOp.getResult(0); +} + +static Value buildPackedWeights(DenseElementsAttr wDenseAttr, + Value wTrans, + const ConvLoweringState& state, + const ConvGemmPlan& plan, + ConversionPatternRewriter& rewriter, + Location loc) { + if (plan.effectiveMaxParallelPixels == 1) return wTrans; - auto packedWeightType = - RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType()); + auto packedWeightType = RankedTensorType::get( + {plan.effectiveMaxParallelPixels * plan.patchSize, plan.effectiveMaxParallelPixels * state.numChannelsOut}, + state.wType.getElementType()); SmallVector sourceValues(wDenseAttr.getValues()); SmallVector packedValues(packedWeightType.getNumElements(), - cast(rewriter.getZeroAttr(wType.getElementType()))); + cast(rewriter.getZeroAttr(state.wType.getElementType()))); - for (int64_t copyId = 0; copyId < packFactor; copyId++) { - for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) { - for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) { - for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) { - for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) { + for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId) { + for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) { + for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) { + for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) { + for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) { const int64_t sourceFlatIndex = - (((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW; - const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW; - const int64_t targetRow = copyId * patchSize + patchIndex; - const int64_t targetCol = copyId * numChannelsOut + outChannel; - packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex]; + (((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW; + const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW; + const int64_t targetRow = copyId * plan.patchSize + patchIndex; + const int64_t targetCol = copyId * state.numChannelsOut + outChannel; + packedValues[targetRow * packedWeightType.getDimSize(1) + targetCol] = sourceValues[sourceFlatIndex]; } } } @@ -187,124 +802,95 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr, return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType); } -static Value createConvWeightMatrix( - Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) { - auto buildWeightMatrix = [&](Value weight) -> Value { - Value wFlat = tensor::CollapseShapeOp::create(rewriter, - loc, - wFlatType, - weight, - SmallVector { - {0}, - {1, 2, 3} - }); - return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult(); - }; - - if (isCompileTimeComputable(w)) - return buildWeightMatrix(w); - - auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) { - spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight)); - }); - return computeOp.getResult(0); -} - -static Value buildPackedBias(bool hasBias, - Value gemmBias, +static Value buildPackedBias(Value gemmBias, Value biasMatrix, DenseElementsAttr biasDenseAttr, - RankedTensorType outType, - int64_t numChannelsOut, - int64_t packFactor, + const ConvLoweringState& state, + const ConvGemmPlan& plan, ConversionPatternRewriter& rewriter, Location loc) { - if (!hasBias) + if (!state.hasBias) return gemmBias; - if (packFactor == 1) + if (plan.effectiveMaxParallelPixels == 1) return biasMatrix; SmallVector sourceValues(biasDenseAttr.getValues()); SmallVector packedValues; - packedValues.reserve(packFactor * numChannelsOut); - for (int64_t copyId = 0; copyId < packFactor; copyId++) + packedValues.reserve(plan.effectiveMaxParallelPixels * state.numChannelsOut); + for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId) packedValues.append(sourceValues.begin(), sourceValues.end()); - auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType()); + auto packedBiasType = + RankedTensorType::get({1, plan.effectiveMaxParallelPixels * state.numChannelsOut}, state.outType.getElementType()); auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType); } -static Value createIm2colRowComputes(Value x, - RankedTensorType xType, - RankedTensorType im2colType, - RankedTensorType im2colRowType, - RankedTensorType gemmInputRowsType, - int64_t batchSize, - int64_t numChannelsIn, - int64_t xHeight, - int64_t xWidth, - int64_t wHeight, - int64_t wWidth, - int64_t padHeightBegin, - int64_t padHeightEnd, - int64_t padWidthBegin, - int64_t padWidthEnd, - int64_t strideHeight, - int64_t strideWidth, - int64_t dilationHeight, - int64_t dilationWidth, - int64_t outWidth, - int64_t patchSize, - int64_t numPatches, - int64_t numPatchesPerBatch, - int64_t packFactor, - ConversionPatternRewriter& rewriter, - Location loc) { - auto elemType = xType.getElementType(); +static ConvGemmPlan +buildConvGemmPlan(const ConvLoweringState& state, bool canPackWeightsAsConstants, bool canPackBiasAsConstants) { + ConvGemmPlan plan; + plan.patchSize = state.numChannelsIn * state.wHeight * state.wWidth; + plan.numPatchesPerBatch = state.outHeight * state.outWidth; + plan.numPatches = state.batchSize * plan.numPatchesPerBatch; + const int64_t wMaxDim = std::max(plan.patchSize, state.numChannelsOut); + plan.maxParallelPixels = std::max(1, static_cast(crossbarSize.getValue()) / wMaxDim); + plan.effectiveMaxParallelPixels = + (canPackWeightsAsConstants && canPackBiasAsConstants) ? plan.maxParallelPixels : 1; + plan.packedNumRows = ceilIntegerDivide(plan.numPatches, plan.effectiveMaxParallelPixels); + + auto elemType = state.xType.getElementType(); + auto outElemType = state.outType.getElementType(); + plan.im2colType = RankedTensorType::get({plan.numPatches, plan.patchSize}, elemType); + plan.im2colRowType = RankedTensorType::get({1, plan.patchSize}, elemType); + plan.gemmInputRowsType = + RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * plan.patchSize}, elemType); + plan.wFlatType = RankedTensorType::get({state.numChannelsOut, plan.patchSize}, state.wType.getElementType()); + plan.wTransType = RankedTensorType::get({plan.patchSize, state.numChannelsOut}, state.wType.getElementType()); + plan.gemmOutType = RankedTensorType::get({plan.numPatches, state.numChannelsOut}, outElemType); + plan.gemmOutputRowsType = + RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * state.numChannelsOut}, outElemType); + plan.nhwcType = + RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, outElemType); + return plan; +} + +static Value createIm2colRows( + const ConvLoweringState& state, const ConvGemmPlan& plan, ConversionPatternRewriter& rewriter, Location loc) { constexpr size_t numInputs = 1; auto im2colComputeOp = - createSpatCompute(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) { + createSpatCompute(rewriter, loc, TypeRange {plan.gemmInputRowsType}, {}, state.x, [&](Value xArg) { + auto elemType = state.xType.getElementType(); Value paddedInput = xArg; - - // Pad input with zeros if needed: - // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] - if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { - const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; - const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; - auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); - SmallVector lowPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightBegin), - rewriter.getIndexAttr(padWidthBegin)}; - SmallVector highPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightEnd), - rewriter.getIndexAttr(padWidthEnd)}; - auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); - auto* padBlock = new Block(); - for (int i = 0; i < 4; i++) - padBlock->addArgument(rewriter.getIndexType(), loc); - padOp.getRegion().push_back(padBlock); - rewriter.setInsertionPointToStart(padBlock); - auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getFloatAttr(elemType, 0.0), elemType); - tensor::YieldOp::create(rewriter, loc, zero); - rewriter.setInsertionPointAfter(padOp); - paddedInput = padOp.getResult(); + if (state.padHeightBegin || state.padHeightEnd || state.padWidthBegin || state.padWidthEnd) { + auto paddedInputType = RankedTensorType::get( + {state.batchSize, + state.numChannelsIn, + state.xHeight + state.padHeightBegin + state.padHeightEnd, + state.xWidth + state.padWidthBegin + state.padWidthEnd}, + elemType); + paddedInput = createZeroPaddedTensor(paddedInput, + paddedInputType, + {0, 0, state.padHeightBegin, state.padWidthBegin}, + {0, 0, state.padHeightEnd, state.padWidthEnd}, + rewriter, + loc); } - // Build im2col [numPatches, patchSize] incrementally to keep the IR small - // until the late PIM unrolling step. - Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); + // Keep the standard im2col view of convolution, flipped so filters sit in + // B / crossbar columns: + // A (im2col): [numPatches, patchSize] -- one row per output spatial position + // B (weights): [patchSize, cOut] + // Gemm output: [numPatches, cOut] + Value im2colInit = tensor::EmptyOp::create(rewriter, loc, plan.im2colType.getShape(), elemType); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); - auto c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); - auto c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); - auto cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, numPatches); - auto cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, numPatchesPerBatch); - auto cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, outWidth); - auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight); - auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth); + Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatches); + Value cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatchesPerBatch); + Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth); + Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, state.strideHeight); + Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.strideWidth); auto im2colLoop = buildNormalizedScfFor( rewriter, @@ -322,44 +908,44 @@ static Value createIm2colRowComputes(Value x, Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight); Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth); - SmallVector offsets = { - batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(numChannelsIn), - rewriter.getIndexAttr(wHeight), - rewriter.getIndexAttr(wWidth)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(dilationHeight), - rewriter.getIndexAttr(dilationWidth)}; - auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); - Value patch = - tensor::ExtractSliceOp::create(rewriter, nestedLoc, patchType, paddedInput, offsets, sizes, strides); - + auto patchType = + RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType); + Value patch = createConvInputPatch(paddedInput, + patchType, + batchIndex, + c0, + inputHeightOffset, + inputWidthOffset, + state.dilationHeight, + state.dilationWidth, + rewriter, + nestedLoc); Value row = tensor::CollapseShapeOp::create(rewriter, nestedLoc, - im2colRowType, + plan.im2colRowType, patch, SmallVector { {0}, {1, 2, 3} }); - SmallVector rowOffsets = {patchIndex, rewriter.getIndexAttr(0)}; - SmallVector rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; - SmallVector rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value updatedIm2col = - tensor::InsertSliceOp::create(rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); - yielded.push_back(updatedIm2col); + SmallVector rowOffsets {patchIndex, rewriter.getIndexAttr(0)}; + SmallVector rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(plan.patchSize)}; + Value next = tensor::InsertSliceOp::create( + rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); + yielded.push_back(next); return success(); }); if (failed(im2colLoop)) return failure(); - Value im2col = im2colLoop->results.front(); - Value gemmInputRows = im2col; - if (packFactor != 1) - gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc); + Value gemmInputRows = im2colLoop->results.front(); + // Pack N old im2col rows into one longer row so one GEMM can cover N + // pixels in parallel. The corresponding packed weight matrix contains N + // block-diagonal copies of W^T, and the packed output must be unpacked + // back to one row per spatial patch. + if (plan.effectiveMaxParallelPixels != 1) + gemmInputRows = packRowsForParallelGemm(gemmInputRows, plan.im2colType, plan.effectiveMaxParallelPixels, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); return success(); @@ -369,6 +955,52 @@ static Value createIm2colRowComputes(Value x, return im2colComputeOp->getResult(0); } +static Value rewriteConv(const ConvLoweringState& state, ConversionPatternRewriter& rewriter, Location loc) { + auto wDenseAttr = getHostConstDenseElementsAttr(state.w); + Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + Value biasMatrix; + DenseElementsAttr biasDenseAttr; + if (state.hasBias) { + gemmBias = state.b; + biasDenseAttr = getHostConstDenseElementsAttr(state.b); + biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); + } + + ConvGemmPlan plan = + buildConvGemmPlan(state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr)); + // Prepare weight matrix W for crossbar storage: + // W: [Cout, Cin, KH, KW] -> [Cout, patchSize] -> [patchSize, Cout] + Value weightMatrix = createWeightMatrix(state.w, plan, rewriter, loc); + Value gemmInputRows = createIm2colRows(state, plan, rewriter, loc); + Value gemmB = buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc); + Value gemmC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc); + + Value gemmRows = ONNXGemmOp::create(rewriter, + loc, + plan.gemmOutputRowsType, + gemmInputRows, + gemmB, + gemmC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + + return createCollectedConvOutput(ValueRange {gemmRows}, + state.outType, + plan.gemmOutType, + plan.nhwcType, + state.outType, + plan.numPatches, + state.numChannelsOut, + plan.effectiveMaxParallelPixels, + rewriter, + loc); +} + +} // namespace standard + static Value createCollectedConvOutput(ValueRange gemmRows, Type convType, RankedTensorType gemmOutType, @@ -386,19 +1018,14 @@ static Value createCollectedConvOutput(ValueRange gemmRows, } else { Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); - gemmOut = unpackRowsFromParallelGemm(packedOutput, - cast(packedOutput.getType()), - numPatches, - numChannelsOut, - packFactor, - rewriter, - loc); + gemmOut = standard::unpackRowsFromParallelGemm( + packedOutput, cast(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc); } - // Restore to NCHW layout: + // Restore output layout: // [numPatches, numChannelsOut] - // -> [1, outHeight, outWidth, numChannelsOut] - // -> [1, numChannelsOut, outHeight, outWidth] + // -> [N, Hout, Wout, Cout] + // -> [N, Cout, Hout, Wout] Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, loc, nhwcType, @@ -413,232 +1040,82 @@ static Value createCollectedConvOutput(ValueRange gemmRows, return collectComputeOp.getResult(0); } -static Value lowerSingleConvGroup(Value x, - Value w, - Value b, - RankedTensorType xType, - RankedTensorType wType, - RankedTensorType outType, - int64_t padHeightBegin, - int64_t padHeightEnd, - int64_t padWidthBegin, - int64_t padWidthEnd, - int64_t strideHeight, - int64_t strideWidth, - int64_t dilationHeight, - int64_t dilationWidth, - ConversionPatternRewriter& rewriter, - Location loc) { - const int64_t batchSize = xType.getDimSize(0); - const int64_t numChannelsIn = xType.getDimSize(1); - const int64_t xHeight = xType.getDimSize(2); - const int64_t xWidth = xType.getDimSize(3); - const int64_t numChannelsOut = wType.getDimSize(0); - const int64_t wHeight = wType.getDimSize(2); - const int64_t wWidth = wType.getDimSize(3); - const int64_t outHeight = outType.getDimSize(2); - const int64_t outWidth = outType.getDimSize(3); +static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor) { + ConvLoweringState state; + state.x = convOpAdaptor.getX(); + state.w = convOpAdaptor.getW(); + state.b = convOpAdaptor.getB(); + state.xType = cast(state.x.getType()); + state.wType = cast(state.w.getType()); + state.outType = cast(convOp.getY().getType()); - // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): - // A (im2col): [numPatches, patchSize] -- one row per output spatial position - // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns - // Gemm output: [numPatches, cOut] - const int64_t patchSize = numChannelsIn * wHeight * wWidth; - const int64_t numPatchesPerBatch = outHeight * outWidth; - const int64_t numPatches = batchSize * numPatchesPerBatch; - - auto elemType = xType.getElementType(); - auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType); - auto rowType = RankedTensorType::get({1, patchSize}, elemType); - auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); - auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); - auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); - auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType()); - - const int64_t xbarSize = static_cast(crossbarSize.getValue()); - const int64_t wMaxDim = std::max(patchSize, numChannelsOut); - const int64_t maxParallelPixels = std::max(1, xbarSize / wMaxDim); - auto wDenseAttr = getHostConstDenseElementsAttr(w); - - // Prepare weight matrix W for crossbar storage: - // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] - Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc); - - // Pass bias through directly; Gemm handles rank-1 C canonicalization. - bool hasB = !isa(b.getDefiningOp()); - Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); - Value biasMatrix; - DenseElementsAttr biasDenseAttr; - if (hasB) { - gemmBias = b; - biasDenseAttr = getHostConstDenseElementsAttr(b); - biasMatrix = expandBiasIfNeeded(b, rewriter, loc); - } - const bool canPackWeightsAsConstants = static_cast(wDenseAttr); - const bool canPackBiasAsConstants = !hasB || static_cast(biasDenseAttr); - const int64_t effectiveMaxParallelPixels = - (canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1; - - // Keep the standard im2col view of convolution: - // A (im2col): [numPatches, patchSize] -- one row per output spatial position - // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns - // and optionally repack several old rows into one GEMM row to use the available crossbar size better. - // - // We want to process N pixels at the same time. Instead of doing N separate operations - // of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix - // containing N copies of W^T and concatenate N im2col rows into one longer row: - // A_packed: [ceil(numPatches / N), N * patchSize] - // B_packed: [N * patchSize, N * cOut] - // Y_packed: [ceil(numPatches / N), N * cOut] - const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels); - auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType); - auto gemmOutputRowsType = - RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); - Value gemmInputRows = createIm2colRowComputes(x, - xType, - im2colType, - rowType, - gemmInputRowsType, - batchSize, - numChannelsIn, - xHeight, - xWidth, - wHeight, - wWidth, - padHeightBegin, - padHeightEnd, - padWidthBegin, - padWidthEnd, - strideHeight, - strideWidth, - dilationHeight, - dilationWidth, - outWidth, - patchSize, - numPatches, - numPatchesPerBatch, - effectiveMaxParallelPixels, - rewriter, - loc); - - Value gemmB = buildPackedWeight(wDenseAttr, - wTrans, - wType, - numChannelsIn, - numChannelsOut, - wHeight, - wWidth, - patchSize, - effectiveMaxParallelPixels, - rewriter, - loc); - Value gemmC = buildPackedBias( - hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); - - Value gemmRows = ONNXGemmOp::create(rewriter, - loc, - gemmOutputRowsType, - gemmInputRows, - gemmB, - gemmC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) - .getY(); - - return createCollectedConvOutput(ValueRange {gemmRows}, - outType, - gemmOutType, - nhwcType, - outType, - numPatches, - numChannelsOut, - effectiveMaxParallelPixels, - rewriter, - loc); -} - -} // namespace - -LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, - ONNXConvOpAdaptor convOpAdaptor, - ConversionPatternRewriter& rewriter) const { - Location loc = convOp.getLoc(); - Value x = convOpAdaptor.getX(); - Value w = convOpAdaptor.getW(); - Value b = convOpAdaptor.getB(); - - auto xType = cast(x.getType()); - auto wType = cast(w.getType()); - auto outType = cast(convOp.getY().getType()); - - if (!xType.hasStaticShape()) { + if (!state.xType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input"); return failure(); } - if (!wType.hasStaticShape()) { + if (!state.wType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight"); return failure(); } - if (!outType.hasStaticShape()) { + if (!state.outType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result"); return failure(); } - if (xType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4}); + if (state.xType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv input", state.xType.getRank(), {4}); return failure(); } - if (wType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4}); + if (state.wType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", state.wType.getRank(), {4}); return failure(); } - if (outType.getRank() != 4) { - pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4}); + if (state.outType.getRank() != 4) { + pim::emitUnsupportedRankDiagnostic(convOp, "conv result", state.outType.getRank(), {4}); return failure(); } - if (convOp.getGroup() < 1) { + + state.group = convOp.getGroup(); + if (state.group < 1) { convOp.emitOpError("requires group >= 1 for Spatial lowering"); return failure(); } - const int64_t batchSize = xType.getDimSize(0); - const int64_t numChannelsIn = xType.getDimSize(1); - const int64_t xHeight = xType.getDimSize(2); - const int64_t xWidth = xType.getDimSize(3); - const int64_t numChannelsOut = wType.getDimSize(0); - const int64_t wHeight = wType.getDimSize(2); - const int64_t wWidth = wType.getDimSize(3); - const int64_t outHeight = outType.getDimSize(2); - const int64_t outWidth = outType.getDimSize(3); - const int64_t group = convOp.getGroup(); - const bool hasB = !isa(b.getDefiningOp()); + state.batchSize = state.xType.getDimSize(0); + state.numChannelsIn = state.xType.getDimSize(1); + state.xHeight = state.xType.getDimSize(2); + state.xWidth = state.xType.getDimSize(3); + state.numChannelsOut = state.wType.getDimSize(0); + state.wHeight = state.wType.getDimSize(2); + state.wWidth = state.wType.getDimSize(3); + state.outHeight = state.outType.getDimSize(2); + state.outWidth = state.outType.getDimSize(3); + state.hasBias = !isa(state.b.getDefiningOp()); - if (numChannelsIn % group != 0) { - convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group + if (state.numChannelsIn % state.group != 0) { + convOp.emitOpError() << "requires input channels " << state.numChannelsIn << " to be divisible by group " + << state.group << " for Spatial lowering"; + return failure(); + } + if (state.numChannelsOut % state.group != 0) { + convOp.emitOpError() << "requires output channels " << state.numChannelsOut << " to be divisible by group " + << state.group << " for Spatial lowering"; + return failure(); + } + + state.numChannelsInPerGroup = state.numChannelsIn / state.group; + state.numChannelsOutPerGroup = state.numChannelsOut / state.group; + if (state.wType.getDimSize(1) != state.numChannelsInPerGroup) { + convOp.emitOpError() << "requires grouped conv weight input channels " << state.wType.getDimSize(1) + << " to match input channels per group " << state.numChannelsInPerGroup << " for Spatial lowering"; return failure(); } - if (numChannelsOut % group != 0) { - convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group - << " for Spatial lowering"; + if (state.wType.getDimSize(0) != state.numChannelsOut) { + convOp.emitOpError() << "requires weight output channels " << state.wType.getDimSize(0) + << " to match result channels " << state.numChannelsOut << " for Spatial lowering"; return failure(); } - const int64_t numChannelsInPerGroup = numChannelsIn / group; - const int64_t numChannelsOutPerGroup = numChannelsOut / group; - if (wType.getDimSize(1) != numChannelsInPerGroup) { - convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1) - << " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering"; - return failure(); - } - if (wType.getDimSize(0) != numChannelsOut) { - convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels " - << numChannelsOut << " for Spatial lowering"; - return failure(); - } - - // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) const auto stridesAttr = convOp.getStrides(); const auto dilationsAttr = convOp.getDilations(); const auto padsAttr = convOp.getPads(); @@ -656,79 +1133,104 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, return failure(); } - const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1); - const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1); - const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1); - const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1); - - int64_t padHeightBegin = 0; - int64_t padHeightEnd = 0; - int64_t padWidthBegin = 0; - int64_t padWidthEnd = 0; + state.strideHeight = getOptionalI64Attr(stridesAttr, 0, 1); + state.strideWidth = getOptionalI64Attr(stridesAttr, 1, 1); + state.dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1); + state.dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1); + state.padHeightBegin = 0; + state.padHeightEnd = 0; + state.padWidthBegin = 0; + state.padWidthEnd = 0; if (padsAttr) { - padHeightBegin = getI64Attr(*padsAttr, 0); - padWidthBegin = getI64Attr(*padsAttr, 1); - padHeightEnd = getI64Attr(*padsAttr, 2); - padWidthEnd = getI64Attr(*padsAttr, 3); + state.padHeightBegin = getI64Attr(*padsAttr, 0); + state.padWidthBegin = getI64Attr(*padsAttr, 1); + state.padHeightEnd = getI64Attr(*padsAttr, 2); + state.padWidthEnd = getI64Attr(*padsAttr, 3); + return state; } - else { - // Compute padding from auto_pad attribute - const auto autoPad = convOp.getAutoPad(); - if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { - const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; - const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; - const int64_t totalPadH = - std::max(static_cast(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight); - const int64_t totalPadW = - std::max(static_cast(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth); - if (autoPad == "SAME_UPPER") { - padHeightBegin = totalPadH / 2; - padHeightEnd = totalPadH - padHeightBegin; - padWidthBegin = totalPadW / 2; - padWidthEnd = totalPadW - padWidthBegin; - } - else { // SAME_LOWER - padHeightEnd = totalPadH / 2; - padHeightBegin = totalPadH - padHeightEnd; - padWidthEnd = totalPadW / 2; - padWidthBegin = totalPadW - padWidthEnd; - } + const auto autoPad = convOp.getAutoPad(); + if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { + const int64_t effectiveKernelH = (state.wHeight - 1) * state.dilationHeight + 1; + const int64_t effectiveKernelW = (state.wWidth - 1) * state.dilationWidth + 1; + const int64_t totalPadH = + std::max(static_cast(0), (state.outHeight - 1) * state.strideHeight + effectiveKernelH - state.xHeight); + const int64_t totalPadW = + std::max(static_cast(0), (state.outWidth - 1) * state.strideWidth + effectiveKernelW - state.xWidth); + + if (autoPad == "SAME_UPPER") { + state.padHeightBegin = totalPadH / 2; + state.padHeightEnd = totalPadH - state.padHeightBegin; + state.padWidthBegin = totalPadW / 2; + state.padWidthEnd = totalPadW - state.padWidthBegin; } - else if (autoPad != "NOTSET" && autoPad != "VALID") { - convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering"; - return failure(); + else { + state.padHeightEnd = totalPadH / 2; + state.padHeightBegin = totalPadH - state.padHeightEnd; + state.padWidthEnd = totalPadW / 2; + state.padWidthBegin = totalPadW - state.padWidthEnd; } - // "NOTSET" or "VALID" -> all pads stay 0 + return state; } - if (group == 1) { - rewriter.replaceOp(convOp, - lowerSingleConvGroup(x, - w, - b, - xType, - wType, - outType, - padHeightBegin, - padHeightEnd, - padWidthBegin, - padWidthEnd, - strideHeight, - strideWidth, - dilationHeight, - dilationWidth, - rewriter, - loc)); - return success(); + if (autoPad != "NOTSET" && autoPad != "VALID") { + convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering"; + return failure(); } - SmallVector xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc); - SmallVector wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc); + return state; +} + +static LogicalResult +rewriteUngroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { + rewriter.replaceOp(convOp, standard::rewriteConv(state, rewriter, convOp.getLoc())); + return success(); +} + +static LogicalResult +rewriteDepthwiseConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { + FailureOr result = depthwise::rewriteConv(convOp, state, rewriter, convOp.getLoc()); + if (failed(result)) + return failure(); + + rewriter.replaceOp(convOp, *result); + return success(); +} + +static ConvLoweringState makeGroupedConvLoweringState( + const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType) { + ConvLoweringState state = parent; + state.x = groupX; + state.w = groupW; + state.b = groupB; + state.xType = cast(groupX.getType()); + state.wType = cast(groupW.getType()); + state.outType = groupOutType; + state.batchSize = state.xType.getDimSize(0); + state.numChannelsIn = state.xType.getDimSize(1); + state.xHeight = state.xType.getDimSize(2); + state.xWidth = state.xType.getDimSize(3); + state.numChannelsOut = state.wType.getDimSize(0); + state.wHeight = state.wType.getDimSize(2); + state.wWidth = state.wType.getDimSize(3); + state.outHeight = state.outType.getDimSize(2); + state.outWidth = state.outType.getDimSize(3); + state.group = 1; + state.numChannelsInPerGroup = state.numChannelsIn; + state.numChannelsOutPerGroup = state.numChannelsOut; + state.hasBias = !isa(groupB.getDefiningOp()); + return state; +} + +static LogicalResult +rewriteGroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { + SmallVector xSlices = sliceTensor(state.x, /*axis=*/1, state.numChannelsInPerGroup, rewriter, convOp.getLoc()); + SmallVector wSlices = + sliceTensor(state.w, /*axis=*/0, state.numChannelsOutPerGroup, rewriter, convOp.getLoc()); SmallVector bSlices; - if (hasB) { - auto biasType = cast(b.getType()); + if (state.hasBias) { + auto biasType = cast(state.b.getType()); int64_t biasAxis = -1; if (biasType.getRank() == 1) biasAxis = 0; @@ -739,50 +1241,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, << biasType.getRank(); return failure(); } - bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc); + bSlices = sliceTensor(state.b, biasAxis, state.numChannelsOutPerGroup, rewriter, convOp.getLoc()); } - if (xSlices.size() != static_cast(group) || wSlices.size() != static_cast(group) - || (hasB && bSlices.size() != static_cast(group))) { + if (xSlices.size() != static_cast(state.group) || wSlices.size() != static_cast(state.group) + || (state.hasBias && bSlices.size() != static_cast(state.group))) { convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering"); return failure(); } SmallVector groupResults; - groupResults.reserve(group); - auto groupOutType = - RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType()); - Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); - for (int64_t groupId = 0; groupId < group; groupId++) { + groupResults.reserve(state.group); + auto groupOutType = RankedTensorType::get( + {state.batchSize, state.numChannelsOutPerGroup, state.outHeight, state.outWidth}, state.outType.getElementType()); + Value noBias = ONNXNoneOp::create(rewriter, convOp.getLoc(), rewriter.getNoneType()); + for (int64_t groupId = 0; groupId < state.group; groupId++) { Value groupX = xSlices[groupId]; Value groupW = wSlices[groupId]; - Value groupB = hasB ? bSlices[groupId] : noBias; - groupResults.push_back(lowerSingleConvGroup(groupX, - groupW, - groupB, - cast(groupX.getType()), - cast(groupW.getType()), - groupOutType, - padHeightBegin, - padHeightEnd, - padWidthBegin, - padWidthEnd, - strideHeight, - strideWidth, - dilationHeight, - dilationWidth, - rewriter, - loc)); + Value groupB = state.hasBias ? bSlices[groupId] : noBias; + ConvLoweringState groupState = makeGroupedConvLoweringState(state, groupX, groupW, groupB, groupOutType); + groupResults.push_back(standard::rewriteConv(groupState, rewriter, convOp.getLoc())); } Value result; if (llvm::all_of(groupResults, isCompileTimeComputable)) { - result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults); + result = createSpatConcat(rewriter, convOp.getLoc(), /*axis=*/1, groupResults); } else { - auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) { - spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args)); - }); + auto concatCompute = + createSpatCompute(rewriter, convOp.getLoc(), TypeRange {state.outType}, {}, groupResults, [&](ValueRange args) { + spatial::SpatYieldOp::create( + rewriter, convOp.getLoc(), createSpatConcat(rewriter, convOp.getLoc(), /*axis=*/1, args)); + }); result = concatCompute.getResult(0); } @@ -790,6 +1280,27 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, return success(); } +} // namespace + +LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, + ONNXConvOpAdaptor convOpAdaptor, + ConversionPatternRewriter& rewriter) const { + FailureOr state = analyzeConvLoweringState(convOp, convOpAdaptor); + if (failed(state)) + return failure(); + + if (isDepthwiseConv(state->group, state->numChannelsIn, state->numChannelsOut, state->numChannelsInPerGroup)) { + if (depthwise::canUseStructuredRewrite(*state)) + return rewriteDepthwiseConv(convOp, *state, rewriter); + return rewriteGroupedConv(convOp, *state, rewriter); + } + + if (state->group == 1) + return rewriteUngroupedConv(convOp, *state, rewriter); + + return rewriteGroupedConv(convOp, *state, rewriter); +} + void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/validation/operations/conv/grouped_many_groups/conv_grouped_many_groups.onnx b/validation/operations/conv/grouped_many_groups/conv_grouped_many_groups.onnx index 7df7a8e..53e138f 100644 Binary files a/validation/operations/conv/grouped_many_groups/conv_grouped_many_groups.onnx and b/validation/operations/conv/grouped_many_groups/conv_grouped_many_groups.onnx differ diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 49f491b..a1fdf18 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -262,7 +262,7 @@ def conv_grouped_many_groups(): X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 2, 2]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 2, 2]) W = numpy_helper.from_array( - np.random.default_rng(77).uniform(-1, 1, (1024, 64, 1, 1)).astype(np.float32), name="W") + np.random.default_rng(77).uniform(-1, 1, (1024, 16, 1, 1)).astype(np.float32), name="W") node = helper.make_node("Conv", ["X", "W"], ["Y"], kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0], group=64) graph = helper.make_graph([node], "conv_grouped_many_groups", [X], [Y], initializer=[W])