diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 8a3670f..569f404 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -1334,8 +1334,10 @@ static Value createCollectedConvOutput(ValueRange gemmRows, int64_t numPatches, int64_t numChannelsOut, int64_t packFactor, + ArrayRef distributedConsumers, ConversionPatternRewriter& rewriter, Location loc); +static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b); namespace depthwise { @@ -1637,15 +1639,17 @@ static bool canUseStructuredRewrite(const ConvLoweringState& state) { if (!getHostConstDenseElementsAttr(state.w)) return false; - if (!computeTiling(state.batchSize, - state.numChannelsIn, - state.numChannelsOut, - state.wHeight, - state.wWidth, - state.outHeight, - state.outWidth)) { + auto tiling = computeTiling(state.batchSize, + state.numChannelsIn, + state.numChannelsOut, + state.wHeight, + state.wWidth, + state.outHeight, + state.outWidth); + if (!tiling) + return false; + if (tiling->numChannelTiles > static_cast(crossbarCountInCore.getValue())) return false; - } if (isa(state.b.getDefiningOp())) return true; @@ -1711,19 +1715,51 @@ rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPattern auto inputTileType = RankedTensorType::get({1, tiling->channelsPerTile, state.wType.getDimSize(2), state.wType.getDimSize(3)}, paddedInputType.getElementType()); + SmallVector batchWeights; + if (tiling->numChannelTiles == 1) { + Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + batchWeights.push_back(createWeightTile(packedWeights, + c0, + cast(packedWeights.getType()), + *tiling, + rewriter, + loc)); + } + else { + batchWeights.push_back(packedWeights); + } auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {piecesType}, tiling->totalPatches * tiling->numChannelTiles, - ValueRange {packedWeights}, + batchWeights, batchInputs, [&](detail::SpatComputeBatchBodyArgs args) { + auto pickInputByRank = [&](int64_t rank) -> Value { + for (Value input : args.inputs) { + auto inputType = dyn_cast(input.getType()); + if (inputType && inputType.getRank() == rank) + return input; + } + return Value(); + }; + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); - Value patchIndex = affineModConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); - Value channelTileIndex = affineFloorDivConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); - Value inputTile = createInputTile(args.inputs.front(), + Value patchIndex = tiling->numChannelTiles == 1 + ? args.lane + : affineModConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); + Value channelTileIndex = tiling->numChannelTiles == 1 + ? getOrCreateIndexConstant(rewriter, anchorOp, 0) + : affineFloorDivConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); + Value paddedInputArg = pickInputByRank(/*rank=*/4); + if (!paddedInputArg) { + convOp.emitOpError("structured depthwise batch body requires a rank-4 padded input block argument"); + return failure(); + } + + Value inputTile = createInputTile(paddedInputArg, patchIndex, channelTileIndex, inputTileType, @@ -1735,24 +1771,31 @@ rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPattern state.outType.getDimSize(3), rewriter, loc); - Value weightTile = createWeightTile(args.weights.front(), - channelTileIndex, - cast(args.weights.front().getType()), - *tiling, - rewriter, - loc); + Value weightTile = tiling->numChannelTiles == 1 + ? args.weights.front() + : createWeightTile(args.weights.front(), + channelTileIndex, + cast(args.weights.front().getType()), + *tiling, + rewriter, + loc); auto rowTileType = RankedTensorType::get({1, tiling->tileOutputChannels}, state.outType.getElementType()); Value rowTile = spatial::SpatVMMOp::create(rewriter, loc, rowTileType, weightTile, inputTile).getResult(); if (args.inputs.size() > 1) { - Value biasTile = createBiasTile(args.inputs[1], channelTileIndex, *tiling, rewriter, loc); + Value biasArg = pickInputByRank(/*rank=*/2); + if (!biasArg) { + convOp.emitOpError("structured depthwise batch body requires a rank-2 bias block argument when bias is present"); + return failure(); + } + Value biasTile = tiling->numChannelTiles == 1 ? biasArg : createBiasTile(biasArg, channelTileIndex, *tiling, rewriter, loc); rowTile = spatial::SpatVAddOp::create(rewriter, loc, rowTileType, rowTile, biasTile).getResult(); } - SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling->tileOutputChannels)}; createParallelInsertSliceIntoBatchOutput( rewriter, loc, rowTile, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2)); + return success(); }); if (failed(batchOp)) return failure(); @@ -1760,12 +1803,16 @@ rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPattern auto nhwcType = RankedTensorType::get( {state.xType.getDimSize(0), state.outType.getDimSize(2), state.outType.getDimSize(3), state.outType.getDimSize(1)}, state.outType.getElementType()); - auto collectedRows = - reconstructDepthwiseGemmRows(batchOp->getResult(0), piecesType, gemmOutType, *tiling, rewriter, loc); - if (failed(collectedRows)) - return failure(); + Value collectedRows = batchOp->getResult(0); + if (tiling->numChannelTiles != 1) { + auto reconstructedRows = + reconstructDepthwiseGemmRows(batchOp->getResult(0), piecesType, gemmOutType, *tiling, rewriter, loc); + if (failed(reconstructedRows)) + return failure(); + collectedRows = *reconstructedRows; + } - return createCollectedConvOutput(ValueRange {*collectedRows}, + return createCollectedConvOutput(ValueRange {collectedRows}, state.outType, gemmOutType, nhwcType, @@ -1773,6 +1820,7 @@ rewriteConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPattern tiling->totalPatches, state.outType.getDimSize(1), /*packFactor=*/1, + {}, rewriter, loc); } @@ -1784,7 +1832,9 @@ namespace standard { struct ConvGemmPlan { int64_t patchSize; int64_t numPatchesPerBatch; - int64_t numPatches; + int64_t globalNumPatches; + int64_t chunkStart; + int64_t chunkNumPatches; int64_t maxParallelPixels; int64_t effectiveMaxParallelPixels; int64_t packedNumRows; @@ -1799,6 +1849,38 @@ struct ConvGemmPlan { RankedTensorType nhwcType; }; +static ConvGemmPlan +buildConvGemmPlan(const ConvLoweringState& state, + bool canPackWeightsAsConstants, + bool canPackBiasAsConstants, + int64_t chunkStart, + int64_t chunkNumPatches, + std::optional forcedPackFactor = std::nullopt); + +static PreparedConvInput prepareInputForIm2Col(const ConvLoweringState& state, + ConversionPatternRewriter& rewriter, + Location loc) { + if (state.padHeightBegin == 0 && state.padHeightEnd == 0 && state.padWidthBegin == 0 && state.padWidthEnd == 0) + return {state.x, state.xType}; + + auto paddedType = RankedTensorType::get({state.batchSize, + state.numChannelsIn, + state.xHeight + state.padHeightBegin + state.padHeightEnd, + state.xWidth + state.padWidthBegin + state.padWidthEnd}, + state.xType.getElementType()); + auto paddedInputOp = + createSpatCompute<1>(rewriter, loc, TypeRange {paddedType}, {}, state.x, [&](Value inputArg) { + Value paddedInput = createZeroPaddedTensor(inputArg, + paddedType, + {0, 0, state.padHeightBegin, state.padWidthBegin}, + {0, 0, state.padHeightEnd, state.padWidthEnd}, + rewriter, + loc); + spatial::SpatYieldOp::create(rewriter, loc, paddedInput); + }); + return {paddedInputOp.getResult(0), paddedType}; +} + static Value createPaddedRows(Value rows, RankedTensorType rowsType, int64_t paddedRows, @@ -1913,6 +1995,255 @@ static Value createWeightMatrix( return computeOp.getResult(0); } +static Value createPaddedConvMatrix(Value matrix, + RankedTensorType sourceType, + RankedTensorType paddedType, + ConversionPatternRewriter& rewriter, + Location loc) { + if (sourceType == paddedType) + return matrix; + return createZeroPaddedTensor(matrix, + paddedType, + {0, 0}, + {paddedType.getDimSize(0) - sourceType.getDimSize(0), + paddedType.getDimSize(1) - sourceType.getDimSize(1)}, + rewriter, + loc); +} + +static Value createPaddedConstantMatrix(DenseElementsAttr sourceAttr, + RankedTensorType sourceType, + RankedTensorType paddedType, + ConversionPatternRewriter& rewriter) { + SmallVector paddedValues( + paddedType.getNumElements(), cast(rewriter.getZeroAttr(paddedType.getElementType()))); + SmallVector sourceValues(sourceAttr.getValues()); + const int64_t sourceRows = sourceType.getDimSize(0); + const int64_t sourceCols = sourceType.getDimSize(1); + const int64_t paddedCols = paddedType.getDimSize(1); + for (int64_t row = 0; row < sourceRows; ++row) + for (int64_t col = 0; col < sourceCols; ++col) + paddedValues[row * paddedCols + col] = sourceValues[row * sourceCols + col]; + auto paddedAttr = DenseElementsAttr::get(paddedType, paddedValues); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), paddedAttr, paddedType); +} + +static Value createPaddedInputKTiledWeightConstant(DenseElementsAttr sourceAttr, + const ConvLoweringState& state, + int64_t paddedK, + int64_t paddedC, + ConversionPatternRewriter& rewriter) { + auto paddedType = RankedTensorType::get({paddedK, paddedC}, state.wType.getElementType()); + SmallVector sourceValues(sourceAttr.getValues()); + SmallVector paddedValues( + paddedType.getNumElements(), cast(rewriter.getZeroAttr(paddedType.getElementType()))); + for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) { + for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) { + for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) { + for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) { + const int64_t sourceFlatIndex = + (((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW; + const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW; + paddedValues[patchIndex * paddedC + outChannel] = sourceValues[sourceFlatIndex]; + } + } + } + } + auto paddedAttr = DenseElementsAttr::get(paddedType, paddedValues); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), paddedAttr, paddedType); +} + +static FailureOr rewriteInputKTiledConv(const ConvLoweringState& state, + ArrayRef distributedConsumers, + ConversionPatternRewriter& rewriter, + Location loc) { + PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); + ConvGeometry geo = buildConvGeometry(state); + const int64_t xbarDim = geo.xbarSize; + const int64_t numKSlices = ceilIntegerDivide(geo.k, xbarDim); + const int64_t paddedK = numKSlices * xbarDim; + const uint64_t maxLanesPerBatch = + std::max(1, + static_cast(crossbarCountInCore.getValue()) + / static_cast(std::max(1, numKSlices * 4))); + const uint64_t rowChunkWidth = std::max( + 1, + std::min({chooseStreamChunkPositions(geo, /*packFactor=*/1), + maxLanesPerBatch, + static_cast(state.outWidth)})); + const auto elementType = state.outType.getElementType(); + auto wDenseAttr = getHostConstDenseElementsAttr(state.w); + if (!wDenseAttr) + return failure(); + + Value paddedWeight = createPaddedInputKTiledWeightConstant(wDenseAttr, state, paddedK, xbarDim, rewriter); + + Value paddedBias; + RankedTensorType paddedBiasType; + if (state.hasBias) { + Value biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); + auto biasMatrixType = cast(biasMatrix.getType()); + paddedBiasType = RankedTensorType::get({1, xbarDim}, elementType); + if (auto biasDenseAttr = getHostConstDenseElementsAttr(state.b)) + paddedBias = createPaddedConstantMatrix(biasDenseAttr, biasMatrixType, paddedBiasType, rewriter); + else + paddedBias = materializeOrComputeUnary( + biasMatrix, paddedBiasType, rewriter, loc, [&](Value biasValue) { + return createPaddedConvMatrix(biasValue, biasMatrixType, paddedBiasType, rewriter, loc); + }); + } + + SmallVector chunkRows; + const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth; + chunkRows.reserve( + state.batchSize * state.outHeight * ceilIntegerDivide(state.outWidth, static_cast(rowChunkWidth))); + for (int64_t batchIndex = 0; batchIndex < state.batchSize; ++batchIndex) { + for (int64_t outHeightIndex = 0; outHeightIndex < state.outHeight; ++outHeightIndex) { + for (int64_t outWidthChunkStart = 0; outWidthChunkStart < state.outWidth; + outWidthChunkStart += static_cast(rowChunkWidth)) { + const int64_t chunkNumPatches = + std::min(static_cast(rowChunkWidth), state.outWidth - outWidthChunkStart); + auto chunkRowsType = RankedTensorType::get({chunkNumPatches, state.numChannelsOut}, elementType); + auto paddedRowType = RankedTensorType::get({1, xbarDim}, elementType); + auto paddedChunkRowType = RankedTensorType::get({1, paddedK}, elementType); + auto patchType = RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elementType); + auto collapsedPatchType = RankedTensorType::get({1, geo.k}, elementType); + auto weightTileType = RankedTensorType::get({xbarDim, xbarDim}, state.wType.getElementType()); + auto rowType = RankedTensorType::get({1, state.numChannelsOut}, elementType); + SmallVector inputsStorage {preparedInput.value}; + if (state.hasBias) + inputsStorage.push_back(paddedBias); + ValueRange inputs(inputsStorage); + auto chunkCompute = spatial::SpatCompute::create(rewriter, loc, TypeRange {chunkRowsType}, ValueRange {paddedWeight}, inputs); + auto* block = new Block(); + block->addArgument(paddedWeight.getType(), loc); + for (Value input : inputs) + block->addArgument(input.getType(), loc); + chunkCompute.getBody().push_back(block); + rewriter.setInsertionPointToStart(block); + + auto buildChunk = [&]() -> LogicalResult { + Value weightArg = block->getArgument(0); + Value inputArg = block->getArgument(1); + Value biasArg = state.hasBias ? block->getArgument(2) : Value(); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value cBatchIndex = getOrCreateIndexConstant(rewriter, anchorOp, batchIndex); + Value cZero = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value cKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices); + Value cOne = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim); + Value cInputHeightOffset = + getOrCreateIndexConstant(rewriter, anchorOp, outHeightIndex * state.strideHeight); + Value chunkRowsValue = tensor::EmptyOp::create(rewriter, loc, chunkRowsType.getShape(), elementType); + + auto widthLoop = buildNormalizedScfFor( + rewriter, + loc, + cZero, + getOrCreateIndexConstant(rewriter, anchorOp, chunkNumPatches), + cOne, + ValueRange {chunkRowsValue}, + [&](OpBuilder&, Location nestedLoc, Value widthIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value laneWithChunkOffset = affineAddConst(rewriter, nestedLoc, widthIndex, outWidthChunkStart, anchorOp); + Value inputWidthOffset = createOrFoldAffineApply(rewriter, + nestedLoc, + getAffineDimExpr(0, rewriter.getContext()) * state.strideWidth, + ValueRange {laneWithChunkOffset}, + anchorOp); + Value patch = createConvInputPatch(inputArg, + patchType, + cBatchIndex, + cZero, + cInputHeightOffset, + inputWidthOffset, + state.dilationHeight, + state.dilationWidth, + rewriter, + nestedLoc); + Value patchRow = tensor::CollapseShapeOp::create(rewriter, + nestedLoc, + collapsedPatchType, + patch, + SmallVector { + {0}, + {1, 2, 3} + }); + Value paddedPatchRow = createZeroPaddedTensor( + patchRow, paddedChunkRowType, {0, 0}, {0, paddedK - geo.k}, rewriter, nestedLoc); + + auto zeroAttr = DenseElementsAttr::get(paddedRowType, rewriter.getZeroAttr(elementType)); + Value zeroRow = getOrCreateConstant(rewriter, anchorOp, zeroAttr, paddedRowType); + auto kLoop = buildNormalizedScfFor( + rewriter, + nestedLoc, + cZero, + cKSlices, + cOne, + ValueRange {zeroRow}, + [&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl& reduceYielded) { + Value acc = reduceIterArgs.front(); + Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar); + SmallVector aOffsets {rewriter.getIndexAttr(0), kOffset}; + SmallVector aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)}; + SmallVector unitStrides = getUnitStrides(rewriter, 2); + Value aTile = tensor::ExtractSliceOp::create( + rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides); + SmallVector bOffsets {kOffset, rewriter.getIndexAttr(0)}; + SmallVector bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)}; + Value bTile = tensor::ExtractSliceOp::create( + rewriter, reduceLoc, weightTileType, weightArg, bOffsets, bSizes, unitStrides); + Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult(); + reduceYielded.push_back( + spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult()); + return success(); + }); + if (failed(kLoop)) + return failure(); + + Value reduced = kLoop->results.front(); + if (state.hasBias) + reduced = spatial::SpatVAddOp::create(rewriter, nestedLoc, paddedRowType, reduced, biasArg).getResult(); + + Value row = reduced; + if (state.numChannelsOut != xbarDim) { + SmallVector rowOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector rowSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)}; + row = tensor::ExtractSliceOp::create( + rewriter, nestedLoc, rowType, reduced, rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); + } + + SmallVector outputOffsets {widthIndex, rewriter.getIndexAttr(0)}; + SmallVector outputSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)}; + Value updatedRows = tensor::InsertSliceOp::create( + rewriter, nestedLoc, row, iterArgs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2)); + yielded.push_back(updatedRows); + return success(); + }); + if (failed(widthLoop)) + return failure(); + spatial::SpatYieldOp::create(rewriter, loc, widthLoop->results.front()); + return success(); + }; + if (failed(buildChunk())) { + rewriter.setInsertionPointAfter(chunkCompute); + rewriter.eraseOp(chunkCompute); + return failure(); + } + rewriter.setInsertionPointAfter(chunkCompute); + chunkRows.push_back(chunkCompute.getResult(0)); + } + } + } + + auto nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, + elementType); + return createCollectedConvOutput( + chunkRows, state.outType, cast(chunkRows.front().getType()), nhwcType, state.outType, totalPatches, + state.numChannelsOut, /*packFactor=*/1, distributedConsumers, rewriter, loc); +} + static Value buildPackedWeights(DenseElementsAttr wDenseAttr, Value wTrans, const ConvLoweringState& state, @@ -1976,26 +2307,35 @@ static Value buildPackedBias(Value gemmBias, } static ConvGemmPlan -buildConvGemmPlan(const ConvLoweringState& state, bool canPackWeightsAsConstants, bool canPackBiasAsConstants) { +buildConvGemmPlan(const ConvLoweringState& state, + bool canPackWeightsAsConstants, + bool canPackBiasAsConstants, + int64_t chunkStart, + int64_t chunkNumPatches, + std::optional forcedPackFactor) { ConvGemmPlan plan; plan.patchSize = state.numChannelsIn * state.wHeight * state.wWidth; plan.numPatchesPerBatch = state.outHeight * state.outWidth; - plan.numPatches = state.batchSize * plan.numPatchesPerBatch; + plan.globalNumPatches = state.batchSize * plan.numPatchesPerBatch; + plan.chunkStart = chunkStart; + plan.chunkNumPatches = chunkNumPatches; const int64_t wMaxDim = std::max(plan.patchSize, state.numChannelsOut); - plan.maxParallelPixels = std::max(1, static_cast(crossbarSize.getValue()) / wMaxDim); + plan.maxParallelPixels = forcedPackFactor + ? *forcedPackFactor + : std::max(1, static_cast(crossbarSize.getValue()) / wMaxDim); plan.effectiveMaxParallelPixels = (canPackWeightsAsConstants && canPackBiasAsConstants) ? plan.maxParallelPixels : 1; - plan.packedNumRows = ceilIntegerDivide(plan.numPatches, plan.effectiveMaxParallelPixels); + plan.packedNumRows = ceilIntegerDivide(plan.chunkNumPatches, plan.effectiveMaxParallelPixels); auto elemType = state.xType.getElementType(); auto outElemType = state.outType.getElementType(); - plan.im2colType = RankedTensorType::get({plan.numPatches, plan.patchSize}, elemType); + plan.im2colType = RankedTensorType::get({plan.chunkNumPatches, plan.patchSize}, elemType); plan.im2colRowType = RankedTensorType::get({1, plan.patchSize}, elemType); plan.gemmInputRowsType = RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * plan.patchSize}, elemType); plan.wFlatType = RankedTensorType::get({state.numChannelsOut, plan.patchSize}, state.wType.getElementType()); plan.wTransType = RankedTensorType::get({plan.patchSize, state.numChannelsOut}, state.wType.getElementType()); - plan.gemmOutType = RankedTensorType::get({plan.numPatches, state.numChannelsOut}, outElemType); + plan.gemmOutType = RankedTensorType::get({plan.chunkNumPatches, state.numChannelsOut}, outElemType); plan.gemmOutputRowsType = RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * state.numChannelsOut}, outElemType); plan.nhwcType = @@ -2003,28 +2343,15 @@ buildConvGemmPlan(const ConvLoweringState& state, bool canPackWeightsAsConstants return plan; } -static Value createIm2colRows( - const ConvLoweringState& state, const ConvGemmPlan& plan, ConversionPatternRewriter& rewriter, Location loc) { +static Value createIm2colRows(const ConvLoweringState& state, + const PreparedConvInput& preparedInput, + const ConvGemmPlan& plan, + ConversionPatternRewriter& rewriter, + Location loc) { constexpr size_t numInputs = 1; auto im2colComputeOp = - createSpatCompute(rewriter, loc, TypeRange {plan.gemmInputRowsType}, {}, state.x, [&](Value xArg) { - auto elemType = state.xType.getElementType(); - Value paddedInput = xArg; - if (state.padHeightBegin || state.padHeightEnd || state.padWidthBegin || state.padWidthEnd) { - auto paddedInputType = RankedTensorType::get( - {state.batchSize, - state.numChannelsIn, - state.xHeight + state.padHeightBegin + state.padHeightEnd, - state.xWidth + state.padWidthBegin + state.padWidthEnd}, - elemType); - paddedInput = createZeroPaddedTensor(paddedInput, - paddedInputType, - {0, 0, state.padHeightBegin, state.padWidthBegin}, - {0, 0, state.padHeightEnd, state.padWidthEnd}, - rewriter, - loc); - } - + createSpatCompute(rewriter, loc, TypeRange {plan.gemmInputRowsType}, {}, preparedInput.value, [&](Value xArg) { + auto elemType = preparedInput.type.getElementType(); // Keep the standard im2col view of convolution, flipped so filters sit in // B / crossbar columns: // A (im2col): [numPatches, patchSize] -- one row per output spatial position @@ -2034,7 +2361,8 @@ static Value createIm2colRows( Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); - Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatches); + Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches); + Value cChunkStart = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkStart); Value cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatchesPerBatch); Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth); Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, state.strideHeight); @@ -2049,8 +2377,9 @@ static Value createIm2colRows( ValueRange {im2colInit}, [&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value im2colAcc = iterArgs.front(); - Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, patchIndex, cNumPatchesPerBatch); - Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, patchIndex, cNumPatchesPerBatch); + Value globalPatchIndex = arith::AddIOp::create(rewriter, nestedLoc, patchIndex, cChunkStart); + Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch); + Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch); Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth); Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth); Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight); @@ -2058,7 +2387,7 @@ static Value createIm2colRows( auto patchType = RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType); - Value patch = createConvInputPatch(paddedInput, + Value patch = createConvInputPatch(xArg, patchType, batchIndex, c0, @@ -2103,8 +2432,81 @@ static Value createIm2colRows( return im2colComputeOp->getResult(0); } -static Value rewriteConv(const ConvLoweringState& state, ConversionPatternRewriter& rewriter, Location loc) { +static Value maybeUnpackChunkRows(Value gemmRows, + const ConvGemmPlan& plan, + ConversionPatternRewriter& rewriter, + Location loc) { + if (plan.effectiveMaxParallelPixels == 1) + return gemmRows; + auto unpackedType = RankedTensorType::get( + {plan.chunkNumPatches, plan.gemmOutType.getDimSize(1)}, plan.gemmOutType.getElementType(), plan.gemmOutType.getEncoding()); + auto unpackCompute = createSpatCompute<1>(rewriter, loc, TypeRange {unpackedType}, {}, gemmRows, [&](Value rowsArg) { + Value unpacked = unpackRowsFromParallelGemm(rowsArg, + cast(rowsArg.getType()), + plan.chunkNumPatches, + plan.gemmOutType.getDimSize(1), + plan.effectiveMaxParallelPixels, + rewriter, + loc); + spatial::SpatYieldOp::create(rewriter, loc, unpacked); + }); + return unpackCompute.getResult(0); +} + +static Value createChunkedConvRows(const ConvLoweringState& state, + const PreparedConvInput& preparedInput, + Value weightMatrix, + Value gemmBias, + Value biasMatrix, + DenseElementsAttr wDenseAttr, + DenseElementsAttr biasDenseAttr, + int64_t forcedPackFactor, + uint64_t chunkPositions, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector chunkRows; + const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth; + for (int64_t chunkStart = 0; chunkStart < totalPatches; chunkStart += static_cast(chunkPositions)) { + const int64_t chunkNumPatches = std::min(static_cast(chunkPositions), totalPatches - chunkStart); + ConvGemmPlan chunkPlan = buildConvGemmPlan(state, + static_cast(wDenseAttr), + !state.hasBias || static_cast(biasDenseAttr), + chunkStart, + chunkNumPatches, + forcedPackFactor); + Value chunkInputRows = createIm2colRows(state, preparedInput, chunkPlan, rewriter, loc); + Value chunkB = buildPackedWeights(wDenseAttr, weightMatrix, state, chunkPlan, rewriter, loc); + Value chunkC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, chunkPlan, rewriter, loc); + Value chunkGemmRows = ONNXGemmOp::create(rewriter, + loc, + chunkPlan.gemmOutputRowsType, + chunkInputRows, + chunkB, + chunkC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + chunkRows.push_back(maybeUnpackChunkRows(chunkGemmRows, chunkPlan, rewriter, loc)); + } + + if (chunkRows.size() == 1) + return chunkRows.front(); + + auto rowType = RankedTensorType::get({totalPatches, state.numChannelsOut}, state.outType.getElementType()); + auto collectRows = createSpatCompute(rewriter, loc, TypeRange {rowType}, {}, chunkRows, [&](ValueRange rows) { + spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, rows)); + }); + return collectRows.getResult(0); +} + +static Value rewritePackedIm2ColConv(const ConvLoweringState& state, + ArrayRef distributedConsumers, + ConversionPatternRewriter& rewriter, + Location loc) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); + PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value biasMatrix; DenseElementsAttr biasDenseAttr; @@ -2115,11 +2517,12 @@ static Value rewriteConv(const ConvLoweringState& state, ConversionPatternRewrit } ConvGemmPlan plan = - buildConvGemmPlan(state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr)); + buildConvGemmPlan(state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), 0, + state.batchSize * state.outHeight * state.outWidth); // Prepare weight matrix W for crossbar storage: // W: [Cout, Cin, KH, KW] -> [Cout, patchSize] -> [patchSize, Cout] Value weightMatrix = createWeightMatrix(state.w, plan, rewriter, loc); - Value gemmInputRows = createIm2colRows(state, plan, rewriter, loc); + Value gemmInputRows = createIm2colRows(state, preparedInput, plan, rewriter, loc); Value gemmB = buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc); Value gemmC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc); @@ -2140,15 +2543,339 @@ static Value rewriteConv(const ConvLoweringState& state, ConversionPatternRewrit plan.gemmOutType, plan.nhwcType, state.outType, - plan.numPatches, + plan.chunkNumPatches, state.numChannelsOut, plan.effectiveMaxParallelPixels, + distributedConsumers, rewriter, loc); } +static Value rewriteStreamedConv(const ConvLoweringState& state, + ArrayRef distributedConsumers, + ConversionPatternRewriter& rewriter, + Location loc, + int64_t forcedPackFactor) { + auto wDenseAttr = getHostConstDenseElementsAttr(state.w); + PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); + Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + Value biasMatrix; + DenseElementsAttr biasDenseAttr; + if (state.hasBias) { + gemmBias = state.b; + biasDenseAttr = getHostConstDenseElementsAttr(state.b); + biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); + } + + ConvGemmPlan seedPlan = buildConvGemmPlan( + state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), 0, 1, forcedPackFactor); + Value weightMatrix = createWeightMatrix(state.w, seedPlan, rewriter, loc); + ConvGeometry geo = buildConvGeometry(state); + uint64_t chunkPositions = chooseStreamChunkPositions(geo, forcedPackFactor); + Value collectedRows = createChunkedConvRows(state, + preparedInput, + weightMatrix, + gemmBias, + biasMatrix, + wDenseAttr, + biasDenseAttr, + forcedPackFactor, + chunkPositions, + rewriter, + loc); + auto gemmOutType = cast(collectedRows.getType()); + auto nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, + state.outType.getElementType()); + return createCollectedConvOutput( + ValueRange {collectedRows}, state.outType, gemmOutType, nhwcType, state.outType, gemmOutType.getDimSize(0), + state.numChannelsOut, /*packFactor=*/1, distributedConsumers, rewriter, loc); +} + } // namespace standard +static RankedTensorType getRowStripFragmentType(RankedTensorType tensorType, int64_t width) { + return RankedTensorType::get( + {tensorType.getDimSize(0), tensorType.getDimSize(1), 1, width}, tensorType.getElementType(), tensorType.getEncoding()); +} + +static SmallVector buildRowStripFragments(RankedTensorType tensorType) { + SmallVector fragments; + const int64_t height = tensorType.getDimSize(2); + const int64_t width = tensorType.getDimSize(3); + const int64_t channels = tensorType.getDimSize(1); + fragments.reserve(height); + for (int64_t row = 0; row < height; ++row) { + fragments.push_back(DistributedFragmentInfo { + {0, 0, row, 0}, + {1, channels, 1, width}, + {1, 1, 1, 1}, + row, + }); + } + return fragments; +} + +static DistributedTensorInfo makeDistributedTensorInfo(Value storage, RankedTensorType logicalType) { + DistributedTensorInfo info; + info.storage = storage; + info.logicalType = logicalType; + info.fragments = buildRowStripFragments(logicalType); + info.laneCount = logicalType.getDimSize(2); + info.channels = logicalType.getDimSize(1); + info.height = logicalType.getDimSize(2); + info.width = logicalType.getDimSize(3); + return info; +} + +static Value createPerChannelConstantFragment(DenseElementsAttr denseAttr, + RankedTensorType fragmentType, + ConversionPatternRewriter& rewriter) { + auto denseType = cast(denseAttr.getType()); + SmallVector channelValues; + channelValues.reserve(fragmentType.getDimSize(1)); + SmallVector flattened(denseAttr.getValues()); + if (denseType.getRank() == 1) { + channelValues = flattened; + } + else if (denseType.getRank() == 2) { + channelValues = flattened; + } + else { + for (int64_t channel = 0; channel < denseType.getDimSize(1); ++channel) + channelValues.push_back(flattened[channel]); + } + + SmallVector values; + values.reserve(fragmentType.getNumElements()); + for (int64_t n = 0; n < fragmentType.getDimSize(0); ++n) + for (int64_t channel = 0; channel < fragmentType.getDimSize(1); ++channel) + for (int64_t h = 0; h < fragmentType.getDimSize(2); ++h) + for (int64_t w = 0; w < fragmentType.getDimSize(3); ++w) + values.push_back(channelValues[channel]); + + auto attr = DenseElementsAttr::get(fragmentType, values); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), attr, fragmentType); +} + +static Value createFragmentConstant(const DistributedTensorStep& step, + RankedTensorType fragmentType, + ConversionPatternRewriter& rewriter) { + if (step.constantKind == DistributedTensorConstantKind::PerChannel) + return createPerChannelConstantFragment(step.constantAttr, fragmentType, rewriter); + + Attribute splatValue = step.constantAttr.getSplatValue(); + return getOrCreateConstant(rewriter, + rewriter.getInsertionBlock()->getParentOp(), + DenseElementsAttr::get(fragmentType, splatValue), + fragmentType); +} + +static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, + RankedTensorType fragmentType, + ConversionPatternRewriter& rewriter) { + SmallVector values; + if (step.constantKind == DistributedTensorConstantKind::PerChannel) { + auto denseType = cast(step.constantAttr.getType()); + SmallVector channelValues; + for (const APFloat& value : step.constantAttr.getValues()) + channelValues.push_back(value); + values.reserve(fragmentType.getNumElements()); + for (int64_t n = 0; n < fragmentType.getDimSize(0); ++n) + for (int64_t channel = 0; channel < fragmentType.getDimSize(1); ++channel) + for (int64_t h = 0; h < fragmentType.getDimSize(2); ++h) + for (int64_t w = 0; w < fragmentType.getDimSize(3); ++w) { + APFloat reciprocal = channelValues[channel]; + APFloat one(reciprocal.getSemantics(), 1); + [[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven); + assert(!(status & APFloat::opInvalidOp) && "distributed conv div requires finite non-zero constant"); + values.push_back(one); + } + (void)denseType; + } + else { + APFloat reciprocal = cast(step.constantAttr).getSplatValue(); + APFloat one(reciprocal.getSemantics(), 1); + [[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven); + assert(!(status & APFloat::opInvalidOp) && "distributed conv div requires finite non-zero constant"); + values.assign(fragmentType.getNumElements(), one); + } + return getOrCreateConstant(rewriter, + rewriter.getInsertionBlock()->getParentOp(), + DenseFPElementsAttr::get(fragmentType, values), + fragmentType); +} + +[[maybe_unused]] static FailureOr createConvRowsForStrategy(const ConvLoweringState& state, + const ConvLoweringDecision& decision, + ConversionPatternRewriter& rewriter, + Location loc) { + auto wDenseAttr = getHostConstDenseElementsAttr(state.w); + PreparedConvInput preparedInput = standard::prepareInputForIm2Col(state, rewriter, loc); + Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + Value biasMatrix; + DenseElementsAttr biasDenseAttr; + if (state.hasBias) { + gemmBias = state.b; + biasDenseAttr = getHostConstDenseElementsAttr(state.b); + biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); + } + + switch (decision.strategy) { + case PimConvLoweringLegacy: + case PimConvLoweringPackedIm2Col: { + standard::ConvGemmPlan plan = standard::buildConvGemmPlan( + state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), 0, + state.batchSize * state.outHeight * state.outWidth); + Value weightMatrix = standard::createWeightMatrix(state.w, plan, rewriter, loc); + Value gemmInputRows = standard::createIm2colRows(state, preparedInput, plan, rewriter, loc); + Value gemmB = standard::buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc); + Value gemmC = standard::buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc); + Value gemmRows = ONNXGemmOp::create(rewriter, + loc, + plan.gemmOutputRowsType, + gemmInputRows, + gemmB, + gemmC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + return standard::maybeUnpackChunkRows(gemmRows, plan, rewriter, loc); + } + case PimConvLoweringStreamedPatch: + case PimConvLoweringOutputChannelTiled: + case PimConvLoweringTiled2D: + case PimConvLoweringStreamedPacked: { + standard::ConvGemmPlan seedPlan = standard::buildConvGemmPlan( + state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), 0, 1, + decision.strategy == PimConvLoweringStreamedPacked ? buildConvGeometry(state).pack : 1); + Value weightMatrix = standard::createWeightMatrix(state.w, seedPlan, rewriter, loc); + ConvGeometry geo = buildConvGeometry(state); + int64_t packFactor = decision.strategy == PimConvLoweringStreamedPacked ? geo.pack : 1; + uint64_t chunkPositions = chooseStreamChunkPositions(geo, packFactor); + return standard::createChunkedConvRows(state, + preparedInput, + weightMatrix, + gemmBias, + biasMatrix, + wDenseAttr, + biasDenseAttr, + packFactor, + chunkPositions, + rewriter, + loc); + } + default: + return failure(); + } +} + +[[maybe_unused]] static FailureOr createDistributedTensorFromRows(Value rows, + RankedTensorType logicalType, + ConversionPatternRewriter& rewriter, + Location loc) { + const int64_t width = logicalType.getDimSize(3); + const int64_t height = logicalType.getDimSize(2); + auto rowsType = cast(rows.getType()); + auto rowSliceType = + RankedTensorType::get({width, logicalType.getDimSize(1)}, logicalType.getElementType(), rowsType.getEncoding()); + auto channelWidthType = + RankedTensorType::get({logicalType.getDimSize(1), width}, logicalType.getElementType(), rowsType.getEncoding()); + auto fragmentType = getRowStripFragmentType(logicalType, width); + auto batchOp = createSpatComputeBatch( + rewriter, loc, TypeRange {logicalType}, height, {}, ValueRange {rows}, [&](detail::SpatComputeBatchBodyArgs args) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value rowStart = affineMulConst(rewriter, loc, args.lane, width, anchorOp); + SmallVector rowOffsets {rowStart, rewriter.getIndexAttr(0)}; + SmallVector rowSizes {rewriter.getIndexAttr(width), rewriter.getIndexAttr(logicalType.getDimSize(1))}; + Value rowSlice = tensor::ExtractSliceOp::create( + rewriter, loc, rowSliceType, args.inputs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); + Value channelWidth = ONNXTransposeOp::create( + rewriter, loc, channelWidthType, rowSlice, rewriter.getI64ArrayAttr({1, 0})).getResult(); + Value fragment = tensor::ExpandShapeOp::create(rewriter, + loc, + fragmentType, + channelWidth, + SmallVector {{0, 1}, {2, 3}}); + SmallVector outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, + rewriter.getIndexAttr(0)}; + SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(logicalType.getDimSize(1)), + rewriter.getIndexAttr(1), rewriter.getIndexAttr(width)}; + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, fragment, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 4)); + return success(); + }); + if (failed(batchOp)) + return failure(); + return makeDistributedTensorInfo(batchOp->getResult(0), logicalType); +} + +[[maybe_unused]] static FailureOr applyDistributedPreservingStep(const DistributedTensorInfo& inputInfo, + const DistributedTensorStep& step, + ConversionPatternRewriter& rewriter, + Location loc) { + auto logicalType = inputInfo.logicalType; + const int64_t width = logicalType.getDimSize(3); + auto fragmentType = getRowStripFragmentType(logicalType, width); + auto batchOp = createSpatComputeBatch(rewriter, + loc, + TypeRange {logicalType}, + inputInfo.laneCount, + {}, + ValueRange {inputInfo.storage}, + [&](detail::SpatComputeBatchBodyArgs args) { + SmallVector offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), + args.lane, rewriter.getIndexAttr(0)}; + SmallVector sizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(logicalType.getDimSize(1)), + rewriter.getIndexAttr(1), rewriter.getIndexAttr(width)}; + Value fragment = tensor::ExtractSliceOp::create( + rewriter, loc, fragmentType, args.inputs.front(), offsets, sizes, getUnitStrides(rewriter, 4)); + switch (step.kind) { + case DistributedTensorOpKind::Relu: + fragment = spatial::SpatReluOp::create(rewriter, loc, fragmentType, fragment).getResult(); + break; + case DistributedTensorOpKind::Sigmoid: + fragment = spatial::SpatSigmoidOp::create(rewriter, loc, fragmentType, fragment).getResult(); + break; + case DistributedTensorOpKind::Add: { + Value constant = createFragmentConstant(step, fragmentType, rewriter); + fragment = + spatial::SpatVAddOp::create(rewriter, loc, fragmentType, fragment, constant).getResult(); + break; + } + case DistributedTensorOpKind::Sub: { + Value constant = createFragmentConstant(step, fragmentType, rewriter); + Value lhs = step.fragmentOnLhs ? fragment : constant; + Value rhs = step.fragmentOnLhs ? constant : fragment; + fragment = spatial::SpatVSubOp::create(rewriter, loc, fragmentType, lhs, rhs).getResult(); + break; + } + case DistributedTensorOpKind::Mul: { + Value constant = createFragmentConstant(step, fragmentType, rewriter); + fragment = + spatial::SpatVMulOp::create(rewriter, loc, fragmentType, fragment, constant).getResult(); + break; + } + case DistributedTensorOpKind::Div: { + Value constant = createFragmentReciprocalConstant(step, fragmentType, rewriter); + fragment = + spatial::SpatVMulOp::create(rewriter, loc, fragmentType, fragment, constant).getResult(); + break; + } + case DistributedTensorOpKind::Conv: + return failure(); + } + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, fragment, args.outputs.front(), offsets, sizes, getUnitStrides(rewriter, 4)); + return success(); + }); + if (failed(batchOp)) + return failure(); + return makeDistributedTensorInfo(batchOp->getResult(0), logicalType); +} + static Value createCollectedConvOutput(ValueRange gemmRows, Type convType, RankedTensorType gemmOutType, @@ -2157,15 +2884,77 @@ static Value createCollectedConvOutput(ValueRange gemmRows, int64_t numPatches, int64_t numChannelsOut, int64_t packFactor, + ArrayRef distributedConsumers, ConversionPatternRewriter& rewriter, Location loc) { + auto materializeSplatTensor = [&](DenseElementsAttr denseAttr, RankedTensorType targetType) { + Attribute splatValue = denseAttr.getSplatValue(); + auto targetAttr = DenseElementsAttr::get(targetType, splatValue); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), targetAttr, targetType); + }; + + auto materializeReciprocalSplatTensor = [&](DenseFPElementsAttr denseAttr, RankedTensorType targetType) { + APFloat reciprocal = denseAttr.getSplatValue(); + APFloat one(reciprocal.getSemantics(), 1); + [[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven); + assert(!(status & APFloat::opInvalidOp) && "distributed conv div consumer requires finite non-zero scalar"); + return getOrCreateConstant( + rewriter, rewriter.getInsertionBlock()->getParentOp(), DenseFPElementsAttr::get(targetType, one), targetType); + }; + + auto applyDistributedConsumers = [&](Value fragment) { + Value current = fragment; + for (const DistributedTensorStep& step : distributedConsumers) { + auto fragmentType = cast(current.getType()); + switch (step.kind) { + case DistributedTensorOpKind::Relu: + current = spatial::SpatReluOp::create(rewriter, loc, fragmentType, current).getResult(); + break; + case DistributedTensorOpKind::Sigmoid: + current = spatial::SpatSigmoidOp::create(rewriter, loc, fragmentType, current).getResult(); + break; + case DistributedTensorOpKind::Add: { + Value splat = materializeSplatTensor(step.constantAttr, fragmentType); + current = spatial::SpatVAddOp::create(rewriter, loc, fragmentType, current, splat).getResult(); + break; + } + case DistributedTensorOpKind::Sub: { + Value splat = materializeSplatTensor(step.constantAttr, fragmentType); + Value lhs = step.fragmentOnLhs ? current : splat; + Value rhs = step.fragmentOnLhs ? splat : current; + current = spatial::SpatVSubOp::create(rewriter, loc, fragmentType, lhs, rhs).getResult(); + break; + } + case DistributedTensorOpKind::Mul: { + Value splat = materializeSplatTensor(step.constantAttr, fragmentType); + current = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, current, splat).getResult(); + break; + } + case DistributedTensorOpKind::Div: { + auto reciprocalAttr = cast(step.constantAttr); + Value reciprocal = materializeReciprocalSplatTensor(reciprocalAttr, fragmentType); + current = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, current, reciprocal).getResult(); + break; + } + case DistributedTensorOpKind::Conv: + llvm_unreachable("conv-consuming distributed chains should not materialize through createCollectedConvOutput"); + } + } + return current; + }; + auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) { + SmallVector transformedRows; + transformedRows.reserve(gemmRowArgs.size()); + for (Value row : gemmRowArgs) + transformedRows.push_back(applyDistributedConsumers(row)); + Value gemmOut; if (packFactor == 1) { - gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); + gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, transformedRows); } else { - Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); + Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, transformedRows); gemmOut = standard::unpackRowsFromParallelGemm( packedOutput, cast(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc); } @@ -2188,11 +2977,11 @@ static Value createCollectedConvOutput(ValueRange gemmRows, return collectComputeOp.getResult(0); } -static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor) { +static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b) { ConvLoweringState state; - state.x = convOpAdaptor.getX(); - state.w = convOpAdaptor.getW(); - state.b = convOpAdaptor.getB(); + state.x = x; + state.w = w; + state.b = b; state.xType = cast(state.x.getType()); state.wType = cast(state.w.getType()); state.outType = cast(convOp.getY().getType()); @@ -2330,22 +3119,140 @@ static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, return state; } -static LogicalResult -rewriteUngroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { - rewriter.replaceOp(convOp, standard::rewriteConv(state, rewriter, convOp.getLoc())); - return success(); +static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor) { + return analyzeConvLoweringState(convOp, convOpAdaptor.getX(), convOpAdaptor.getW(), convOpAdaptor.getB()); } static LogicalResult -rewriteDepthwiseConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { - FailureOr result = depthwise::rewriteConv(convOp, state, rewriter, convOp.getLoc()); - if (failed(result)) +createConvValueForStrategy(ONNXConvOp convOp, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + const DistributedConvAnalysis& analysis, + ArrayRef distributedConsumers, + ConversionPatternRewriter& rewriter, + FailureOr& result) { + const ConvGeometry geo = buildConvGeometry(state); + const ConvStrategyEstimate estimate = estimateConvStrategy(geo, decision.strategy, analysis); + switch (decision.strategy) { + case PimConvLoweringDepthwise: { + result = depthwise::rewriteConv(convOp, state, rewriter, convOp.getLoc()); + if (failed(result)) + return failure(); + reportConvLoweringDecision( + convOp, geo, decision, estimate, /*batchSize=*/geo.p, /*numberOfBatches=*/1, /*usesComputeBatch=*/true, + /*usesBatchedInstructionEmission=*/true, std::nullopt); + return success(); + } + case PimConvLoweringLegacy: + case PimConvLoweringPackedIm2Col: { + reportConvLoweringDecision( + convOp, geo, decision, estimate, /*batchSize=*/geo.pack, /*numberOfBatches=*/1, /*usesComputeBatch=*/true, + /*usesBatchedInstructionEmission=*/true, std::nullopt); + result = standard::rewritePackedIm2ColConv(state, distributedConsumers, rewriter, convOp.getLoc()); + return success(); + } + case PimConvLoweringStreamedPatch: + case PimConvLoweringOutputChannelTiled: + case PimConvLoweringTiled2D: { + uint64_t chunkPositions = chooseStreamChunkPositions(geo, /*packFactor=*/1); + const int64_t batches = ceilIntegerDivide(geo.p, static_cast(chunkPositions)); + reportConvLoweringDecision(convOp, + geo, + decision, + estimate, + /*batchSize=*/1, + batches, + /*usesComputeBatch=*/true, + /*usesBatchedInstructionEmission=*/true, + chunkPositions); + result = standard::rewriteStreamedConv(state, distributedConsumers, rewriter, convOp.getLoc(), /*forcedPackFactor=*/1); + return success(); + } + case PimConvLoweringInputKTiled: { + const int64_t numKSlices = ceilIntegerDivide(geo.k, geo.xbarSize); + const uint64_t maxLanesPerBatch = + std::max(1, + static_cast(crossbarCountInCore.getValue()) + / static_cast(std::max(1, numKSlices * 4))); + const uint64_t rowChunkWidth = std::max( + 1, + std::min({chooseStreamChunkPositions(geo, /*packFactor=*/1), + maxLanesPerBatch, + static_cast(state.outWidth)})); + const int64_t batches = + state.batchSize * state.outHeight * ceilIntegerDivide(state.outWidth, static_cast(rowChunkWidth)); + reportConvLoweringDecision(convOp, + geo, + decision, + estimate, + /*batchSize=*/1, + batches, + /*usesComputeBatch=*/false, + /*usesBatchedInstructionEmission=*/false, + rowChunkWidth); + result = standard::rewriteInputKTiledConv(state, distributedConsumers, rewriter, convOp.getLoc()); + return success(); + } + case PimConvLoweringStreamedPacked: { + uint64_t chunkPositions = chooseStreamChunkPositions(geo, geo.pack); + const int64_t batches = ceilIntegerDivide(geo.p, static_cast(chunkPositions)); + reportConvLoweringDecision(convOp, + geo, + decision, + estimate, + /*batchSize=*/geo.pack, + batches, + /*usesComputeBatch=*/true, + /*usesBatchedInstructionEmission=*/true, + chunkPositions); + result = standard::rewriteStreamedConv(state, distributedConsumers, rewriter, convOp.getLoc(), geo.pack); + return success(); + } + case PimConvLoweringAuto: + break; + } + return convOp.emitOpError("unexpected auto strategy at Conv lowering dispatch"); +} + +static LogicalResult +rewriteSelectedConv(ONNXConvOp convOp, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + const DistributedConvAnalysis& analysis, + ConversionPatternRewriter& rewriter) { + FailureOr result = failure(); + if (failed(createConvValueForStrategy(convOp, state, decision, analysis, analysis.steps, rewriter, result))) return failure(); - rewriter.replaceOp(convOp, *result); + if (!analysis.hasLocalConsumers()) { + rewriter.replaceOp(convOp, *result); + return success(); + } + + assert(analysis.replacementOp && "conv rewrite expects a replacement op"); + rewriter.replaceOp(analysis.replacementOp, *result); + for (auto it = analysis.steps.rbegin(); it != analysis.steps.rend(); ++it) + if (it->op != analysis.replacementOp) + rewriter.eraseOp(it->op); + rewriter.eraseOp(convOp); return success(); } +static LogicalResult +rewriteUngroupedConv(ONNXConvOp convOp, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + const DistributedConvAnalysis& analysis, + ConversionPatternRewriter& rewriter) { + return rewriteSelectedConv(convOp, state, decision, analysis, rewriter); +} + +static LogicalResult +rewriteGroupedConv(ONNXConvOp convOp, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + ConversionPatternRewriter& rewriter); + static ConvLoweringState makeGroupedConvLoweringState( const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType) { ConvLoweringState state = parent; @@ -2372,7 +3279,10 @@ static ConvLoweringState makeGroupedConvLoweringState( } static LogicalResult -rewriteGroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, ConversionPatternRewriter& rewriter) { +rewriteGroupedConv(ONNXConvOp convOp, + const ConvLoweringState& state, + const ConvLoweringDecision& decision, + ConversionPatternRewriter& rewriter) { SmallVector xSlices = sliceTensor(state.x, /*axis=*/1, state.numChannelsInPerGroup, rewriter, convOp.getLoc()); SmallVector wSlices = sliceTensor(state.w, /*axis=*/0, state.numChannelsOutPerGroup, rewriter, convOp.getLoc()); @@ -2408,7 +3318,13 @@ rewriteGroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, Conversion Value groupW = wSlices[groupId]; Value groupB = state.hasBias ? bSlices[groupId] : noBias; ConvLoweringState groupState = makeGroupedConvLoweringState(state, groupX, groupW, groupB, groupOutType); - groupResults.push_back(standard::rewriteConv(groupState, rewriter, convOp.getLoc())); + FailureOr groupResult = failure(); + DistributedConvAnalysis groupAnalysis; + groupAnalysis.barrierKind = DistributedConvBarrierKind::GroupedConv; + groupAnalysis.barrierDetail = "grouped convolution still materializes densely"; + if (failed(createConvValueForStrategy(convOp, groupState, decision, groupAnalysis, {}, rewriter, groupResult))) + return failure(); + groupResults.push_back(*groupResult); } Value result; @@ -2437,16 +3353,48 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, if (failed(state)) return failure(); - if (isDepthwiseConv(state->group, state->numChannelsIn, state->numChannelsOut, state->numChannelsInPerGroup)) { - if (depthwise::canUseStructuredRewrite(*state)) - return rewriteDepthwiseConv(convOp, *state, rewriter); - return rewriteGroupedConv(convOp, *state, rewriter); + FailureOr requestedStrategy = resolveRequestedConvLoweringStrategy(convOp); + if (failed(requestedStrategy)) + return failure(); + + DistributedConvAnalysis distributedAnalysis; + if (state->group == 1) + distributedAnalysis = analyzeDistributedConvConsumers(convOp); + else { + distributedAnalysis.barrierKind = DistributedConvBarrierKind::GroupedConv; + distributedAnalysis.barrierDetail = "grouped convolution still materializes densely"; } - if (state->group == 1) - return rewriteUngroupedConv(convOp, *state, rewriter); + ConvGeometry geometry = buildConvGeometry(*state); + ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, distributedAnalysis); + if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state) + && *requestedStrategy == PimConvLoweringAuto) { + decision = {PimConvLoweringLegacy, + "depthwise auto fallback when structured depthwise lowering is not representable", + /*isAuto=*/true, + "", + ""}; + } + if (failed(verifyForcedConvLoweringStrategy(convOp, geometry, decision.strategy))) + return failure(); - return rewriteGroupedConv(convOp, *state, rewriter); + if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state)) + return convOp.emitOpError("selected depthwise Conv lowering requires constant-derived weights and sliceable bias"); + + if (decision.strategy == PimConvLoweringDepthwise) { + distributedAnalysis.barrierKind = DistributedConvBarrierKind::Depthwise; + distributedAnalysis.barrierDetail = "depthwise lowering still materializes densely"; + recordDistributedConvOutcome(distributedAnalysis); + return rewriteSelectedConv(convOp, *state, decision, distributedAnalysis, rewriter); + } + + if (state->group == 1) { + recordDistributedConvOutcome(distributedAnalysis); + return rewriteUngroupedConv(convOp, *state, decision, distributedAnalysis, rewriter); + } + + recordDistributedConvOutcome(distributedAnalysis); + return rewriteGroupedConv(convOp, *state, decision, rewriter); } void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); }