#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include #include #include #include #include #include #include #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct ConvToGemm : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; struct ConvLoweringDecision { PimConvLoweringType strategy; std::string reason; bool isAuto = false; std::string fallbackReason; std::string rejectedAutoStrategy; }; struct PreparedConvInput { Value value; RankedTensorType type; }; struct ConvStrategyEstimate { uint64_t estimatedMvmCount = 0; uint64_t estimatedReductionVAddCount = 0; uint64_t estimatedOutputFragments = 1; bool perOutputPositionReduction = false; bool requiresFuncReturnMaterialization = false; bool giantCollectorConcatExpected = false; bool fullInputBroadcastExpected = false; uint64_t concatOperandCount = 0; std::string materializationKind = "none"; std::string collectorCore = "-"; }; struct ConvReportEntry { uint64_t id; std::string where; std::string strategy; std::string mode; std::string inputShape; std::string weightShape; std::string outputShape; int64_t groups; int64_t k; int64_t c; int64_t p; int64_t xbarSize; int64_t pack; uint64_t im2colElements; uint64_t im2colBudget; std::string chunkText; int64_t batchSize; int64_t numberOfBatches; std::string spatialComputeBatch; std::string batchedInstructionEmission; std::string reason; std::string fallbackReason; std::string rejectedAutoStrategy; uint64_t estimatedMvmCount; uint64_t estimatedReductionVAddCount; uint64_t estimatedOutputFragments; std::string materializationRequiredAtReturn; std::string materializationKind; uint64_t concatOperandCount; std::string collectorCore; std::string giantCollectorConcatExpected; std::string fullInputBroadcastExpected; }; enum class DistributedTensorOpKind { Relu, Sigmoid, Add, Sub, Mul, Div, Conv, }; enum class DistributedConvBarrierKind { Return, UnsupportedConsumer, Fanout, DeadValue, GroupedConv, Depthwise, }; enum class DistributedTensorConstantKind { None, Splat, PerChannel, }; enum class DistributedTensorLayoutKind { NchwRowStrip, }; struct DistributedFragmentInfo { SmallVector offsets; SmallVector sizes; SmallVector strides; int64_t producerLane = 0; }; struct DistributedTensorInfo { Value storage; RankedTensorType logicalType; DistributedTensorLayoutKind layoutKind = DistributedTensorLayoutKind::NchwRowStrip; SmallVector fragments; int64_t laneCount = 0; int64_t fragmentHeight = 1; int64_t channels = 0; int64_t height = 0; int64_t width = 0; bool isRowStripNchw() const { return layoutKind == DistributedTensorLayoutKind::NchwRowStrip; } }; struct DistributedTensorRegistry { llvm::DenseMap infos; void bind(Value value, const DistributedTensorInfo& info) { infos[value] = info; } const DistributedTensorInfo* lookup(Value value) const { auto it = infos.find(value); if (it == infos.end()) return nullptr; return &it->second; } }; struct DistributedTensorStep { Operation* op = nullptr; DistributedTensorOpKind kind; DenseElementsAttr constantAttr; DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None; bool fragmentOnLhs = true; std::optional convState; }; struct DistributedConvAnalysis { SmallVector steps; Operation* replacementOp = nullptr; DistributedConvBarrierKind barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; std::string barrierDetail; bool hasLocalConsumers() const { return !steps.empty(); } bool hasDistributedConvConsumer() const { return llvm::any_of(steps, [](const DistributedTensorStep& step) { return step.kind == DistributedTensorOpKind::Conv; }); } }; struct DistributedChainReportEntry { uint64_t chainId = 0; uint64_t chainLength = 0; std::string producerKind; std::string distributedOps; std::string materializationPoints; std::string firstMaterializationReason; uint64_t maxLiveFragments = 0; uint64_t maxFragmentFanout = 0; uint64_t patchBuilderCoreCount = 0; std::string centralJunctionDetected = "no"; std::string convInputMaterializationKind = "dense_materialization_fallback"; uint64_t localPatchFragments = 0; uint64_t remotePatchFragments = 0; uint64_t haloTransferCount = 0; uint64_t groupedTransferCount = 0; std::string fallbackReason; }; struct DistributedConvReportTotals { uint64_t totalConvs = 0; uint64_t distributedTensorsCreated = 0; uint64_t distributedValuesPropagated = 0; uint64_t distributedConsumersHandled = 0; uint64_t distributedConvInputsSeen = 0; uint64_t distributedConvInputsConsumed = 0; uint64_t materializationBarriersInserted = 0; std::map fallbackReasons; std::map barrierReasons; SmallVector chains; }; static Value createZeroGemmBias(RankedTensorType gemmResultType, PatternRewriter& rewriter); static FailureOr createRowStripPackedRows(Value rows, const ConvLoweringState& state, PatternRewriter& rewriter, Location loc); static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b); static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) { switch (kind) { case DistributedConvBarrierKind::Return: return "func.return"; case DistributedConvBarrierKind::UnsupportedConsumer: return "unsupported_consumer"; case DistributedConvBarrierKind::Fanout: return "fanout"; case DistributedConvBarrierKind::DeadValue: return "dead_value"; case DistributedConvBarrierKind::GroupedConv: return "grouped_conv"; case DistributedConvBarrierKind::Depthwise: return "depthwise_conv"; } llvm_unreachable("unknown distributed conv barrier kind"); } static StringRef stringifyConvLoweringStrategy(PimConvLoweringType strategy) { switch (strategy) { case PimConvLoweringAuto: return "auto"; case PimConvLoweringLegacy: return "legacy"; case PimConvLoweringDepthwise: return "depthwise"; case PimConvLoweringPackedIm2Col: return "packed-im2col"; case PimConvLoweringStreamedPatch: return "streamed-patch"; case PimConvLoweringStreamedPacked: return "streamed-packed"; case PimConvLoweringOutputChannelTiled: return "output-channel-tiled"; case PimConvLoweringInputKTiled: return "input-k-tiled"; case PimConvLoweringTiled2D: return "tiled-2d"; } llvm_unreachable("unknown conv lowering strategy"); } static bool requiresFuncReturnMaterialization(const DistributedConvAnalysis& analysis) { return !analysis.hasLocalConsumers() && analysis.barrierKind == DistributedConvBarrierKind::Return; } static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo, PimConvLoweringType strategy, const DistributedConvAnalysis& analysis) { ConvStrategyEstimate estimate; estimate.requiresFuncReturnMaterialization = requiresFuncReturnMaterialization(analysis); estimate.materializationKind = estimate.requiresFuncReturnMaterialization ? "func.return" : "none"; switch (strategy) { case PimConvLoweringLegacy: case PimConvLoweringPackedIm2Col: estimate.estimatedMvmCount = static_cast(std::max(1, geo.p)); break; case PimConvLoweringStreamedPatch: case PimConvLoweringStreamedPacked: case PimConvLoweringOutputChannelTiled: { uint64_t chunkPositions = chooseStreamChunkPositions(geo, /*packFactor=*/1); estimate.estimatedMvmCount = static_cast(std::max(1, geo.p)); estimate.estimatedOutputFragments = std::max(1, static_cast(ceilIntegerDivide(geo.p, static_cast(chunkPositions)))); break; } case PimConvLoweringInputKTiled: { const int64_t numKSlices = ceilIntegerDivide(geo.k, geo.xbarSize); const uint64_t maxLanesPerBatch = std::max(1, static_cast(crossbarCountInCore.getValue()) / static_cast(std::max(1, numKSlices * 4))); const uint64_t rowChunkWidth = std::max( 1, std::min({chooseStreamChunkPositions(geo, /*packFactor=*/1), maxLanesPerBatch, static_cast(std::max(1, geo.outWidth))})); estimate.estimatedMvmCount = static_cast(std::max(1, geo.p)) * static_cast(std::max(1, numKSlices)); estimate.estimatedReductionVAddCount = static_cast(std::max(1, geo.p)) * static_cast(std::max(0, numKSlices - 1) + (geo.hasBias ? 1 : 0)); estimate.estimatedOutputFragments = static_cast(std::max(1, geo.batchSize)) * static_cast(std::max(1, geo.outHeight)) * static_cast( ceilIntegerDivide(geo.outWidth, static_cast(rowChunkWidth))); estimate.perOutputPositionReduction = numKSlices > 1; estimate.fullInputBroadcastExpected = estimate.estimatedOutputFragments > 1; if (estimate.requiresFuncReturnMaterialization && estimate.estimatedOutputFragments >= 128) { estimate.giantCollectorConcatExpected = true; estimate.materializationKind = "single_collector_concat"; estimate.concatOperandCount = estimate.estimatedOutputFragments; estimate.collectorCore = "scheduled_post_merge"; } break; } case PimConvLoweringTiled2D: estimate.estimatedMvmCount = static_cast(std::max(1, geo.p)); break; case PimConvLoweringDepthwise: case PimConvLoweringAuto: break; } if (estimate.requiresFuncReturnMaterialization && estimate.materializationKind == "none") estimate.materializationKind = "func.return"; return estimate; } static std::string formatShape(ArrayRef dims) { std::string text; llvm::raw_string_ostream os(text); os << "["; for (size_t i = 0; i < dims.size(); ++i) { if (i != 0) os << "x"; os << dims[i]; } os << "]"; return text; } static std::string collapseWhitespace(StringRef text) { std::string out; out.reserve(text.size()); bool lastWasSpace = false; for (char c : text) { bool isSpace = std::isspace(static_cast(c)); if (isSpace) { if (!lastWasSpace && !out.empty()) out.push_back(' '); lastWasSpace = true; continue; } out.push_back(c); lastWasSpace = false; } return out; } static std::string abbreviate(StringRef text, size_t maxLen) { if (text.size() <= maxLen) return text.str(); return (text.take_front(maxLen - 3) + "...").str(); } static std::string abbreviateFromEnd(StringRef text, size_t maxLen) { if (text.size() <= maxLen) return text.str(); return ("..." + text.take_back(maxLen - 3)).str(); } static std::string summarizeLocation(Location loc, size_t maxLen = 44) { std::string text; llvm::raw_string_ostream os(text); loc.print(os); os.flush(); std::string collapsed = collapseWhitespace(text); if (collapsed.size() <= maxLen) return collapsed; if (collapsed.find('/') != std::string::npos || collapsed.find('#') != std::string::npos) return abbreviateFromEnd(collapsed, maxLen); return abbreviate(collapsed, maxLen); } static std::string alignCell(StringRef text, size_t width, bool rightAlign = false) { std::string cell = text.str(); if (cell.size() < width) { size_t padding = width - cell.size(); if (rightAlign) cell.insert(cell.begin(), padding, ' '); else cell.append(padding, ' '); } return cell; } static bool hasSameStaticTensorType(Value value, Type expectedType) { auto valueType = dyn_cast(value.getType()); auto expectedTensorType = dyn_cast(expectedType); return valueType && expectedTensorType && valueType.hasStaticShape() && valueType == expectedTensorType; } static bool isSplatConstantValue(Value value, DenseElementsAttr& denseAttr) { denseAttr = getHostConstDenseElementsAttr(value); return static_cast(denseAttr) && denseAttr.isSplat(); } static bool isPerChannelConstantValue(Value value, RankedTensorType currentType, DenseElementsAttr& denseAttr) { denseAttr = getHostConstDenseElementsAttr(value); if (!denseAttr || denseAttr.isSplat()) return false; auto constantType = dyn_cast(denseAttr.getType()); if (!constantType || !constantType.hasStaticShape()) return false; const int64_t channels = currentType.getDimSize(1); if (constantType.getRank() == 1) return constantType.getDimSize(0) == channels; if (constantType.getRank() == 2) return constantType.getDimSize(0) == 1 && constantType.getDimSize(1) == channels; if (constantType.getRank() != 4) return false; return constantType.getDimSize(0) == 1 && constantType.getDimSize(1) == channels && constantType.getDimSize(2) == 1 && constantType.getDimSize(3) == 1; } static std::optional classifyDistributedBinaryConsumer(Operation* user, Value currentValue, Value lhs, Value rhs, DistributedTensorOpKind kind, bool allowFragmentOnRhs, std::string& failureDetail) { if (user->getNumResults() != 1 || !hasSameStaticTensorType(user->getResult(0), currentValue.getType())) { failureDetail = "result type mismatch"; return std::nullopt; } auto currentType = cast(currentValue.getType()); DenseElementsAttr constantAttr; if (lhs == currentValue) { DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None; if (isSplatConstantValue(rhs, constantAttr)) constantKind = DistributedTensorConstantKind::Splat; else if (isPerChannelConstantValue(rhs, currentType, constantAttr)) constantKind = DistributedTensorConstantKind::PerChannel; else failureDetail = "unsupported rhs broadcast"; if (constantKind == DistributedTensorConstantKind::None) return std::nullopt; return DistributedTensorStep {user, kind, constantAttr, constantKind, /*fragmentOnLhs=*/true, std::nullopt}; } if (rhs == currentValue && allowFragmentOnRhs) { DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None; if (isSplatConstantValue(lhs, constantAttr)) constantKind = DistributedTensorConstantKind::Splat; else if (isPerChannelConstantValue(lhs, currentType, constantAttr)) constantKind = DistributedTensorConstantKind::PerChannel; else failureDetail = "unsupported lhs broadcast"; if (constantKind == DistributedTensorConstantKind::None) return std::nullopt; return DistributedTensorStep {user, kind, constantAttr, constantKind, /*fragmentOnLhs=*/false, std::nullopt}; } failureDetail = "conv result is not the supported binary operand"; return std::nullopt; } static bool covers(RowInterval acquired, RowInterval needed) { return acquired.begin <= needed.begin && acquired.end >= needed.end; } static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) { if (state.batchSize != 1) { failureReason = "unsupported_batch"; return false; } if (state.group != 1) { failureReason = "unsupported_groups"; return false; } if (isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup)) { failureReason = "unsupported_depthwise"; return false; } if (state.strideHeight != 1 || state.strideWidth != 1) { failureReason = "unsupported_stride"; return false; } if (state.dilationHeight != 1 || state.dilationWidth != 1) { failureReason = "unsupported_dilation"; return false; } if (state.padHeightBegin != state.padHeightEnd || state.padWidthBegin != state.padWidthEnd) { failureReason = "unsupported_padding"; return false; } if (state.padHeightBegin != 1 || state.padWidthBegin != 1) { failureReason = "unsupported_padding"; return false; } if (state.wHeight != 3 || state.wWidth != 3) { failureReason = "unsupported_kernel"; return false; } if (state.outHeight != state.xHeight || state.outWidth != state.xWidth) { failureReason = "unsupported_output_shape"; return false; } if (!getHostConstDenseElementsAttr(state.w)) { failureReason = "non_constant_weights"; return false; } if (state.hasBias && !getHostConstDenseElementsAttr(state.b)) { failureReason = "non_constant_bias"; return false; } failureReason = ""; return true; } static std::string stringifyDistributedTensorOpKind(DistributedTensorOpKind kind) { switch (kind) { case DistributedTensorOpKind::Relu: return "Relu"; case DistributedTensorOpKind::Sigmoid: return "Sigmoid"; case DistributedTensorOpKind::Add: return "Add"; case DistributedTensorOpKind::Sub: return "Sub"; case DistributedTensorOpKind::Mul: return "Mul"; case DistributedTensorOpKind::Div: return "Div"; case DistributedTensorOpKind::Conv: return "Conv"; } llvm_unreachable("unknown distributed tensor op kind"); } [[maybe_unused]] static DistributedConvAnalysis analyzeDistributedConvConsumers(ONNXConvOp convOp) { DistributedConvAnalysis analysis; analysis.replacementOp = convOp; Value currentValue = convOp.getResult(); while (true) { if (currentValue.use_empty()) { analysis.barrierKind = DistributedConvBarrierKind::DeadValue; analysis.barrierDetail = "result has no users"; return analysis; } if (!currentValue.hasOneUse()) { analysis.barrierKind = DistributedConvBarrierKind::Fanout; analysis.barrierDetail = "value has multiple users"; return analysis; } Operation* user = *currentValue.getUsers().begin(); if (isa(user)) { analysis.barrierKind = DistributedConvBarrierKind::Return; analysis.barrierDetail = "materialize at func.return"; return analysis; } std::optional step; std::string failureDetail; if (auto reluOp = dyn_cast(user)) { if (hasSameStaticTensorType(reluOp.getResult(), currentValue.getType())) step = DistributedTensorStep { user, DistributedTensorOpKind::Relu, {}, DistributedTensorConstantKind::None, true, std::nullopt}; else failureDetail = "relu result type mismatch"; } else if (auto sigmoidOp = dyn_cast(user)) { if (hasSameStaticTensorType(sigmoidOp.getResult(), currentValue.getType())) step = DistributedTensorStep { user, DistributedTensorOpKind::Sigmoid, {}, DistributedTensorConstantKind::None, true, std::nullopt}; else failureDetail = "sigmoid result type mismatch"; } else if (auto addOp = dyn_cast(user)) { step = classifyDistributedBinaryConsumer( user, currentValue, addOp.getA(), addOp.getB(), DistributedTensorOpKind::Add, /*allowFragmentOnRhs=*/true, failureDetail); } else if (auto subOp = dyn_cast(user)) { step = classifyDistributedBinaryConsumer( user, currentValue, subOp.getA(), subOp.getB(), DistributedTensorOpKind::Sub, /*allowFragmentOnRhs=*/true, failureDetail); } else if (auto mulOp = dyn_cast(user)) { step = classifyDistributedBinaryConsumer( user, currentValue, mulOp.getA(), mulOp.getB(), DistributedTensorOpKind::Mul, /*allowFragmentOnRhs=*/true, failureDetail); } else if (auto divOp = dyn_cast(user)) { step = classifyDistributedBinaryConsumer( user, currentValue, divOp.getA(), divOp.getB(), DistributedTensorOpKind::Div, /*allowFragmentOnRhs=*/false, failureDetail); if (step) { auto denseAttr = dyn_cast(step->constantAttr); if (!denseAttr) { failureDetail = "div requires floating-point splat constant"; step.reset(); } } } else if (auto nextConv = dyn_cast(user)) { failureDetail = "onnx.Conv distributed consumer blocked by dim0-only whole-batch materialization in MergeComputeNodes"; } else { failureDetail = (user->getName().getStringRef() + " is not distributed-aware yet").str(); } if (!step) { analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; analysis.barrierDetail = failureDetail; return analysis; } analysis.replacementOp = user; analysis.steps.push_back(*step); currentValue = user->getResult(0); } } static void rewriteDistributedConvReport(const DistributedConvReportTotals& totals) { std::fstream reportFile = openReportFile("conv_distributed_consumption_report"); if (!reportFile.is_open()) return; reportFile << "# PIM Conv Distributed Consumption Report\n\n"; reportFile << "Totals:\n"; reportFile << "- convs_seen: " << totals.totalConvs << "\n"; reportFile << "- distributed_tensors_created: " << totals.distributedTensorsCreated << "\n"; reportFile << "- distributed_values_propagated: " << totals.distributedValuesPropagated << "\n"; reportFile << "- distributed_consumers_handled_locally: " << totals.distributedConsumersHandled << "\n"; reportFile << "- distributed_conv_inputs_seen: " << totals.distributedConvInputsSeen << "\n"; reportFile << "- distributed_conv_inputs_consumed: " << totals.distributedConvInputsConsumed << "\n"; reportFile << "- materialization_barriers_inserted: " << totals.materializationBarriersInserted << "\n\n"; if (!totals.barrierReasons.empty()) { reportFile << "Materialization barriers:\n"; for (const auto& [reason, count] : totals.barrierReasons) reportFile << "- " << reason << ": " << count << "\n"; reportFile << "\n"; } if (!totals.fallbackReasons.empty()) { reportFile << "Fallback / no-distribution reasons:\n"; for (const auto& [reason, count] : totals.fallbackReasons) reportFile << "- " << reason << ": " << count << "\n"; reportFile << "\n"; } if (!totals.chains.empty()) { reportFile << "Chains:\n"; for (const DistributedChainReportEntry& chain : totals.chains) { reportFile << "- chain_id: " << chain.chainId << "\n"; reportFile << " chain_length: " << chain.chainLength << "\n"; reportFile << " producer_kind: " << chain.producerKind << "\n"; reportFile << " distributed_ops: " << chain.distributedOps << "\n"; reportFile << " materialization_points: " << chain.materializationPoints << "\n"; reportFile << " first_materialization_reason: " << chain.firstMaterializationReason << "\n"; reportFile << " max_live_fragments: " << chain.maxLiveFragments << "\n"; reportFile << " max_fragment_fanout: " << chain.maxFragmentFanout << "\n"; reportFile << " patch_builder_core_count: " << chain.patchBuilderCoreCount << "\n"; reportFile << " central_junction_detected: " << chain.centralJunctionDetected << "\n"; reportFile << " conv_input_materialization_kind: " << chain.convInputMaterializationKind << "\n"; reportFile << " local_patch_fragments: " << chain.localPatchFragments << "\n"; reportFile << " remote_patch_fragments: " << chain.remotePatchFragments << "\n"; reportFile << " halo_transfer_count: " << chain.haloTransferCount << "\n"; reportFile << " grouped_transfer_count: " << chain.groupedTransferCount << "\n"; if (!chain.fallbackReason.empty()) reportFile << " fallback_reason: " << chain.fallbackReason << "\n"; } } } [[maybe_unused]] static void recordDistributedConvOutcome(const DistributedConvAnalysis& analysis) { static std::mutex reportMutex; static DistributedConvReportTotals totals; std::string barrierKey = stringifyDistributedConvBarrierKind(analysis.barrierKind).str(); if (!analysis.barrierDetail.empty()) barrierKey += ": " + analysis.barrierDetail; std::lock_guard guard(reportMutex); totals.totalConvs++; if (analysis.hasLocalConsumers()) { totals.distributedTensorsCreated++; totals.distributedValuesPropagated += analysis.steps.size(); totals.distributedConsumersHandled += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) { return step.kind != DistributedTensorOpKind::Conv; }); totals.distributedConvInputsSeen += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) { return step.kind == DistributedTensorOpKind::Conv; }); totals.distributedConvInputsConsumed += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) { return step.kind == DistributedTensorOpKind::Conv; }); totals.materializationBarriersInserted++; totals.barrierReasons[barrierKey]++; } else { totals.fallbackReasons[barrierKey]++; } DistributedChainReportEntry chain; chain.chainId = totals.totalConvs; chain.chainLength = analysis.steps.size() + 1; chain.producerKind = "Conv"; chain.materializationPoints = stringifyDistributedConvBarrierKind(analysis.barrierKind).str(); chain.firstMaterializationReason = analysis.barrierDetail; chain.maxFragmentFanout = analysis.barrierKind == DistributedConvBarrierKind::Fanout ? 2 : 1; std::string ops; for (size_t index = 0; index < analysis.steps.size(); ++index) { if (!ops.empty()) ops += ", "; ops += stringifyDistributedTensorOpKind(analysis.steps[index].kind); } chain.distributedOps = ops; if (analysis.hasDistributedConvConsumer()) { chain.convInputMaterializationKind = "distributed_with_halo_exchange"; chain.patchBuilderCoreCount = 1; } if (totals.chains.size() == 16) totals.chains.erase(totals.chains.begin()); totals.chains.push_back(std::move(chain)); rewriteDistributedConvReport(totals); } static std::string makeDivider(ArrayRef widths) { std::string divider = "+"; for (size_t width : widths) { divider.append(width + 2, '-'); divider.push_back('+'); } return divider; } static void printConvReportLegend(std::fstream& reportFile) { reportFile << "# PIM Conv Lowering Report\n\n"; reportFile << "Legend:\n"; reportFile << "- `id`: sequential Conv report entry index within this compiler invocation.\n"; reportFile << "- `where`: summarized MLIR location of the Conv op.\n"; reportFile << "- `mode`: whether the selected strategy came from `auto` policy or a forced compiler option.\n"; reportFile << "- `strategy`: selected Conv lowering algorithm.\n"; reportFile << "- `input`, `weight`, `output`: tensor shapes of the Conv operands/result.\n"; reportFile << "- `groups`: ONNX Conv group count.\n"; reportFile << "- `K`: logical reduction size, `CinPerGroup * Kh * Kw`.\n"; reportFile << "- `C`: logical output-channel width handled by the selected strategy.\n"; reportFile << "- `P`: total output positions, `N * Hout * Wout`.\n"; reportFile << "- `X`: crossbar size.\n"; reportFile << "- `pack`: packed spatial positions per MVM group, `floor(X / max(K, C))`.\n"; reportFile << "- `im2col`: total explicit im2col element count, `P * K`.\n"; reportFile << "- `im2col_budget`: maximum im2col element budget allowed by the compiler option.\n"; reportFile << "- `stream_chunk`: output positions materialized per streamed chunk, or `-` when not applicable.\n"; reportFile << "- `batch_size`: logical batch size passed into the compute lowering for this Conv form.\n"; reportFile << "- `batches`: number of repeated compute batches emitted for the Conv.\n"; reportFile << "- `spatial_compute_batch`: whether ONNX-to-Spatial used `spat.compute_batch` for this Conv.\n"; reportFile << "- `batched_instruction_emission`: whether the lowering is expected to reach batched PIM emission.\n"; reportFile << "- `reason`: strategy-selection reason.\n"; reportFile << "- Per-conv details below the table include profitability estimates and materialization diagnostics.\n"; reportFile << "- placeholders like `[7]`: value too long for the table cell; see the appendix at the end.\n\n"; } struct ConvReportOverflow { uint64_t placeholderId; SmallVector rowIds; std::string column; std::string value; }; static StringRef describeConvLoweringStrategy(PimConvLoweringType strategy) { switch (strategy) { case PimConvLoweringAuto: return "Automatic policy selection."; case PimConvLoweringLegacy: return "Legacy Conv lowering path kept for compatibility."; case PimConvLoweringDepthwise: return "Specialized depthwise lowering that avoids generic cross-channel GEMM mixing."; case PimConvLoweringPackedIm2Col: return "Explicit im2col plus packed GEMM lowering for Conv shapes that fit well in one crossbar."; case PimConvLoweringStreamedPatch: return "Chunked per-patch streaming Conv lowering without global im2col materialization."; case PimConvLoweringStreamedPacked: return "Chunked streamed Conv lowering that still packs multiple output positions per MVM group."; case PimConvLoweringOutputChannelTiled: return "Conv lowering that splits output channels across tiles when C exceeds one crossbar width."; case PimConvLoweringInputKTiled: return "Conv lowering that splits the reduction dimension K across tiles and accumulates partial sums."; case PimConvLoweringTiled2D: return "Conv lowering that tiles both K and output channels because neither dimension fits one crossbar."; } llvm_unreachable("unknown conv lowering strategy"); } static std::string fitConvReportCell(StringRef text, size_t width, uint64_t rowId, StringRef column, std::vector& overflows, uint64_t& nextPlaceholderId, bool rightAlign = false) { if (text.size() <= width) return alignCell(text, width, rightAlign); for (ConvReportOverflow& overflow : overflows) { if (overflow.column == column && overflow.value == text) { if (llvm::find(overflow.rowIds, rowId) == overflow.rowIds.end()) overflow.rowIds.push_back(rowId); std::string placeholder = "[" + std::to_string(overflow.placeholderId) + "]"; return alignCell(placeholder, width, rightAlign); } } std::string placeholder = "[" + std::to_string(nextPlaceholderId++) + "]"; overflows.push_back({nextPlaceholderId - 1, {rowId}, column.str(), text.str()}); return alignCell(placeholder, width, rightAlign); } static void writeConvReportTable(std::fstream& reportFile, ArrayRef entries) { static constexpr size_t kIdWidth = 4; static constexpr size_t kWhereWidth = 24; static constexpr size_t kModeWidth = 6; static constexpr size_t kStrategyWidth = 20; static constexpr size_t kShapeWidth = 14; static constexpr size_t kGroupsWidth = 3; static constexpr size_t kSmallWidth = 5; static constexpr size_t kPWidth = 10; static constexpr size_t kIm2colWidth = 10; static constexpr size_t kChunkWidth = 8; static constexpr size_t kFlagWidth = 3; static constexpr size_t kReasonWidth = 24; const SmallVector widths = { kIdWidth, kWhereWidth, kModeWidth, kStrategyWidth, kShapeWidth, kShapeWidth, kShapeWidth, kGroupsWidth, kSmallWidth, kSmallWidth, kPWidth, kSmallWidth, kSmallWidth, kIm2colWidth, kIm2colWidth, kChunkWidth, kSmallWidth, kSmallWidth, kFlagWidth, kFlagWidth, kReasonWidth, }; const std::string divider = makeDivider(widths); std::vector overflows; uint64_t nextPlaceholderId = 1; auto printRow = [&](ArrayRef cells) { reportFile << "|"; for (size_t i = 0; i < cells.size(); ++i) reportFile << " " << cells[i] << " |"; reportFile << "\n"; }; reportFile << divider << "\n"; printRow({ alignCell("id", kIdWidth, true), alignCell("where", kWhereWidth), alignCell("mode", kModeWidth), alignCell("strategy", kStrategyWidth), alignCell("input", kShapeWidth), alignCell("weight", kShapeWidth), alignCell("output", kShapeWidth), alignCell("grp", kGroupsWidth, true), alignCell("K", kSmallWidth, true), alignCell("C", kSmallWidth, true), alignCell("P", kPWidth, true), alignCell("X", kSmallWidth, true), alignCell("pack", kSmallWidth, true), alignCell("im2col", kIm2colWidth, true), alignCell("budget", kIm2colWidth, true), alignCell("chunk", kChunkWidth, true), alignCell("batch", kSmallWidth, true), alignCell("nbat", kSmallWidth, true), alignCell("scb", kFlagWidth), alignCell("bie", kFlagWidth), alignCell("reason", kReasonWidth), }); reportFile << divider << "\n"; for (const ConvReportEntry& entry : entries) { printRow({ alignCell(std::to_string(entry.id), kIdWidth, true), fitConvReportCell(entry.where, kWhereWidth, entry.id, "where", overflows, nextPlaceholderId), alignCell(entry.mode, kModeWidth), fitConvReportCell(entry.strategy, kStrategyWidth, entry.id, "strategy", overflows, nextPlaceholderId), fitConvReportCell(entry.inputShape, kShapeWidth, entry.id, "input", overflows, nextPlaceholderId), fitConvReportCell(entry.weightShape, kShapeWidth, entry.id, "weight", overflows, nextPlaceholderId), fitConvReportCell(entry.outputShape, kShapeWidth, entry.id, "output", overflows, nextPlaceholderId), alignCell(std::to_string(entry.groups), kGroupsWidth, true), alignCell(std::to_string(entry.k), kSmallWidth, true), alignCell(std::to_string(entry.c), kSmallWidth, true), alignCell(std::to_string(entry.p), kPWidth, true), alignCell(std::to_string(entry.xbarSize), kSmallWidth, true), alignCell(std::to_string(entry.pack), kSmallWidth, true), alignCell(std::to_string(entry.im2colElements), kIm2colWidth, true), alignCell(std::to_string(entry.im2colBudget), kIm2colWidth, true), fitConvReportCell(entry.chunkText, kChunkWidth, entry.id, "stream_chunk", overflows, nextPlaceholderId, true), alignCell(std::to_string(entry.batchSize), kSmallWidth, true), alignCell(std::to_string(entry.numberOfBatches), kSmallWidth, true), alignCell(entry.spatialComputeBatch, kFlagWidth), alignCell(entry.batchedInstructionEmission, kFlagWidth), fitConvReportCell(entry.reason, kReasonWidth, entry.id, "reason", overflows, nextPlaceholderId), }); } reportFile << divider << "\n"; if (overflows.empty()) reportFile << "\n"; else { reportFile << "\nAppendix:\n"; for (const ConvReportOverflow& overflow : overflows) { reportFile << " [" << overflow.placeholderId << "] rows "; for (size_t i = 0; i < overflow.rowIds.size(); ++i) { if (i != 0) reportFile << ", "; reportFile << overflow.rowIds[i]; } reportFile << ", " << overflow.column << ": " << overflow.value << "\n"; } reportFile << "\n"; } reportFile << "Per-Conv Details:\n"; for (const ConvReportEntry& entry : entries) { reportFile << "- Conv " << entry.id << ": mode=" << entry.mode << ", strategy=" << entry.strategy << ", reason=" << entry.reason << "\n"; reportFile << " K=" << entry.k << ", Cout=" << entry.c << ", output_positions=" << entry.p << ", estimated_mvm_count=" << entry.estimatedMvmCount << ", estimated_reduction_vadd_count=" << entry.estimatedReductionVAddCount << ", estimated_output_fragments=" << entry.estimatedOutputFragments << "\n"; reportFile << " materialization_required_at_func_return=" << entry.materializationRequiredAtReturn << ", materialization_kind=" << entry.materializationKind << ", giant_collector_concat_expected=" << entry.giantCollectorConcatExpected << ", concat_operand_count=" << entry.concatOperandCount << ", collector_core=" << entry.collectorCore << ", full_input_broadcast_expected=" << entry.fullInputBroadcastExpected << "\n"; if (!entry.rejectedAutoStrategy.empty()) reportFile << " rejected_auto_strategy=" << entry.rejectedAutoStrategy << "\n"; if (!entry.fallbackReason.empty()) reportFile << " fallback_reason=" << entry.fallbackReason << "\n"; } reportFile << "\n"; llvm::SmallVector usedStrategies; for (const ConvReportEntry& entry : entries) { PimConvLoweringType strategy = PimConvLoweringAuto; for (PimConvLoweringType candidate : { PimConvLoweringAuto, PimConvLoweringLegacy, PimConvLoweringDepthwise, PimConvLoweringPackedIm2Col, PimConvLoweringStreamedPatch, PimConvLoweringStreamedPacked, PimConvLoweringOutputChannelTiled, PimConvLoweringInputKTiled, PimConvLoweringTiled2D, }) { if (entry.strategy == stringifyConvLoweringStrategy(candidate)) { strategy = candidate; break; } } if (llvm::find(usedStrategies, strategy) == usedStrategies.end()) usedStrategies.push_back(strategy); } reportFile << "Strategies used in this report:\n"; for (PimConvLoweringType strategy : usedStrategies) reportFile << "- `" << stringifyConvLoweringStrategy(strategy).str() << "`: " << describeConvLoweringStrategy(strategy).str() << "\n"; } static void rewriteConvLoweringReport(ArrayRef entries) { std::fstream reportFile = openReportFile("conv_lowering_report"); if (!reportFile.is_open()) return; printConvReportLegend(reportFile); writeConvReportTable(reportFile, entries); } [[maybe_unused]] static FailureOr resolveRequestedConvLoweringStrategy(ONNXConvOp convOp) { if (!useExperimentalConvImpl) return pimConvLowering.getValue(); if (pimConvLowering != PimConvLoweringAuto && pimConvLowering != PimConvLoweringPackedIm2Col) { convOp.emitOpError() << "--use-experimental-conv-impl conflicts with --pim-conv-lowering=" << stringifyConvLoweringStrategy(pimConvLowering); return failure(); } return PimConvLoweringPackedIm2Col; } static ConvLoweringDecision chooseConvLoweringStrategy(const ConvGeometry& geo, PimConvLoweringType requested, const DistributedConvAnalysis& analysis) { if (requested != PimConvLoweringAuto) return {requested, "forced by compiler option", /*isAuto=*/false, "", ""}; // Transform-based convolution is intentionally not selected for this ISA: // it would require explicit transform sequences and staging traffic on top of // the same crossbar MVM primitive, which is not attractive here. if (geo.isDepthwise) return {PimConvLoweringDepthwise, "depthwise convolution", /*isAuto=*/true, "", ""}; if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements) return {PimConvLoweringPackedIm2Col, "fits crossbar, packing useful, and global im2col fits budget", /*isAuto=*/true, "", ""}; if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements > pimConvIm2colMaxElements) return {PimConvLoweringStreamedPacked, "fits crossbar and packing useful, but global im2col exceeds budget", /*isAuto=*/true, "", ""}; if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize) return {PimConvLoweringStreamedPatch, "fits crossbar but packing is not useful", /*isAuto=*/true, "", ""}; if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize) return {PimConvLoweringOutputChannelTiled, "output channels exceed one crossbar width", /*isAuto=*/true, "", ""}; if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize) { ConvStrategyEstimate estimate = estimateConvStrategy(geo, PimConvLoweringInputKTiled, analysis); std::string fallbackReason = "auto rejects input-k-tiled because the reduction-heavy path is force-only for now"; if (estimate.requiresFuncReturnMaterialization && estimate.perOutputPositionReduction && estimate.giantCollectorConcatExpected) { fallbackReason += "; func.return would materialize " + std::to_string(estimate.concatOperandCount) + " output fragments through a single collector concat after per-position reductions"; } if (estimate.fullInputBroadcastExpected) fallbackReason += "; the current lowering also broadcasts the padded input to many workers"; return {PimConvLoweringLegacy, "fall back to legacy explicit-im2col for the current auto policy", /*isAuto=*/true, fallbackReason, stringifyConvLoweringStrategy(PimConvLoweringInputKTiled).str()}; } return {PimConvLoweringTiled2D, "both reduction K and output channels exceed one crossbar", /*isAuto=*/true, "", ""}; } [[maybe_unused]] static LogicalResult verifyForcedConvLoweringStrategy(ONNXConvOp convOp, const ConvGeometry& geo, PimConvLoweringType strategy) { switch (strategy) { case PimConvLoweringAuto: case PimConvLoweringLegacy: return success(); case PimConvLoweringDepthwise: if (geo.isDepthwise) return success(); return convOp.emitOpError("forced depthwise Conv lowering requires a depthwise convolution"); case PimConvLoweringPackedIm2Col: if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements) return success(); return convOp.emitOpError("forced packed-im2col Conv lowering requires K/C to fit, pack >= 2, and im2col within budget"); case PimConvLoweringStreamedPatch: if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize) return success(); return convOp.emitOpError("forced streamed-patch Conv lowering requires K and C to each fit one crossbar"); case PimConvLoweringStreamedPacked: if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2) return success(); return convOp.emitOpError("forced streamed-packed Conv lowering requires K/C to fit and pack >= 2"); case PimConvLoweringOutputChannelTiled: if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize) return success(); return convOp.emitOpError("forced output-channel-tiled Conv lowering requires K <= X and C > X"); case PimConvLoweringInputKTiled: if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize) return success(); return convOp.emitOpError("forced input-k-tiled Conv lowering requires K > X and C <= X"); case PimConvLoweringTiled2D: if (geo.k > geo.xbarSize && geo.c > geo.xbarSize) return success(); return convOp.emitOpError("forced tiled-2d Conv lowering requires K > X and C > X"); } llvm_unreachable("unknown conv lowering strategy"); } static void reportConvLoweringDecision(ONNXConvOp convOp, const ConvGeometry& geo, const ConvLoweringDecision& decision, const ConvStrategyEstimate& estimate, int64_t batchSize, int64_t numberOfBatches, bool usesComputeBatch, bool usesBatchedInstructionEmission, std::optional streamChunkPositions = std::nullopt) { if (!pimReportConvLowering) return; const std::string location = summarizeLocation(convOp.getLoc()); const std::string strategy = stringifyConvLoweringStrategy(decision.strategy).str(); const std::string mode = decision.isAuto ? "auto" : "forced"; const std::string inputShape = formatShape(cast(convOp.getX().getType()).getShape()); const std::string weightShape = formatShape(cast(convOp.getW().getType()).getShape()); const std::string outputShape = formatShape(cast(convOp.getY().getType()).getShape()); const std::string chunkText = streamChunkPositions ? std::to_string(*streamChunkPositions) : "-"; const std::string scbText = usesComputeBatch ? "yes" : "no"; const std::string bieText = usesBatchedInstructionEmission ? "yes" : "no"; static uint64_t reportIndex = 0; const uint64_t currentIndex = ++reportIndex; static std::mutex reportMutex; static std::vector reportEntries; std::lock_guard lock(reportMutex); reportEntries.push_back({ currentIndex, location, strategy, mode, inputShape, weightShape, outputShape, geo.group, geo.k, geo.c, geo.p, geo.xbarSize, geo.pack, geo.im2colElements, pimConvIm2colMaxElements, chunkText, batchSize, numberOfBatches, scbText, bieText, decision.reason, decision.fallbackReason, decision.rejectedAutoStrategy, estimate.estimatedMvmCount, estimate.estimatedReductionVAddCount, estimate.estimatedOutputFragments, estimate.requiresFuncReturnMaterialization ? "yes" : "no", estimate.materializationKind, estimate.concatOperandCount, estimate.collectorCore, estimate.giantCollectorConcatExpected ? "yes" : "no", estimate.fullInputBroadcastExpected ? "yes" : "no", }); rewriteConvLoweringReport(reportEntries); } static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) return bias; auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType()); return tensor::ExpandShapeOp::create(rewriter, loc, expandedBiasType, bias, SmallVector { {0, 1} }); } static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) { assert(value > 0 && "expected positive value"); limit = std::min(value, limit); for (int64_t candidate = limit; candidate >= 1; --candidate) if (value % candidate == 0) return candidate; return 1; } static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, ArrayRef lowPadValues, ArrayRef highPadValues, PatternRewriter& rewriter, Location loc) { auto valueType = cast(value.getType()); if (valueType == resultType) return value; SmallVector lowPads; SmallVector highPads; lowPads.reserve(lowPadValues.size()); highPads.reserve(highPadValues.size()); for (auto lowPad : lowPadValues) lowPads.push_back(rewriter.getIndexAttr(lowPad)); for (auto highPad : highPadValues) highPads.push_back(rewriter.getIndexAttr(highPad)); auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); auto* padBlock = new Block(); for (int64_t dim = 0, rank = resultType.getRank(); dim < rank; ++dim) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); auto zero = getOrCreateConstant( rewriter, padOp.getOperation(), rewriter.getZeroAttr(resultType.getElementType()), resultType.getElementType()); tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } static Value createConvInputPatch(Value input, RankedTensorType patchType, Value batchIndex, Value channelOffset, Value inputHeightOffset, Value inputWidthOffset, int64_t dilationHeight, int64_t dilationWidth, PatternRewriter& rewriter, Location loc) { const int64_t patchChannels = patchType.getDimSize(1); const int64_t kernelHeight = patchType.getDimSize(2); const int64_t kernelWidth = patchType.getDimSize(3); if (dilationHeight == 1 && dilationWidth == 1) { SmallVector offsets {batchIndex, channelOffset, inputHeightOffset, inputWidthOffset}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchChannels), rewriter.getIndexAttr(kernelHeight), rewriter.getIndexAttr(kernelWidth)}; return tensor::ExtractSliceOp::create(rewriter, loc, patchType, input, offsets, sizes, getUnitStrides(rewriter, 4)); } Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); auto elementType = patchType.getElementType(); auto pixelType = RankedTensorType::get({1, patchChannels, 1, 1}, elementType, patchType.getEncoding()); Value patch = tensor::EmptyOp::create(rewriter, loc, patchType.getShape(), elementType); for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { Value sourceHeightOffset = affineAddConst(rewriter, loc, inputHeightOffset, kernelH * dilationHeight, anchorOp); for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { Value sourceWidthOffset = affineAddConst(rewriter, loc, inputWidthOffset, kernelW * dilationWidth, anchorOp); SmallVector sourceOffsets {batchIndex, channelOffset, sourceHeightOffset, sourceWidthOffset}; SmallVector sourceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value sourcePixel = tensor::ExtractSliceOp::create( rewriter, loc, pixelType, input, sourceOffsets, sourceSizes, getUnitStrides(rewriter, 4)); SmallVector targetOffsets { rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(kernelH), rewriter.getIndexAttr(kernelW)}; patch = tensor::InsertSliceOp::create( rewriter, loc, sourcePixel, patch, targetOffsets, sourceSizes, getUnitStrides(rewriter, 4)); } } return patch; } static Value createCollectedConvOutput(ValueRange gemmRows, Type convType, RankedTensorType gemmOutType, RankedTensorType nhwcType, RankedTensorType outType, int64_t numPatches, int64_t numChannelsOut, int64_t packFactor, ArrayRef distributedConsumers, PatternRewriter& rewriter, Location loc); static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b); namespace depthwise { struct Tiling { int64_t outputMultiplier; int64_t kernelElements; int64_t channelsPerTile; int64_t tileInputRows; int64_t tileOutputChannels; int64_t numChannelTiles; int64_t spatialPatchesPerBatch; int64_t totalPatches; }; static std::optional computeTiling(int64_t batchSize, int64_t numChannelsIn, int64_t numChannelsOut, int64_t wHeight, int64_t wWidth, int64_t outHeight, int64_t outWidth) { const int64_t kernelElements = wHeight * wWidth; const int64_t outputMultiplier = numChannelsOut / numChannelsIn; const int64_t xbarDim = static_cast(crossbarSize.getValue()); if (kernelElements <= 0 || outputMultiplier <= 0 || kernelElements > xbarDim || outputMultiplier > xbarDim) return std::nullopt; const int64_t maxChannelsPerTile = std::min(xbarDim / kernelElements, xbarDim / outputMultiplier); if (maxChannelsPerTile <= 0) return std::nullopt; const int64_t channelsPerTile = findLargestDivisorAtMost(numChannelsIn, maxChannelsPerTile); const int64_t tileInputRows = channelsPerTile * kernelElements; const int64_t tileOutputChannels = channelsPerTile * outputMultiplier; if (tileInputRows > xbarDim || tileOutputChannels > xbarDim) return std::nullopt; return Tiling { outputMultiplier, kernelElements, channelsPerTile, tileInputRows, tileOutputChannels, numChannelsIn / channelsPerTile, outHeight * outWidth, batchSize * outHeight * outWidth, }; } static Value buildPackedWeights(DenseElementsAttr wDenseAttr, RankedTensorType wType, const Tiling& tiling, PatternRewriter& rewriter, Location loc) { auto packedWeightType = RankedTensorType::get( {tiling.numChannelTiles, tiling.tileInputRows, tiling.tileOutputChannels}, wType.getElementType()); SmallVector packedValues(packedWeightType.getNumElements(), cast(rewriter.getZeroAttr(wType.getElementType()))); SmallVector sourceValues(wDenseAttr.getValues()); for (int64_t tileIndex = 0; tileIndex < tiling.numChannelTiles; ++tileIndex) { const int64_t channelBase = tileIndex * tiling.channelsPerTile; for (int64_t localChannel = 0; localChannel < tiling.channelsPerTile; ++localChannel) { const int64_t globalChannel = channelBase + localChannel; for (int64_t kernelIndex = 0; kernelIndex < tiling.kernelElements; ++kernelIndex) { const int64_t kernelH = kernelIndex / wType.getDimSize(3); const int64_t kernelW = kernelIndex % wType.getDimSize(3); const int64_t targetRow = localChannel * tiling.kernelElements + kernelIndex; for (int64_t multiplierIndex = 0; multiplierIndex < tiling.outputMultiplier; ++multiplierIndex) { const int64_t globalOutChannel = globalChannel * tiling.outputMultiplier + multiplierIndex; const int64_t sourceFlatIndex = ((globalOutChannel * wType.getDimSize(1) * wType.getDimSize(2)) + kernelH) * wType.getDimSize(3) + kernelW; const int64_t targetCol = localChannel * tiling.outputMultiplier + multiplierIndex; const int64_t targetFlatIndex = ((tileIndex * tiling.tileInputRows) + targetRow) * tiling.tileOutputChannels + targetCol; packedValues[targetFlatIndex] = sourceValues[sourceFlatIndex]; } } } } auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType); } static Value createPaddedInput(Value input, RankedTensorType inputType, int64_t padHeightBegin, int64_t padHeightEnd, int64_t padWidthBegin, int64_t padWidthEnd, PatternRewriter& rewriter, Location loc) { if (padHeightBegin == 0 && padHeightEnd == 0 && padWidthBegin == 0 && padWidthEnd == 0) return input; auto paddedInputType = RankedTensorType::get({inputType.getDimSize(0), inputType.getDimSize(1), inputType.getDimSize(2) + padHeightBegin + padHeightEnd, inputType.getDimSize(3) + padWidthBegin + padWidthEnd}, inputType.getElementType()); auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { Value padded = createZeroPaddedTensor(computeInput, paddedInputType, {0, 0, padHeightBegin, padWidthBegin}, {0, 0, padHeightEnd, padWidthEnd}, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, padded); }); return computeOp.getResult(0); } static Value createInputTile(Value input, Value patchIndex, Value channelTileIndex, RankedTensorType inputTileType, const Tiling& tiling, int64_t strideHeight, int64_t strideWidth, int64_t dilationHeight, int64_t dilationWidth, int64_t outWidth, PatternRewriter& rewriter, Location loc) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value batchIndex = affineFloorDivConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp); Value batchPatchIndex = affineModConst(rewriter, loc, patchIndex, tiling.spatialPatchesPerBatch, anchorOp); Value outHeightIndex = affineFloorDivConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp); Value outWidthIndex = affineModConst(rewriter, loc, batchPatchIndex, outWidth, anchorOp); Value inputHeightOffset = strideHeight == 1 ? outHeightIndex : affineMulConst(rewriter, loc, outHeightIndex, strideHeight, anchorOp); Value inputWidthOffset = strideWidth == 1 ? outWidthIndex : affineMulConst(rewriter, loc, outWidthIndex, strideWidth, anchorOp); Value channelOffset = tiling.channelsPerTile == 1 ? channelTileIndex : affineMulConst(rewriter, loc, channelTileIndex, tiling.channelsPerTile, anchorOp); Value tile4D = createConvInputPatch(input, inputTileType, batchIndex, channelOffset, inputHeightOffset, inputWidthOffset, dilationHeight, dilationWidth, rewriter, loc); auto collapsedType = RankedTensorType::get({1, tiling.tileInputRows}, inputTileType.getElementType()); return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, tile4D, SmallVector { {0}, {1, 2, 3} }); } static Value createWeightTile(Value packedWeights, Value channelTileIndex, RankedTensorType packedWeightType, const Tiling& tiling, PatternRewriter& rewriter, Location loc) { SmallVector offsets {channelTileIndex, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileInputRows), rewriter.getIndexAttr(tiling.tileOutputChannels)}; auto sliceType = RankedTensorType::get({1, tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType()); Value slice = tensor::ExtractSliceOp::create( rewriter, loc, sliceType, packedWeights, offsets, sizes, getUnitStrides(rewriter, 3)); auto collapsedType = RankedTensorType::get({tiling.tileInputRows, tiling.tileOutputChannels}, packedWeightType.getElementType()); return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, slice, SmallVector { {0, 1}, {2} }); } static Value createBiasTile( Value bias, Value channelTileIndex, const Tiling& tiling, PatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); auto biasTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, biasType.getElementType()); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value channelOffset = tiling.tileOutputChannels == 1 ? channelTileIndex : affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp); SmallVector offsets {rewriter.getIndexAttr(0), channelOffset}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)}; return tensor::ExtractSliceOp::create(rewriter, loc, biasTileType, bias, offsets, sizes, getUnitStrides(rewriter, 2)); } static Value insertOutputTile(Value rowTile, Value rowAccumulator, Value channelTileIndex, const Tiling& tiling, PatternRewriter& rewriter, Location loc) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value channelOffset = tiling.tileOutputChannels == 1 ? channelTileIndex : affineMulConst(rewriter, loc, channelTileIndex, tiling.tileOutputChannels, anchorOp); SmallVector offsets {rewriter.getIndexAttr(0), channelOffset}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)}; return tensor::InsertSliceOp::create( rewriter, loc, rowTile, rowAccumulator, offsets, sizes, getUnitStrides(rewriter, 2)); } static FailureOr reconstructDepthwiseGemmRows(Value pieces, RankedTensorType piecesType, RankedTensorType gemmOutType, const Tiling& tiling, PatternRewriter& rewriter, Location loc) { auto collectedOp = createSpatCompute<1>(rewriter, loc, TypeRange {gemmOutType}, {}, pieces, [&](Value piecesArg) { auto rowType = RankedTensorType::get({1, gemmOutType.getDimSize(1)}, gemmOutType.getElementType()); Value outputInit = tensor::EmptyOp::create(rewriter, loc, gemmOutType.getShape(), gemmOutType.getElementType()); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, tiling.totalPatches); Value cNumChannelTiles = getOrCreateIndexConstant(rewriter, anchorOp, tiling.numChannelTiles); auto patchLoop = buildNormalizedScfFor( rewriter, loc, c0, cNumPatches, c1, ValueRange {outputInit}, [&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange patchIterArgs, SmallVectorImpl& patchYielded) { Value outputAcc = patchIterArgs.front(); Value rowInit = tensor::EmptyOp::create(rewriter, nestedLoc, rowType.getShape(), rowType.getElementType()); auto tileLoop = buildNormalizedScfFor( rewriter, nestedLoc, c0, cNumChannelTiles, c1, ValueRange {rowInit}, [&](OpBuilder&, Location tileLoc, Value channelTileIndex, ValueRange tileIterArgs, SmallVectorImpl& tileYielded) { Value rowAcc = tileIterArgs.front(); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); AffineExpr d1 = getAffineDimExpr(1, context); Value laneIndex = createOrFoldAffineApply( rewriter, tileLoc, (d0 * tiling.totalPatches) + d1, ValueRange {channelTileIndex, patchIndex}, anchorOp); auto rowTileType = RankedTensorType::get({1, tiling.tileOutputChannels}, piecesType.getElementType()); SmallVector pieceOffsets {laneIndex, rewriter.getIndexAttr(0)}; SmallVector pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling.tileOutputChannels)}; Value rowTile = tensor::ExtractSliceOp::create( rewriter, tileLoc, rowTileType, piecesArg, pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2)); Value rowNext = insertOutputTile(rowTile, rowAcc, channelTileIndex, tiling, rewriter, tileLoc); tileYielded.push_back(rowNext); return success(); }); if (failed(tileLoop)) return failure(); SmallVector rowOffsets {patchIndex, rewriter.getIndexAttr(0)}; SmallVector rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(gemmOutType.getDimSize(1))}; Value outputNext = tensor::InsertSliceOp::create(rewriter, nestedLoc, tileLoop->results.front(), outputAcc, rowOffsets, rowSizes, getUnitStrides(rewriter, 2)) .getResult(); patchYielded.push_back(outputNext); return success(); }); if (failed(patchLoop)) return failure(); spatial::SpatYieldOp::create(rewriter, loc, patchLoop->results.front()); return success(); }); if (failed(collectedOp)) return failure(); return collectedOp->getResult(0); } static bool canUseStructuredRewrite(const ConvLoweringState& state) { if (!getHostConstDenseElementsAttr(state.w)) return false; auto tiling = computeTiling(state.batchSize, state.numChannelsIn, state.numChannelsOut, state.wHeight, state.wWidth, state.outHeight, state.outWidth); if (!tiling) return false; if (!state.hasBias) return true; auto biasType = dyn_cast(state.b.getType()); if (!biasType) return false; if (biasType.getRank() == 1) return biasType.getDimSize(0) == state.numChannelsOut; if (biasType.getRank() != 2) return false; return biasType.getDimSize(0) == 1 && biasType.getDimSize(1) == state.numChannelsOut; } static FailureOr rewriteConv(Operation* convOp, const ConvLoweringState& state, PatternRewriter& rewriter, Location loc) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); if (!wDenseAttr) { convOp->emitOpError("requires constant-derived weights for structured depthwise Spatial lowering"); return failure(); } auto tiling = computeTiling(state.xType.getDimSize(0), state.xType.getDimSize(1), state.outType.getDimSize(1), state.wType.getDimSize(2), state.wType.getDimSize(3), state.outType.getDimSize(2), state.outType.getDimSize(3)); if (!tiling) { convOp->emitOpError("failed to derive a structured depthwise tiling that fits Spatial weighted VMM lowering"); return failure(); } Value paddedInput = createPaddedInput(state.x, state.xType, state.padHeightBegin, state.padHeightEnd, state.padWidthBegin, state.padWidthEnd, rewriter, loc); Value packedWeights = buildPackedWeights(wDenseAttr, state.wType, *tiling, rewriter, loc); Value expandedBias; SmallVector batchInputs {paddedInput}; if (state.hasBias) { expandedBias = expandBiasIfNeeded(state.b, rewriter, loc); auto biasType = dyn_cast(expandedBias.getType()); if (!biasType || biasType.getRank() != 2 || biasType.getDimSize(0) != 1 || biasType.getDimSize(1) != state.outType.getDimSize(1)) { convOp->emitOpError("requires bias sliceable as tensor<1xCout> for structured depthwise Spatial lowering"); return failure(); } batchInputs.push_back(expandedBias); } auto gemmOutType = RankedTensorType::get({tiling->totalPatches, state.outType.getDimSize(1)}, state.outType.getElementType()); auto piecesType = RankedTensorType::get({tiling->totalPatches * tiling->numChannelTiles, tiling->tileOutputChannels}, state.outType.getElementType()); auto paddedInputType = cast(paddedInput.getType()); auto inputTileType = RankedTensorType::get({1, tiling->channelsPerTile, state.wType.getDimSize(2), state.wType.getDimSize(3)}, paddedInputType.getElementType()); SmallVector batchWeights; if (tiling->numChannelTiles == 1) { Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); batchWeights.push_back(createWeightTile(packedWeights, c0, cast(packedWeights.getType()), *tiling, rewriter, loc)); } else { batchWeights.push_back(packedWeights); } auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {piecesType}, tiling->totalPatches * tiling->numChannelTiles, batchWeights, batchInputs, [&](detail::SpatComputeBatchBodyArgs args) { auto pickInputByRank = [&](int64_t rank) -> Value { for (Value input : args.inputs) { auto inputType = dyn_cast(input.getType()); if (inputType && inputType.getRank() == rank) return input; } return Value(); }; Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value patchIndex = tiling->numChannelTiles == 1 ? args.lane : affineModConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); Value channelTileIndex = tiling->numChannelTiles == 1 ? getOrCreateIndexConstant(rewriter, anchorOp, 0) : affineFloorDivConst(rewriter, loc, args.lane, tiling->totalPatches, anchorOp); Value paddedInputArg = pickInputByRank(/*rank=*/4); if (!paddedInputArg) { convOp->emitOpError("structured depthwise batch body requires a rank-4 padded input block argument"); return failure(); } Value inputTile = createInputTile(paddedInputArg, patchIndex, channelTileIndex, inputTileType, *tiling, state.strideHeight, state.strideWidth, state.dilationHeight, state.dilationWidth, state.outType.getDimSize(3), rewriter, loc); Value weightTile = tiling->numChannelTiles == 1 ? args.weights.front() : createWeightTile(args.weights.front(), channelTileIndex, cast(args.weights.front().getType()), *tiling, rewriter, loc); auto rowTileType = RankedTensorType::get({1, tiling->tileOutputChannels}, state.outType.getElementType()); Value rowTile = spatial::SpatVMMOp::create(rewriter, loc, rowTileType, weightTile, inputTile).getResult(); if (args.inputs.size() > 1) { Value biasArg = pickInputByRank(/*rank=*/2); if (!biasArg) { convOp->emitOpError("structured depthwise batch body requires a rank-2 bias block argument when bias is present"); return failure(); } Value biasTile = tiling->numChannelTiles == 1 ? biasArg : createBiasTile(biasArg, channelTileIndex, *tiling, rewriter, loc); rowTile = spatial::SpatVAddOp::create(rewriter, loc, rowTileType, rowTile, biasTile).getResult(); } SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tiling->tileOutputChannels)}; createParallelInsertSliceIntoBatchOutput( rewriter, loc, rowTile, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2)); return success(); }); if (failed(batchOp)) return failure(); auto nhwcType = RankedTensorType::get( {state.xType.getDimSize(0), state.outType.getDimSize(2), state.outType.getDimSize(3), state.outType.getDimSize(1)}, state.outType.getElementType()); Value collectedRows = batchOp->getResult(0); if (tiling->numChannelTiles != 1) { auto reconstructedRows = reconstructDepthwiseGemmRows(batchOp->getResult(0), piecesType, gemmOutType, *tiling, rewriter, loc); if (failed(reconstructedRows)) return failure(); collectedRows = *reconstructedRows; } return createCollectedConvOutput(ValueRange {collectedRows}, state.outType, gemmOutType, nhwcType, state.outType, tiling->totalPatches, state.outType.getDimSize(1), /*packFactor=*/1, {}, rewriter, loc); } } // namespace depthwise namespace standard { struct ConvGemmPlan { int64_t patchSize; int64_t numPatchesPerBatch; int64_t globalNumPatches; int64_t chunkStart; int64_t chunkNumPatches; int64_t maxParallelPixels; int64_t effectiveMaxParallelPixels; int64_t packedNumRows; RankedTensorType im2colType; RankedTensorType im2colRowType; RankedTensorType gemmInputRowsType; RankedTensorType wFlatType; RankedTensorType wTransType; RankedTensorType gemmOutType; RankedTensorType gemmOutputRowsType; RankedTensorType nhwcType; }; static ConvGemmPlan buildConvGemmPlan(const ConvLoweringState& state, bool canPackWeightsAsConstants, bool canPackBiasAsConstants, int64_t chunkStart, int64_t chunkNumPatches, std::optional forcedPackFactor = std::nullopt); static PreparedConvInput prepareInputForIm2Col(const ConvLoweringState& state, PatternRewriter& rewriter, Location loc) { if (state.padHeightBegin == 0 && state.padHeightEnd == 0 && state.padWidthBegin == 0 && state.padWidthEnd == 0) return {state.x, state.xType}; auto paddedType = RankedTensorType::get({state.batchSize, state.numChannelsIn, state.xHeight + state.padHeightBegin + state.padHeightEnd, state.xWidth + state.padWidthBegin + state.padWidthEnd}, state.xType.getElementType()); auto paddedInputOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedType}, {}, state.x, [&](Value inputArg) { Value paddedInput = createZeroPaddedTensor(inputArg, paddedType, {0, 0, state.padHeightBegin, state.padWidthBegin}, {0, 0, state.padHeightEnd, state.padWidthEnd}, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, paddedInput); }); return {paddedInputOp.getResult(0), paddedType}; } static Value createPaddedRows(Value rows, RankedTensorType rowsType, int64_t paddedRows, PatternRewriter& rewriter, Location loc) { if (rowsType.getDimSize(0) == paddedRows) return rows; auto paddedType = RankedTensorType::get({paddedRows, rowsType.getDimSize(1)}, rowsType.getElementType(), rowsType.getEncoding()); return createZeroPaddedTensor( rows, paddedType, {0, 0}, {paddedRows - rowsType.getDimSize(0), 0}, rewriter, loc); } static Value packRowsForParallelGemm( Value rows, RankedTensorType rowsType, int64_t packFactor, PatternRewriter& rewriter, Location loc) { if (packFactor == 1) return rows; const int64_t paddedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor) * packFactor; const int64_t packedNumRows = paddedNumRows / packFactor; const int64_t rowWidth = rowsType.getDimSize(1); auto groupedType = RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding()); auto packedType = RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding()); Value padded = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc); Value grouped = tensor::ExpandShapeOp::create(rewriter, loc, groupedType, padded, SmallVector { {0, 1}, {2} }); return tensor::CollapseShapeOp::create(rewriter, loc, packedType, grouped, SmallVector { {0}, {1, 2} }); } static Value unpackRowsFromParallelGemm(Value packedRows, RankedTensorType packedRowsType, int64_t unpackedRows, int64_t rowWidth, int64_t packFactor, PatternRewriter& rewriter, Location loc) { if (packFactor == 1) return packedRows; const int64_t packedNumRows = packedRowsType.getDimSize(0); const int64_t paddedNumRows = packedNumRows * packFactor; auto expandedType = RankedTensorType::get( {packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding()); auto paddedType = RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding()); auto unpackedType = RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding()); Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, packedRows, SmallVector { {0}, {1, 2} }); Value padded = tensor::CollapseShapeOp::create(rewriter, loc, paddedType, expanded, SmallVector { {0, 1}, {2} }); if (paddedNumRows == unpackedRows) return padded; SmallVector offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)}; return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, padded, offsets, sizes, getUnitStrides(rewriter, 2)); } static Value createWeightMatrix( Value weights, const ConvGemmPlan& plan, PatternRewriter& rewriter, Location loc) { auto buildWeightMatrix = [&](Value weight) -> Value { Value flattened = tensor::CollapseShapeOp::create(rewriter, loc, plan.wFlatType, weight, SmallVector { {0}, {1, 2, 3} }); return ONNXTransposeOp::create(rewriter, loc, plan.wTransType, flattened, rewriter.getI64ArrayAttr({1, 0})) .getResult(); }; if (isCompileTimeComputable(weights)) return buildWeightMatrix(weights); auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {plan.wTransType}, {}, ValueRange {weights}, [&](Value weight) { spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight)); }); return computeOp.getResult(0); } static Value createPaddedConvMatrix(Value matrix, RankedTensorType sourceType, RankedTensorType paddedType, PatternRewriter& rewriter, Location loc) { if (sourceType == paddedType) return matrix; return createZeroPaddedTensor(matrix, paddedType, {0, 0}, {paddedType.getDimSize(0) - sourceType.getDimSize(0), paddedType.getDimSize(1) - sourceType.getDimSize(1)}, rewriter, loc); } static Value createPaddedConstantMatrix(DenseElementsAttr sourceAttr, RankedTensorType sourceType, RankedTensorType paddedType, PatternRewriter& rewriter) { SmallVector paddedValues( paddedType.getNumElements(), cast(rewriter.getZeroAttr(paddedType.getElementType()))); SmallVector sourceValues(sourceAttr.getValues()); const int64_t sourceRows = sourceType.getDimSize(0); const int64_t sourceCols = sourceType.getDimSize(1); const int64_t paddedCols = paddedType.getDimSize(1); for (int64_t row = 0; row < sourceRows; ++row) for (int64_t col = 0; col < sourceCols; ++col) paddedValues[row * paddedCols + col] = sourceValues[row * sourceCols + col]; auto paddedAttr = DenseElementsAttr::get(paddedType, paddedValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), paddedAttr, paddedType); } static Value createPaddedInputKTiledWeightConstant(DenseElementsAttr sourceAttr, const ConvLoweringState& state, int64_t paddedK, int64_t paddedC, PatternRewriter& rewriter) { auto paddedType = RankedTensorType::get({paddedK, paddedC}, state.wType.getElementType()); SmallVector sourceValues(sourceAttr.getValues()); SmallVector paddedValues( paddedType.getNumElements(), cast(rewriter.getZeroAttr(paddedType.getElementType()))); for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) { for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) { for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) { for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) { const int64_t sourceFlatIndex = (((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW; const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW; paddedValues[patchIndex * paddedC + outChannel] = sourceValues[sourceFlatIndex]; } } } } auto paddedAttr = DenseElementsAttr::get(paddedType, paddedValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), paddedAttr, paddedType); } static FailureOr rewriteInputKTiledConv(const ConvLoweringState& state, ArrayRef distributedConsumers, PatternRewriter& rewriter, Location loc) { PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); ConvGeometry geo = buildConvGeometry(state); const int64_t xbarDim = geo.xbarSize; const int64_t numKSlices = ceilIntegerDivide(geo.k, xbarDim); const int64_t paddedK = numKSlices * xbarDim; const uint64_t maxLanesPerBatch = std::max(1, static_cast(crossbarCountInCore.getValue()) / static_cast(std::max(1, numKSlices * 4))); const uint64_t rowChunkWidth = std::max( 1, std::min({chooseStreamChunkPositions(geo, /*packFactor=*/1), maxLanesPerBatch, static_cast(state.outWidth)})); const auto elementType = state.outType.getElementType(); auto wDenseAttr = getHostConstDenseElementsAttr(state.w); if (!wDenseAttr) return failure(); Value paddedWeight = createPaddedInputKTiledWeightConstant(wDenseAttr, state, paddedK, xbarDim, rewriter); Value paddedBias; RankedTensorType paddedBiasType; if (state.hasBias) { Value biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); auto biasMatrixType = cast(biasMatrix.getType()); paddedBiasType = RankedTensorType::get({1, xbarDim}, elementType); if (auto biasDenseAttr = getHostConstDenseElementsAttr(state.b)) paddedBias = createPaddedConstantMatrix(biasDenseAttr, biasMatrixType, paddedBiasType, rewriter); else paddedBias = materializeOrComputeUnary( biasMatrix, paddedBiasType, rewriter, loc, [&](Value biasValue) { return createPaddedConvMatrix(biasValue, biasMatrixType, paddedBiasType, rewriter, loc); }); } SmallVector chunkRows; const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth; chunkRows.reserve( state.batchSize * state.outHeight * ceilIntegerDivide(state.outWidth, static_cast(rowChunkWidth))); for (int64_t batchIndex = 0; batchIndex < state.batchSize; ++batchIndex) { for (int64_t outHeightIndex = 0; outHeightIndex < state.outHeight; ++outHeightIndex) { for (int64_t outWidthChunkStart = 0; outWidthChunkStart < state.outWidth; outWidthChunkStart += static_cast(rowChunkWidth)) { const int64_t chunkNumPatches = std::min(static_cast(rowChunkWidth), state.outWidth - outWidthChunkStart); auto chunkRowsType = RankedTensorType::get({chunkNumPatches, state.numChannelsOut}, elementType); auto paddedRowType = RankedTensorType::get({1, xbarDim}, elementType); auto paddedChunkRowType = RankedTensorType::get({1, paddedK}, elementType); auto patchType = RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elementType); auto collapsedPatchType = RankedTensorType::get({1, geo.k}, elementType); auto weightTileType = RankedTensorType::get({xbarDim, xbarDim}, state.wType.getElementType()); auto rowType = RankedTensorType::get({1, state.numChannelsOut}, elementType); SmallVector inputsStorage {preparedInput.value}; if (state.hasBias) inputsStorage.push_back(paddedBias); ValueRange inputs(inputsStorage); auto chunkCompute = spatial::SpatCompute::create(rewriter, loc, TypeRange {chunkRowsType}, ValueRange {paddedWeight}, inputs); auto* block = new Block(); block->addArgument(paddedWeight.getType(), loc); for (Value input : inputs) block->addArgument(input.getType(), loc); chunkCompute.getBody().push_back(block); rewriter.setInsertionPointToStart(block); auto buildChunk = [&]() -> LogicalResult { Value weightArg = block->getArgument(0); Value inputArg = block->getArgument(1); Value biasArg = state.hasBias ? block->getArgument(2) : Value(); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value cBatchIndex = getOrCreateIndexConstant(rewriter, anchorOp, batchIndex); Value cZero = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value cKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices); Value cOne = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim); Value cInputHeightOffset = getOrCreateIndexConstant(rewriter, anchorOp, outHeightIndex * state.strideHeight); Value chunkRowsValue = tensor::EmptyOp::create(rewriter, loc, chunkRowsType.getShape(), elementType); auto widthLoop = buildNormalizedScfFor( rewriter, loc, cZero, getOrCreateIndexConstant(rewriter, anchorOp, chunkNumPatches), cOne, ValueRange {chunkRowsValue}, [&](OpBuilder&, Location nestedLoc, Value widthIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value laneWithChunkOffset = affineAddConst(rewriter, nestedLoc, widthIndex, outWidthChunkStart, anchorOp); Value inputWidthOffset = createOrFoldAffineApply(rewriter, nestedLoc, getAffineDimExpr(0, rewriter.getContext()) * state.strideWidth, ValueRange {laneWithChunkOffset}, anchorOp); Value patch = createConvInputPatch(inputArg, patchType, cBatchIndex, cZero, cInputHeightOffset, inputWidthOffset, state.dilationHeight, state.dilationWidth, rewriter, nestedLoc); Value patchRow = tensor::CollapseShapeOp::create(rewriter, nestedLoc, collapsedPatchType, patch, SmallVector { {0}, {1, 2, 3} }); Value paddedPatchRow = createZeroPaddedTensor( patchRow, paddedChunkRowType, {0, 0}, {0, paddedK - geo.k}, rewriter, nestedLoc); auto zeroAttr = DenseElementsAttr::get(paddedRowType, rewriter.getZeroAttr(elementType)); Value zeroRow = getOrCreateConstant(rewriter, anchorOp, zeroAttr, paddedRowType); auto kLoop = buildNormalizedScfFor( rewriter, nestedLoc, cZero, cKSlices, cOne, ValueRange {zeroRow}, [&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl& reduceYielded) { Value acc = reduceIterArgs.front(); Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar); SmallVector aOffsets {rewriter.getIndexAttr(0), kOffset}; SmallVector aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)}; SmallVector unitStrides = getUnitStrides(rewriter, 2); Value aTile = tensor::ExtractSliceOp::create( rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides); SmallVector bOffsets {kOffset, rewriter.getIndexAttr(0)}; SmallVector bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)}; Value bTile = extractStaticSliceOrIdentity( rewriter, reduceLoc, weightArg, weightTileType, bOffsets, bSizes, unitStrides); Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult(); reduceYielded.push_back( spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult()); return success(); }); if (failed(kLoop)) return failure(); Value reduced = kLoop->results.front(); if (state.hasBias) reduced = spatial::SpatVAddOp::create(rewriter, nestedLoc, paddedRowType, reduced, biasArg).getResult(); Value row = reduced; if (state.numChannelsOut != xbarDim) { SmallVector rowOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector rowSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)}; row = tensor::ExtractSliceOp::create( rewriter, nestedLoc, rowType, reduced, rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); } SmallVector outputOffsets {widthIndex, rewriter.getIndexAttr(0)}; SmallVector outputSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)}; Value updatedRows = tensor::InsertSliceOp::create( rewriter, nestedLoc, row, iterArgs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 2)); yielded.push_back(updatedRows); return success(); }); if (failed(widthLoop)) return failure(); spatial::SpatYieldOp::create(rewriter, loc, widthLoop->results.front()); return success(); }; if (failed(buildChunk())) { rewriter.setInsertionPointAfter(chunkCompute); rewriter.eraseOp(chunkCompute); return failure(); } rewriter.setInsertionPointAfter(chunkCompute); chunkRows.push_back(chunkCompute.getResult(0)); } } } auto nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, elementType); return createCollectedConvOutput( chunkRows, state.outType, cast(chunkRows.front().getType()), nhwcType, state.outType, totalPatches, state.numChannelsOut, /*packFactor=*/1, distributedConsumers, rewriter, loc); } static Value buildPackedWeights(DenseElementsAttr wDenseAttr, Value wTrans, const ConvLoweringState& state, const ConvGemmPlan& plan, PatternRewriter& rewriter, Location loc) { if (plan.effectiveMaxParallelPixels == 1) return wTrans; auto packedWeightType = RankedTensorType::get( {plan.effectiveMaxParallelPixels * plan.patchSize, plan.effectiveMaxParallelPixels * state.numChannelsOut}, state.wType.getElementType()); SmallVector sourceValues(wDenseAttr.getValues()); SmallVector packedValues(packedWeightType.getNumElements(), cast(rewriter.getZeroAttr(state.wType.getElementType()))); for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId) { for (int64_t outChannel = 0; outChannel < state.numChannelsOut; ++outChannel) { for (int64_t inChannel = 0; inChannel < state.numChannelsIn; ++inChannel) { for (int64_t kernelH = 0; kernelH < state.wHeight; ++kernelH) { for (int64_t kernelW = 0; kernelW < state.wWidth; ++kernelW) { const int64_t sourceFlatIndex = (((outChannel * state.numChannelsIn) + inChannel) * state.wHeight + kernelH) * state.wWidth + kernelW; const int64_t patchIndex = ((inChannel * state.wHeight) + kernelH) * state.wWidth + kernelW; const int64_t targetRow = copyId * plan.patchSize + patchIndex; const int64_t targetCol = copyId * state.numChannelsOut + outChannel; packedValues[targetRow * packedWeightType.getDimSize(1) + targetCol] = sourceValues[sourceFlatIndex]; } } } } } auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType); } static Value buildPackedBias(Value gemmBias, Value biasMatrix, DenseElementsAttr biasDenseAttr, const ConvLoweringState& state, const ConvGemmPlan& plan, PatternRewriter& rewriter, Location loc) { if (!state.hasBias) return gemmBias; if (plan.effectiveMaxParallelPixels == 1) return biasMatrix; SmallVector sourceValues(biasDenseAttr.getValues()); SmallVector packedValues; packedValues.reserve(plan.effectiveMaxParallelPixels * state.numChannelsOut); for (int64_t copyId = 0; copyId < plan.effectiveMaxParallelPixels; ++copyId) packedValues.append(sourceValues.begin(), sourceValues.end()); auto packedBiasType = RankedTensorType::get({1, plan.effectiveMaxParallelPixels * state.numChannelsOut}, state.outType.getElementType()); auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType); } static ConvGemmPlan buildConvGemmPlan(const ConvLoweringState& state, bool canPackWeightsAsConstants, bool canPackBiasAsConstants, int64_t chunkStart, int64_t chunkNumPatches, std::optional forcedPackFactor) { ConvGemmPlan plan; plan.patchSize = state.numChannelsIn * state.wHeight * state.wWidth; plan.numPatchesPerBatch = state.outHeight * state.outWidth; plan.globalNumPatches = state.batchSize * plan.numPatchesPerBatch; plan.chunkStart = chunkStart; plan.chunkNumPatches = chunkNumPatches; const int64_t wMaxDim = std::max(plan.patchSize, state.numChannelsOut); plan.maxParallelPixels = forcedPackFactor ? *forcedPackFactor : std::max(1, static_cast(crossbarSize.getValue()) / wMaxDim); plan.effectiveMaxParallelPixels = (canPackWeightsAsConstants && canPackBiasAsConstants) ? plan.maxParallelPixels : 1; plan.packedNumRows = ceilIntegerDivide(plan.chunkNumPatches, plan.effectiveMaxParallelPixels); auto elemType = state.xType.getElementType(); auto outElemType = state.outType.getElementType(); plan.im2colType = RankedTensorType::get({plan.chunkNumPatches, plan.patchSize}, elemType); plan.im2colRowType = RankedTensorType::get({1, plan.patchSize}, elemType); plan.gemmInputRowsType = RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * plan.patchSize}, elemType); plan.wFlatType = RankedTensorType::get({state.numChannelsOut, plan.patchSize}, state.wType.getElementType()); plan.wTransType = RankedTensorType::get({plan.patchSize, state.numChannelsOut}, state.wType.getElementType()); plan.gemmOutType = RankedTensorType::get({plan.chunkNumPatches, state.numChannelsOut}, outElemType); plan.gemmOutputRowsType = RankedTensorType::get({plan.packedNumRows, plan.effectiveMaxParallelPixels * state.numChannelsOut}, outElemType); plan.nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, outElemType); return plan; } static Value createIm2colRows(const ConvLoweringState& state, const PreparedConvInput& preparedInput, const ConvGemmPlan& plan, PatternRewriter& rewriter, Location loc) { constexpr size_t numInputs = 1; auto im2colComputeOp = createSpatCompute(rewriter, loc, TypeRange {plan.gemmInputRowsType}, {}, preparedInput.value, [&](Value xArg) { auto elemType = preparedInput.type.getElementType(); // Keep the standard im2col view of convolution, flipped so filters sit in // B / crossbar columns: // A (im2col): [numPatches, patchSize] -- one row per output spatial position // B (weights): [patchSize, cOut] // Gemm output: [numPatches, cOut] Value im2colInit = tensor::EmptyOp::create(rewriter, loc, plan.im2colType.getShape(), elemType); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches); auto im2colLoop = buildNormalizedScfFor( rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}, [&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value im2colAcc = iterArgs.front(); Value batchIndex = affineAddFloorDivConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp); Value batchPatchIndex = affineAddModConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp); Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp); Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp); Value inputHeightOffset = affineMulConst(rewriter, nestedLoc, outHeightIndex, state.strideHeight, anchorOp); Value inputWidthOffset = affineMulConst(rewriter, nestedLoc, outWidthIndex, state.strideWidth, anchorOp); auto patchType = RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType); Value patch = createConvInputPatch(xArg, patchType, batchIndex, c0, inputHeightOffset, inputWidthOffset, state.dilationHeight, state.dilationWidth, rewriter, nestedLoc); Value row = tensor::CollapseShapeOp::create(rewriter, nestedLoc, plan.im2colRowType, patch, SmallVector { {0}, {1, 2, 3} }); SmallVector rowOffsets {patchIndex, rewriter.getIndexAttr(0)}; SmallVector rowSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(plan.patchSize)}; Value next = tensor::InsertSliceOp::create( rewriter, nestedLoc, row, im2colAcc, rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); yielded.push_back(next); return success(); }); if (failed(im2colLoop)) return failure(); Value gemmInputRows = im2colLoop->results.front(); // Pack N old im2col rows into one longer row so one GEMM can cover N // pixels in parallel. The corresponding packed weight matrix contains N // block-diagonal copies of W^T, and the packed output must be unpacked // back to one row per spatial patch. if (plan.effectiveMaxParallelPixels != 1) gemmInputRows = packRowsForParallelGemm(gemmInputRows, plan.im2colType, plan.effectiveMaxParallelPixels, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); return success(); }); assert(succeeded(im2colComputeOp) && "Conv im2col compute construction must succeed"); return im2colComputeOp->getResult(0); } static Value maybeUnpackChunkRows(Value gemmRows, const ConvGemmPlan& plan, PatternRewriter& rewriter, Location loc) { if (plan.effectiveMaxParallelPixels == 1) return gemmRows; auto unpackedType = RankedTensorType::get( {plan.chunkNumPatches, plan.gemmOutType.getDimSize(1)}, plan.gemmOutType.getElementType(), plan.gemmOutType.getEncoding()); auto unpackCompute = createSpatCompute<1>(rewriter, loc, TypeRange {unpackedType}, {}, gemmRows, [&](Value rowsArg) { Value unpacked = unpackRowsFromParallelGemm(rowsArg, cast(rowsArg.getType()), plan.chunkNumPatches, plan.gemmOutType.getDimSize(1), plan.effectiveMaxParallelPixels, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, unpacked); }); return unpackCompute.getResult(0); } static Value createChunkedConvRows(const ConvLoweringState& state, const PreparedConvInput& preparedInput, Value weightMatrix, Value biasMatrix, DenseElementsAttr wDenseAttr, DenseElementsAttr biasDenseAttr, int64_t forcedPackFactor, uint64_t chunkPositions, PatternRewriter& rewriter, Location loc) { SmallVector chunkRows; const int64_t totalPatches = state.batchSize * state.outHeight * state.outWidth; for (int64_t chunkStart = 0; chunkStart < totalPatches; chunkStart += static_cast(chunkPositions)) { const int64_t chunkNumPatches = std::min(static_cast(chunkPositions), totalPatches - chunkStart); ConvGemmPlan chunkPlan = buildConvGemmPlan(state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), chunkStart, chunkNumPatches, forcedPackFactor); Value chunkInputRows = createIm2colRows(state, preparedInput, chunkPlan, rewriter, loc); Value chunkB = buildPackedWeights(wDenseAttr, weightMatrix, state, chunkPlan, rewriter, loc); Value gemmBias = createZeroGemmBias(chunkPlan.gemmOutputRowsType, rewriter); if (state.hasBias) gemmBias = state.b; Value chunkC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, chunkPlan, rewriter, loc); Value chunkGemmRows = ONNXGemmOp::create(rewriter, loc, chunkPlan.gemmOutputRowsType, chunkInputRows, chunkB, chunkC, APFloat(1.0f), APFloat(1.0f), /*transA=*/0, /*transB=*/0) .getY(); chunkRows.push_back(maybeUnpackChunkRows(chunkGemmRows, chunkPlan, rewriter, loc)); } if (chunkRows.size() == 1) return chunkRows.front(); auto rowType = RankedTensorType::get({totalPatches, state.numChannelsOut}, state.outType.getElementType()); auto collectRows = createSpatCompute(rewriter, loc, TypeRange {rowType}, {}, chunkRows, [&](ValueRange rows) { spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, rows)); }); return collectRows.getResult(0); } static Value rewritePackedIm2ColConv(const ConvLoweringState& state, ArrayRef distributedConsumers, PatternRewriter& rewriter, Location loc) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); Value biasMatrix; DenseElementsAttr biasDenseAttr; if (state.hasBias) { biasDenseAttr = getHostConstDenseElementsAttr(state.b); biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); } ConvGemmPlan plan = buildConvGemmPlan(state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), 0, state.batchSize * state.outHeight * state.outWidth); // Prepare weight matrix W for crossbar storage: // W: [Cout, Cin, KH, KW] -> [Cout, patchSize] -> [patchSize, Cout] Value weightMatrix = createWeightMatrix(state.w, plan, rewriter, loc); Value gemmInputRows = createIm2colRows(state, preparedInput, plan, rewriter, loc); Value gemmB = buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc); Value gemmBias = createZeroGemmBias(plan.gemmOutputRowsType, rewriter); if (state.hasBias) gemmBias = state.b; Value gemmC = buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc); Value gemmRows = ONNXGemmOp::create(rewriter, loc, plan.gemmOutputRowsType, gemmInputRows, gemmB, gemmC, APFloat(1.0f), APFloat(1.0f), /*transA=*/0, /*transB=*/0) .getY(); return createCollectedConvOutput(ValueRange {gemmRows}, state.outType, plan.gemmOutType, plan.nhwcType, state.outType, plan.chunkNumPatches, state.numChannelsOut, plan.effectiveMaxParallelPixels, distributedConsumers, rewriter, loc); } static Value rewriteStreamedConv(const ConvLoweringState& state, ArrayRef distributedConsumers, PatternRewriter& rewriter, Location loc, int64_t forcedPackFactor) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); PreparedConvInput preparedInput = prepareInputForIm2Col(state, rewriter, loc); Value biasMatrix; DenseElementsAttr biasDenseAttr; if (state.hasBias) { biasDenseAttr = getHostConstDenseElementsAttr(state.b); biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); } ConvGemmPlan seedPlan = buildConvGemmPlan( state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), 0, 1, forcedPackFactor); Value weightMatrix = createWeightMatrix(state.w, seedPlan, rewriter, loc); ConvGeometry geo = buildConvGeometry(state); uint64_t chunkPositions = chooseStreamChunkPositions(geo, forcedPackFactor); Value collectedRows = createChunkedConvRows(state, preparedInput, weightMatrix, biasMatrix, wDenseAttr, biasDenseAttr, forcedPackFactor, chunkPositions, rewriter, loc); auto gemmOutType = cast(collectedRows.getType()); auto nhwcType = RankedTensorType::get({state.batchSize, state.outHeight, state.outWidth, state.numChannelsOut}, state.outType.getElementType()); return createCollectedConvOutput( ValueRange {collectedRows}, state.outType, gemmOutType, nhwcType, state.outType, gemmOutType.getDimSize(0), state.numChannelsOut, /*packFactor=*/1, distributedConsumers, rewriter, loc); } } // namespace standard static RankedTensorType getRowStripFragmentType(RankedTensorType tensorType, int64_t width) { return RankedTensorType::get( {tensorType.getDimSize(0), tensorType.getDimSize(1), 1, width}, tensorType.getElementType(), tensorType.getEncoding()); } static SmallVector buildRowStripFragments(RankedTensorType tensorType) { SmallVector fragments; const int64_t height = tensorType.getDimSize(2); const int64_t width = tensorType.getDimSize(3); const int64_t channels = tensorType.getDimSize(1); fragments.reserve(height); for (int64_t row = 0; row < height; ++row) { fragments.push_back(DistributedFragmentInfo { {0, 0, row, 0}, {1, channels, 1, width}, {1, 1, 1, 1}, row, }); } return fragments; } static DistributedTensorInfo makeDistributedTensorInfo(Value storage, RankedTensorType logicalType) { DistributedTensorInfo info; info.storage = storage; info.logicalType = logicalType; info.fragments = buildRowStripFragments(logicalType); info.laneCount = logicalType.getDimSize(2); info.channels = logicalType.getDimSize(1); info.height = logicalType.getDimSize(2); info.width = logicalType.getDimSize(3); return info; } static Value createPerChannelConstantFragment(DenseElementsAttr denseAttr, RankedTensorType fragmentType, PatternRewriter& rewriter) { auto denseType = cast(denseAttr.getType()); SmallVector channelValues; channelValues.reserve(fragmentType.getDimSize(1)); SmallVector flattened(denseAttr.getValues()); if (denseType.getRank() == 1) { channelValues = flattened; } else if (denseType.getRank() == 2) { channelValues = flattened; } else { for (int64_t channel = 0; channel < denseType.getDimSize(1); ++channel) channelValues.push_back(flattened[channel]); } SmallVector values; values.reserve(fragmentType.getNumElements()); for (int64_t n = 0; n < fragmentType.getDimSize(0); ++n) for (int64_t channel = 0; channel < fragmentType.getDimSize(1); ++channel) for (int64_t h = 0; h < fragmentType.getDimSize(2); ++h) for (int64_t w = 0; w < fragmentType.getDimSize(3); ++w) values.push_back(channelValues[channel]); auto attr = DenseElementsAttr::get(fragmentType, values); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), attr, fragmentType); } static Value createZeroGemmBias(RankedTensorType gemmResultType, PatternRewriter& rewriter) { auto zeroAttr = DenseElementsAttr::get(gemmResultType, rewriter.getZeroAttr(gemmResultType.getElementType())); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), zeroAttr, gemmResultType); } static bool canDirectLowerRowStripConv(const ConvLoweringState& state, StringRef& failureReason) { if (!canConsumeRowStripHwcInput(state, failureReason)) return false; ConvGeometry geometry = buildConvGeometry(state); if (state.numChannelsOut > geometry.xbarSize) { failureReason = "unsupported_output_channels"; return false; } failureReason = ""; return true; } static FailureOr createRowStripPackedRows(Value rows, const ConvLoweringState& state, PatternRewriter& rewriter, Location loc) { auto rowsType = dyn_cast(rows.getType()); if (!rowsType || !rowsType.hasStaticShape() || rowsType.getRank() != 2) return failure(); if (state.batchSize != 1) return failure(); if (state.outType.getRank() != 4 || !state.outType.hasStaticShape()) return failure(); const int64_t outHeight = state.outType.getDimSize(2); const int64_t outWidth = state.outType.getDimSize(3); const int64_t outChannels = state.outType.getDimSize(1); if (rowsType.getDimSize(0) != outHeight * outWidth || rowsType.getDimSize(1) != outChannels) return failure(); auto packedType = RankedTensorType::get({outHeight, outWidth, outChannels}, rowsType.getElementType(), rowsType.getEncoding()); auto packedRows = createSpatCompute<1>(rewriter, loc, TypeRange {packedType}, {}, rows, [&](Value rowValues) { Value packed = tensor::ExpandShapeOp::create( rewriter, loc, packedType, rowValues, SmallVector {{0, 1}, {2}}); spatial::SpatYieldOp::create(rewriter, loc, packed); }); return packedRows.getResult(0); } static FailureOr createConvOutputFromRowStripHwc(Value inputHwc, const ConvLoweringState& state, PatternRewriter& rewriter, Location loc) { auto inputType = dyn_cast(inputHwc.getType()); if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 3) return failure(); if (inputType.getDimSize(0) != state.xHeight || inputType.getDimSize(1) != state.xWidth || inputType.getDimSize(2) != state.numChannelsIn) return failure(); StringRef failureReason; if (!canDirectLowerRowStripConv(state, failureReason)) return failure(); ConvRowDemand demand = buildConvRowDemand(RowInterval {0, state.outHeight}, state); if (!covers(demand.acquiredInputRows, demand.neededInputRows)) return failure(); ConvGeometry geometry = buildConvGeometry(state); const int64_t xbarDim = geometry.xbarSize; const int64_t patchSize = state.numChannelsIn * state.wHeight * state.wWidth; const int64_t numKSlices = ceilIntegerDivide(patchSize, xbarDim); const int64_t paddedK = numKSlices * xbarDim; auto elementType = inputType.getElementType(); auto paddedInputType = RankedTensorType::get({state.xHeight + state.padHeightBegin + state.padHeightEnd, state.xWidth + state.padWidthBegin + state.padWidthEnd, state.numChannelsIn}, elementType, inputType.getEncoding()); auto paddedPatchType = RankedTensorType::get({state.wHeight, state.wWidth, 1}, elementType, inputType.getEncoding()); auto flatPatchType = RankedTensorType::get({state.wHeight * state.wWidth}, elementType, inputType.getEncoding()); auto rowChunkType = RankedTensorType::get({1, state.wHeight * state.wWidth}, elementType, inputType.getEncoding()); auto rowType = RankedTensorType::get({1, state.numChannelsOut}, state.outType.getElementType()); auto packedOutputType = RankedTensorType::get({state.outHeight, state.outWidth, state.numChannelsOut}, state.outType.getElementType()); auto packedOutputSliceType = RankedTensorType::get({1, 1, state.numChannelsOut}, state.outType.getElementType()); auto paddedRowType = RankedTensorType::get({1, xbarDim}, state.outType.getElementType()); auto paddedPatchRowType = RankedTensorType::get({1, paddedK}, elementType, inputType.getEncoding()); auto paddedWeightTileType = RankedTensorType::get({xbarDim, xbarDim}, state.wType.getElementType()); auto weightDenseAttr = getHostConstDenseElementsAttr(state.w); if (!weightDenseAttr) return failure(); Value paddedWeights = standard::createPaddedInputKTiledWeightConstant(weightDenseAttr, state, paddedK, xbarDim, rewriter); Value paddedBias; if (state.hasBias) { Value biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); auto biasMatrixType = cast(biasMatrix.getType()); auto paddedBiasType = RankedTensorType::get({1, xbarDim}, state.outType.getElementType()); if (auto biasDenseAttr = getHostConstDenseElementsAttr(state.b)) paddedBias = standard::createPaddedConstantMatrix(biasDenseAttr, biasMatrixType, paddedBiasType, rewriter); else paddedBias = materializeOrComputeUnary( biasMatrix, paddedBiasType, rewriter, loc, [&](Value biasValue) { return standard::createPaddedConvMatrix(biasValue, biasMatrixType, paddedBiasType, rewriter, loc); }); } auto paddedInputOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, inputHwc, [&](Value hwcInputArg) { Value paddedInput = createZeroPaddedTensor(hwcInputArg, paddedInputType, {state.padHeightBegin, state.padWidthBegin, 0}, {state.padHeightEnd, state.padWidthEnd, 0}, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, paddedInput); }); SmallVector batchInputs {paddedInputOp.getResult(0)}; if (state.hasBias) batchInputs.push_back(paddedBias); auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {packedOutputType}, state.outHeight, ValueRange {paddedWeights}, batchInputs, [&](detail::SpatComputeBatchBodyArgs args) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices); Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth); Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn); Value localHeightOffset = args.lane; Value packedRowInit = tensor::EmptyOp::create(rewriter, loc, ArrayRef {1, state.outWidth, state.numChannelsOut}, elementType); auto widthLoop = buildNormalizedScfFor( rewriter, loc, c0, cOutWidth, c1, ValueRange {packedRowInit}, [&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl& widthYielded) { Value localWidthOffset = widthIndex; Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef {1, patchSize}, elementType); auto rowLoop = buildNormalizedScfFor( rewriter, widthLoc, c0, cNumChannels, c1, ValueRange {rowInit}, [&](OpBuilder&, Location rowLoc, Value channel, ValueRange rowIterArgs, SmallVectorImpl& rowYielded) { SmallVector patchOffsets {localHeightOffset, localWidthOffset, channel}; SmallVector patchSizes { rewriter.getIndexAttr(state.wHeight), rewriter.getIndexAttr(state.wWidth), rewriter.getIndexAttr(1)}; Value channelPatch = tensor::ExtractSliceOp::create( rewriter, rowLoc, paddedPatchType, args.inputs.front(), patchOffsets, patchSizes, getUnitStrides(rewriter, 3)); Value flatPatch = tensor::CollapseShapeOp::create( rewriter, rowLoc, flatPatchType, channelPatch, SmallVector {{0, 1, 2}}); Value rowChunk = tensor::ExpandShapeOp::create( rewriter, rowLoc, rowChunkType, flatPatch, SmallVector {{0, 1}}); Value flatOffset = affineMulConst( rewriter, rowLoc, channel, state.wHeight * state.wWidth, anchorOp); SmallVector rowOffsets {rewriter.getIndexAttr(0), flatOffset}; SmallVector rowSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)}; Value nextRow = tensor::InsertSliceOp::create( rewriter, rowLoc, rowChunk, rowIterArgs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); rowYielded.push_back(nextRow); return success(); }); if (failed(rowLoop)) return failure(); Value paddedRow = rowLoop->results.front(); if (patchSize != paddedK) paddedRow = createZeroPaddedTensor( paddedRow, paddedPatchRowType, {0, 0}, {0, paddedK - patchSize}, rewriter, widthLoc); auto zeroAttr = DenseElementsAttr::get(paddedRowType, rewriter.getZeroAttr(state.outType.getElementType())); Value zeroRow = getOrCreateConstant(rewriter, anchorOp, zeroAttr, paddedRowType); auto kLoop = buildNormalizedScfFor( rewriter, widthLoc, c0, cNumKSlices, c1, ValueRange {zeroRow}, [&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl& reduceYielded) { Value kOffset = affineMulConst(rewriter, reduceLoc, kSlice, xbarDim, anchorOp); SmallVector aOffsets {rewriter.getIndexAttr(0), kOffset}; SmallVector aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)}; Value aTile = tensor::ExtractSliceOp::create( rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2)); SmallVector bOffsets {kOffset, rewriter.getIndexAttr(0)}; SmallVector bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)}; Value bTile = extractStaticSliceOrIdentity(rewriter, reduceLoc, args.weights.front(), paddedWeightTileType, bOffsets, bSizes, getUnitStrides(rewriter, 2)); Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult(); reduceYielded.push_back( spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult()); return success(); }); if (failed(kLoop)) return failure(); Value rowResult = kLoop->results.front(); if (state.hasBias) rowResult = spatial::SpatVAddOp::create(rewriter, widthLoc, paddedRowType, rowResult, args.inputs[1]).getResult(); Value outputRow = rowResult; if (state.numChannelsOut != xbarDim) { SmallVector outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector outputSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)}; outputRow = tensor::ExtractSliceOp::create( rewriter, widthLoc, rowType, rowResult, outputOffsets, outputSizes, getUnitStrides(rewriter, 2)); } Value outputFragment = tensor::ExpandShapeOp::create(rewriter, widthLoc, packedOutputSliceType, outputRow, SmallVector {{0}, {1, 2}}); SmallVector rowOffsets {rewriter.getIndexAttr(0), widthIndex, rewriter.getIndexAttr(0)}; SmallVector rowSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.numChannelsOut)}; Value nextPackedRow = tensor::InsertSliceOp::create( rewriter, widthLoc, outputFragment, widthIterArgs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 3)); widthYielded.push_back(nextPackedRow); return success(); }); if (failed(widthLoop)) return failure(); SmallVector batchOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector batchSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.outWidth), rewriter.getIndexAttr(state.numChannelsOut)}; createParallelInsertSliceIntoBatchOutput( rewriter, loc, widthLoop->results.front(), args.outputs.front(), batchOffsets, batchSizes, getUnitStrides(rewriter, 3)); return success(); }); if (failed(batchOp)) return failure(); return batchOp->getResult(0); } static FailureOr createConvRowsFromRowStripInput(const ConvLoweringState& state, [[maybe_unused]] const ConvLoweringDecision& decision, Value rowStripInput, PatternRewriter& rewriter, Location loc) { return createConvOutputFromRowStripHwc(rowStripInput, state, rewriter, loc); } static Value createFragmentConstant(const DistributedTensorStep& step, RankedTensorType fragmentType, PatternRewriter& rewriter) { if (step.constantKind == DistributedTensorConstantKind::PerChannel) return createPerChannelConstantFragment(step.constantAttr, fragmentType, rewriter); Attribute splatValue = step.constantAttr.getSplatValue(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), DenseElementsAttr::get(fragmentType, splatValue), fragmentType); } static Value createFragmentReciprocalConstant(const DistributedTensorStep& step, RankedTensorType fragmentType, PatternRewriter& rewriter) { SmallVector values; if (step.constantKind == DistributedTensorConstantKind::PerChannel) { auto denseType = cast(step.constantAttr.getType()); SmallVector channelValues; for (const APFloat& value : step.constantAttr.getValues()) channelValues.push_back(value); values.reserve(fragmentType.getNumElements()); for (int64_t n = 0; n < fragmentType.getDimSize(0); ++n) for (int64_t channel = 0; channel < fragmentType.getDimSize(1); ++channel) for (int64_t h = 0; h < fragmentType.getDimSize(2); ++h) for (int64_t w = 0; w < fragmentType.getDimSize(3); ++w) { APFloat reciprocal = channelValues[channel]; APFloat one(reciprocal.getSemantics(), 1); [[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven); assert(!(status & APFloat::opInvalidOp) && "distributed conv div requires finite non-zero constant"); values.push_back(one); } (void)denseType; } else { APFloat reciprocal = cast(step.constantAttr).getSplatValue(); APFloat one(reciprocal.getSemantics(), 1); [[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven); assert(!(status & APFloat::opInvalidOp) && "distributed conv div requires finite non-zero constant"); values.assign(fragmentType.getNumElements(), one); } return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), DenseFPElementsAttr::get(fragmentType, values), fragmentType); } [[maybe_unused]] static FailureOr createConvRowsForStrategy(const ConvLoweringState& state, const ConvLoweringDecision& decision, PatternRewriter& rewriter, Location loc) { auto wDenseAttr = getHostConstDenseElementsAttr(state.w); PreparedConvInput preparedInput = standard::prepareInputForIm2Col(state, rewriter, loc); Value biasMatrix; DenseElementsAttr biasDenseAttr; if (state.hasBias) { biasDenseAttr = getHostConstDenseElementsAttr(state.b); biasMatrix = expandBiasIfNeeded(state.b, rewriter, loc); } switch (decision.strategy) { case PimConvLoweringLegacy: case PimConvLoweringPackedIm2Col: { standard::ConvGemmPlan plan = standard::buildConvGemmPlan( state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), 0, state.batchSize * state.outHeight * state.outWidth); Value weightMatrix = standard::createWeightMatrix(state.w, plan, rewriter, loc); Value gemmInputRows = standard::createIm2colRows(state, preparedInput, plan, rewriter, loc); Value gemmB = standard::buildPackedWeights(wDenseAttr, weightMatrix, state, plan, rewriter, loc); Value gemmBias = createZeroGemmBias(plan.gemmOutputRowsType, rewriter); if (state.hasBias) gemmBias = state.b; Value gemmC = standard::buildPackedBias(gemmBias, biasMatrix, biasDenseAttr, state, plan, rewriter, loc); Value gemmRows = ONNXGemmOp::create(rewriter, loc, plan.gemmOutputRowsType, gemmInputRows, gemmB, gemmC, APFloat(1.0f), APFloat(1.0f), /*transA=*/0, /*transB=*/0) .getY(); return standard::maybeUnpackChunkRows(gemmRows, plan, rewriter, loc); } case PimConvLoweringStreamedPatch: case PimConvLoweringOutputChannelTiled: case PimConvLoweringTiled2D: case PimConvLoweringStreamedPacked: { standard::ConvGemmPlan seedPlan = standard::buildConvGemmPlan( state, static_cast(wDenseAttr), !state.hasBias || static_cast(biasDenseAttr), 0, 1, decision.strategy == PimConvLoweringStreamedPacked ? buildConvGeometry(state).pack : 1); Value weightMatrix = standard::createWeightMatrix(state.w, seedPlan, rewriter, loc); ConvGeometry geo = buildConvGeometry(state); int64_t packFactor = decision.strategy == PimConvLoweringStreamedPacked ? geo.pack : 1; uint64_t chunkPositions = chooseStreamChunkPositions(geo, packFactor); return standard::createChunkedConvRows(state, preparedInput, weightMatrix, biasMatrix, wDenseAttr, biasDenseAttr, packFactor, chunkPositions, rewriter, loc); } default: return failure(); } } [[maybe_unused]] static FailureOr createDistributedTensorFromRows(Value rows, RankedTensorType logicalType, PatternRewriter& rewriter, Location loc) { const int64_t width = logicalType.getDimSize(3); const int64_t height = logicalType.getDimSize(2); auto rowsType = cast(rows.getType()); auto rowSliceType = RankedTensorType::get({width, logicalType.getDimSize(1)}, logicalType.getElementType(), rowsType.getEncoding()); auto channelWidthType = RankedTensorType::get({logicalType.getDimSize(1), width}, logicalType.getElementType(), rowsType.getEncoding()); auto fragmentType = getRowStripFragmentType(logicalType, width); auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {logicalType}, height, {}, ValueRange {rows}, [&](detail::SpatComputeBatchBodyArgs args) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value rowStart = affineMulConst(rewriter, loc, args.lane, width, anchorOp); SmallVector rowOffsets {rowStart, rewriter.getIndexAttr(0)}; SmallVector rowSizes {rewriter.getIndexAttr(width), rewriter.getIndexAttr(logicalType.getDimSize(1))}; Value rowSlice = tensor::ExtractSliceOp::create( rewriter, loc, rowSliceType, args.inputs.front(), rowOffsets, rowSizes, getUnitStrides(rewriter, 2)); Value channelWidth = ONNXTransposeOp::create( rewriter, loc, channelWidthType, rowSlice, rewriter.getI64ArrayAttr({1, 0})).getResult(); Value fragment = tensor::ExpandShapeOp::create(rewriter, loc, fragmentType, channelWidth, SmallVector {{0, 1}, {2, 3}}); SmallVector outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)}; SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(logicalType.getDimSize(1)), rewriter.getIndexAttr(1), rewriter.getIndexAttr(width)}; createParallelInsertSliceIntoBatchOutput( rewriter, loc, fragment, args.outputs.front(), outputOffsets, outputSizes, getUnitStrides(rewriter, 4)); return success(); }); if (failed(batchOp)) return failure(); return makeDistributedTensorInfo(batchOp->getResult(0), logicalType); } [[maybe_unused]] static FailureOr applyDistributedPreservingStep(const DistributedTensorInfo& inputInfo, const DistributedTensorStep& step, PatternRewriter& rewriter, Location loc) { auto logicalType = inputInfo.logicalType; const int64_t width = logicalType.getDimSize(3); auto fragmentType = getRowStripFragmentType(logicalType, width); auto batchOp = createSpatComputeBatch(rewriter, loc, TypeRange {logicalType}, inputInfo.laneCount, {}, ValueRange {inputInfo.storage}, [&](detail::SpatComputeBatchBodyArgs args) { SmallVector offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(logicalType.getDimSize(1)), rewriter.getIndexAttr(1), rewriter.getIndexAttr(width)}; Value fragment = tensor::ExtractSliceOp::create( rewriter, loc, fragmentType, args.inputs.front(), offsets, sizes, getUnitStrides(rewriter, 4)); switch (step.kind) { case DistributedTensorOpKind::Relu: fragment = spatial::SpatReluOp::create(rewriter, loc, fragmentType, fragment).getResult(); break; case DistributedTensorOpKind::Sigmoid: fragment = spatial::SpatSigmoidOp::create(rewriter, loc, fragmentType, fragment).getResult(); break; case DistributedTensorOpKind::Add: { Value constant = createFragmentConstant(step, fragmentType, rewriter); fragment = spatial::SpatVAddOp::create(rewriter, loc, fragmentType, fragment, constant).getResult(); break; } case DistributedTensorOpKind::Sub: { Value constant = createFragmentConstant(step, fragmentType, rewriter); Value lhs = step.fragmentOnLhs ? fragment : constant; Value rhs = step.fragmentOnLhs ? constant : fragment; fragment = spatial::SpatVSubOp::create(rewriter, loc, fragmentType, lhs, rhs).getResult(); break; } case DistributedTensorOpKind::Mul: { Value constant = createFragmentConstant(step, fragmentType, rewriter); fragment = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, fragment, constant).getResult(); break; } case DistributedTensorOpKind::Div: { Value constant = createFragmentReciprocalConstant(step, fragmentType, rewriter); fragment = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, fragment, constant).getResult(); break; } case DistributedTensorOpKind::Conv: return failure(); } createParallelInsertSliceIntoBatchOutput( rewriter, loc, fragment, args.outputs.front(), offsets, sizes, getUnitStrides(rewriter, 4)); return success(); }); if (failed(batchOp)) return failure(); return makeDistributedTensorInfo(batchOp->getResult(0), logicalType); } static Value createCollectedConvOutput(ValueRange gemmRows, Type convType, RankedTensorType gemmOutType, RankedTensorType nhwcType, RankedTensorType outType, int64_t numPatches, int64_t numChannelsOut, int64_t packFactor, ArrayRef distributedConsumers, PatternRewriter& rewriter, Location loc) { auto materializeSplatTensor = [&](DenseElementsAttr denseAttr, RankedTensorType targetType) { Attribute splatValue = denseAttr.getSplatValue(); auto targetAttr = DenseElementsAttr::get(targetType, splatValue); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), targetAttr, targetType); }; auto materializeReciprocalSplatTensor = [&](DenseFPElementsAttr denseAttr, RankedTensorType targetType) { APFloat reciprocal = denseAttr.getSplatValue(); APFloat one(reciprocal.getSemantics(), 1); [[maybe_unused]] APFloat::opStatus status = one.divide(reciprocal, APFloat::rmNearestTiesToEven); assert(!(status & APFloat::opInvalidOp) && "distributed conv div consumer requires finite non-zero scalar"); return getOrCreateConstant( rewriter, rewriter.getInsertionBlock()->getParentOp(), DenseFPElementsAttr::get(targetType, one), targetType); }; auto applyDistributedConsumers = [&](Value fragment) { Value current = fragment; for (const DistributedTensorStep& step : distributedConsumers) { auto fragmentType = cast(current.getType()); switch (step.kind) { case DistributedTensorOpKind::Relu: current = spatial::SpatReluOp::create(rewriter, loc, fragmentType, current).getResult(); break; case DistributedTensorOpKind::Sigmoid: current = spatial::SpatSigmoidOp::create(rewriter, loc, fragmentType, current).getResult(); break; case DistributedTensorOpKind::Add: { Value splat = materializeSplatTensor(step.constantAttr, fragmentType); current = spatial::SpatVAddOp::create(rewriter, loc, fragmentType, current, splat).getResult(); break; } case DistributedTensorOpKind::Sub: { Value splat = materializeSplatTensor(step.constantAttr, fragmentType); Value lhs = step.fragmentOnLhs ? current : splat; Value rhs = step.fragmentOnLhs ? splat : current; current = spatial::SpatVSubOp::create(rewriter, loc, fragmentType, lhs, rhs).getResult(); break; } case DistributedTensorOpKind::Mul: { Value splat = materializeSplatTensor(step.constantAttr, fragmentType); current = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, current, splat).getResult(); break; } case DistributedTensorOpKind::Div: { auto reciprocalAttr = cast(step.constantAttr); Value reciprocal = materializeReciprocalSplatTensor(reciprocalAttr, fragmentType); current = spatial::SpatVMulOp::create(rewriter, loc, fragmentType, current, reciprocal).getResult(); break; } case DistributedTensorOpKind::Conv: llvm_unreachable("conv-consuming distributed chains should not materialize through createCollectedConvOutput"); } } return current; }; auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) { SmallVector transformedRows; transformedRows.reserve(gemmRowArgs.size()); for (Value row : gemmRowArgs) transformedRows.push_back(applyDistributedConsumers(row)); Value gemmOut; if (packFactor == 1) { gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, transformedRows); } else { Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, transformedRows); gemmOut = standard::unpackRowsFromParallelGemm( packedOutput, cast(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc); } // Restore output layout: // [numPatches, numChannelsOut] // -> [N, Hout, Wout, Cout] // -> [N, Cout, Hout, Wout] Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, loc, nhwcType, gemmOut, SmallVector { {0, 1, 2}, {3} }); Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); spatial::SpatYieldOp::create(rewriter, loc, nchwOut); }); return collectComputeOp.getResult(0); } static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b) { ConvLoweringState state; state.x = x; state.w = w; state.b = b; state.xType = cast(state.x.getType()); state.wType = cast(state.w.getType()); state.outType = cast(convOp.getY().getType()); if (!state.xType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input"); return failure(); } if (!state.wType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight"); return failure(); } if (!state.outType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result"); return failure(); } if (state.xType.getRank() != 4) { pim::emitUnsupportedRankDiagnostic(convOp, "conv input", state.xType.getRank(), {4}); return failure(); } if (state.wType.getRank() != 4) { pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", state.wType.getRank(), {4}); return failure(); } if (state.outType.getRank() != 4) { pim::emitUnsupportedRankDiagnostic(convOp, "conv result", state.outType.getRank(), {4}); return failure(); } state.group = convOp.getGroup(); if (state.group < 1) { convOp.emitOpError("requires group >= 1 for Spatial lowering"); return failure(); } state.batchSize = state.xType.getDimSize(0); state.numChannelsIn = state.xType.getDimSize(1); state.xHeight = state.xType.getDimSize(2); state.xWidth = state.xType.getDimSize(3); state.numChannelsOut = state.wType.getDimSize(0); state.wHeight = state.wType.getDimSize(2); state.wWidth = state.wType.getDimSize(3); state.outHeight = state.outType.getDimSize(2); state.outWidth = state.outType.getDimSize(3); state.hasBias = state.b && !isa(state.b.getDefiningOp()); if (state.numChannelsIn % state.group != 0) { convOp.emitOpError() << "requires input channels " << state.numChannelsIn << " to be divisible by group " << state.group << " for Spatial lowering"; return failure(); } if (state.numChannelsOut % state.group != 0) { convOp.emitOpError() << "requires output channels " << state.numChannelsOut << " to be divisible by group " << state.group << " for Spatial lowering"; return failure(); } state.numChannelsInPerGroup = state.numChannelsIn / state.group; state.numChannelsOutPerGroup = state.numChannelsOut / state.group; if (state.wType.getDimSize(1) != state.numChannelsInPerGroup) { convOp.emitOpError() << "requires grouped conv weight input channels " << state.wType.getDimSize(1) << " to match input channels per group " << state.numChannelsInPerGroup << " for Spatial lowering"; return failure(); } if (state.wType.getDimSize(0) != state.numChannelsOut) { convOp.emitOpError() << "requires weight output channels " << state.wType.getDimSize(0) << " to match result channels " << state.numChannelsOut << " for Spatial lowering"; return failure(); } const auto stridesAttr = convOp.getStrides(); const auto dilationsAttr = convOp.getDilations(); const auto padsAttr = convOp.getPads(); if (stridesAttr && stridesAttr->size() != 2) { convOp.emitOpError("requires exactly two stride values for Spatial lowering"); return failure(); } if (dilationsAttr && dilationsAttr->size() != 2) { convOp.emitOpError("requires exactly two dilation values for Spatial lowering"); return failure(); } if (padsAttr && padsAttr->size() != 4) { convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering"); return failure(); } state.strideHeight = getOptionalI64Attr(stridesAttr, 0, 1); state.strideWidth = getOptionalI64Attr(stridesAttr, 1, 1); state.dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1); state.dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1); state.padHeightBegin = 0; state.padHeightEnd = 0; state.padWidthBegin = 0; state.padWidthEnd = 0; if (padsAttr) { state.padHeightBegin = getI64Attr(*padsAttr, 0); state.padWidthBegin = getI64Attr(*padsAttr, 1); state.padHeightEnd = getI64Attr(*padsAttr, 2); state.padWidthEnd = getI64Attr(*padsAttr, 3); return state; } const auto autoPad = convOp.getAutoPad(); if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { const int64_t effectiveKernelH = (state.wHeight - 1) * state.dilationHeight + 1; const int64_t effectiveKernelW = (state.wWidth - 1) * state.dilationWidth + 1; const int64_t totalPadH = std::max(static_cast(0), (state.outHeight - 1) * state.strideHeight + effectiveKernelH - state.xHeight); const int64_t totalPadW = std::max(static_cast(0), (state.outWidth - 1) * state.strideWidth + effectiveKernelW - state.xWidth); if (autoPad == "SAME_UPPER") { state.padHeightBegin = totalPadH / 2; state.padHeightEnd = totalPadH - state.padHeightBegin; state.padWidthBegin = totalPadW / 2; state.padWidthEnd = totalPadW - state.padWidthBegin; } else { state.padHeightEnd = totalPadH / 2; state.padHeightBegin = totalPadH - state.padHeightEnd; state.padWidthEnd = totalPadW / 2; state.padWidthBegin = totalPadW - state.padWidthEnd; } return state; } if (autoPad != "NOTSET" && autoPad != "VALID") { convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering"; return failure(); } return state; } static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor) { return analyzeConvLoweringState(convOp, convOpAdaptor.getX(), convOpAdaptor.getW(), convOpAdaptor.getB()); } static FailureOr analyzeConvLoweringState(spatial::SpatConv2DPlanOp planOp) { ConvLoweringState state; state.x = planOp.getInput(); state.w = planOp.getWeight(); state.b = planOp.getBias() ? planOp.getBias() : Value(); state.xType = dyn_cast(state.x.getType()); state.wType = dyn_cast(state.w.getType()); state.outType = dyn_cast(planOp.getOutput().getType()); if (!state.xType || !state.wType || !state.outType) return planOp.emitOpError("requires ranked tensor input, weight, and output"), failure(); if (!state.xType.hasStaticShape() || !state.wType.hasStaticShape() || !state.outType.hasStaticShape()) return planOp.emitOpError("requires static input, weight, and output shapes"), failure(); if (state.xType.getRank() != 4 || state.wType.getRank() != 4 || state.outType.getRank() != 4) return planOp.emitOpError("requires rank-4 input, weight, and output tensors"), failure(); state.group = planOp.getGroup(); if (state.group < 1) return planOp.emitOpError("requires group >= 1"), failure(); state.batchSize = state.xType.getDimSize(0); state.numChannelsIn = state.xType.getDimSize(1); state.xHeight = state.xType.getDimSize(2); state.xWidth = state.xType.getDimSize(3); state.numChannelsOut = state.wType.getDimSize(0); state.wHeight = state.wType.getDimSize(2); state.wWidth = state.wType.getDimSize(3); state.outHeight = state.outType.getDimSize(2); state.outWidth = state.outType.getDimSize(3); state.hasBias = static_cast(planOp.getBias()); if (state.numChannelsIn % state.group != 0 || state.numChannelsOut % state.group != 0) return planOp.emitOpError("requires input and output channels divisible by group"), failure(); state.numChannelsInPerGroup = state.numChannelsIn / state.group; state.numChannelsOutPerGroup = state.numChannelsOut / state.group; if (state.wType.getDimSize(1) != state.numChannelsInPerGroup) return planOp.emitOpError("requires grouped conv weight channels to match input channels per group"), failure(); auto pads = planOp.getPads(); auto strides = planOp.getStrides(); auto dilations = planOp.getDilations(); if (pads.size() != 4 || strides.size() != 2 || dilations.size() != 2) return planOp.emitOpError("requires 4 pads, 2 strides, and 2 dilations"), failure(); state.padHeightBegin = pads[0]; state.padWidthBegin = pads[1]; state.padHeightEnd = pads[2]; state.padWidthEnd = pads[3]; state.strideHeight = strides[0]; state.strideWidth = strides[1]; state.dilationHeight = dilations[0]; state.dilationWidth = dilations[1]; return state; } static FailureOr resolveRequestedConvLoweringStrategy(Operation* op) { if (!useExperimentalConvImpl) return pimConvLowering.getValue(); if (pimConvLowering != PimConvLoweringAuto && pimConvLowering != PimConvLoweringPackedIm2Col) { op->emitOpError() << "--use-experimental-conv-impl conflicts with --pim-conv-lowering=" << stringifyConvLoweringStrategy(pimConvLowering); return failure(); } return PimConvLoweringPackedIm2Col; } static LogicalResult verifyForcedConvLoweringStrategy(Operation* op, const ConvGeometry& geo, PimConvLoweringType strategy) { switch (strategy) { case PimConvLoweringAuto: case PimConvLoweringLegacy: return success(); case PimConvLoweringDepthwise: if (geo.isDepthwise) return success(); return op->emitOpError("forced depthwise Conv lowering requires a depthwise convolution"); case PimConvLoweringPackedIm2Col: if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements) return success(); return op->emitOpError("forced packed-im2col Conv lowering requires K/C to fit, pack >= 2, and im2col within budget"); case PimConvLoweringStreamedPatch: if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize) return success(); return op->emitOpError("forced streamed-patch Conv lowering requires K and C to each fit one crossbar"); case PimConvLoweringStreamedPacked: if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2) return success(); return op->emitOpError("forced streamed-packed Conv lowering requires K/C to fit and pack >= 2"); case PimConvLoweringOutputChannelTiled: if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize) return success(); return op->emitOpError("forced output-channel-tiled Conv lowering requires K <= X and C > X"); case PimConvLoweringInputKTiled: if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize) return success(); return op->emitOpError("forced input-k-tiled Conv lowering requires K > X and C <= X"); case PimConvLoweringTiled2D: if (geo.k > geo.xbarSize && geo.c > geo.xbarSize) return success(); return op->emitOpError("forced tiled-2d Conv lowering requires K > X and C > X"); } llvm_unreachable("unknown conv lowering strategy"); } static FailureOr lowerDenseSelectedConvPlan(Operation* op, const ConvLoweringState& state, PimConvLoweringType strategy, PatternRewriter& rewriter, Location loc); static ConvLoweringState makeGroupedConvLoweringState(const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType); static FailureOr buildConvValueForStrategy(Operation* op, Location loc, const ConvLoweringState& state, const ConvLoweringDecision& decision, const DistributedConvAnalysis& analysis, ArrayRef distributedConsumers, PatternRewriter& rewriter); static FailureOr buildGroupedConvValue(Operation* op, Location loc, const ConvLoweringState& state, const ConvLoweringDecision& decision, PatternRewriter& rewriter); static FailureOr lowerGroupedSelectedConvPlan(Operation* op, const ConvLoweringState& state, PimConvLoweringType strategy, PatternRewriter& rewriter, Location loc) { ConvLoweringDecision decision {strategy, "", false, "", ""}; return buildGroupedConvValue(op, loc, state, decision, rewriter); } static FailureOr lowerDenseSelectedConvPlan(Operation* op, const ConvLoweringState& state, PimConvLoweringType strategy, PatternRewriter& rewriter, Location loc) { DistributedConvAnalysis analysis; analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; analysis.barrierDetail = "selected dense layout"; ConvLoweringDecision decision {strategy, "", false, "", ""}; return buildConvValueForStrategy(op, loc, state, decision, analysis, {}, rewriter); } static FailureOr buildConvValueForStrategy(Operation* op, Location loc, const ConvLoweringState& state, const ConvLoweringDecision& decision, const DistributedConvAnalysis& analysis, ArrayRef distributedConsumers, PatternRewriter& rewriter) { (void)analysis; const ConvGeometry geo = buildConvGeometry(state); switch (decision.strategy) { case PimConvLoweringDepthwise: { return depthwise::rewriteConv(op, state, rewriter, loc); } case PimConvLoweringLegacy: case PimConvLoweringPackedIm2Col: { return standard::rewritePackedIm2ColConv(state, distributedConsumers, rewriter, loc); } case PimConvLoweringStreamedPatch: case PimConvLoweringOutputChannelTiled: case PimConvLoweringTiled2D: { return standard::rewriteStreamedConv(state, distributedConsumers, rewriter, loc, /*forcedPackFactor=*/1); } case PimConvLoweringInputKTiled: { return standard::rewriteInputKTiledConv(state, distributedConsumers, rewriter, loc); } case PimConvLoweringStreamedPacked: { return standard::rewriteStreamedConv(state, distributedConsumers, rewriter, loc, geo.pack); } case PimConvLoweringAuto: break; } op->emitOpError("unexpected auto strategy at Conv lowering dispatch"); return failure(); } static LogicalResult createConvValueForStrategy(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, const DistributedConvAnalysis& analysis, ArrayRef distributedConsumers, PatternRewriter& rewriter, FailureOr& result) { result = buildConvValueForStrategy(convOp, convOp.getLoc(), state, decision, analysis, distributedConsumers, rewriter); if (failed(result)) return failure(); const ConvGeometry geo = buildConvGeometry(state); const ConvStrategyEstimate estimate = estimateConvStrategy(geo, decision.strategy, analysis); switch (decision.strategy) { case PimConvLoweringDepthwise: reportConvLoweringDecision( convOp, geo, decision, estimate, /*batchSize=*/geo.p, /*numberOfBatches=*/1, /*usesComputeBatch=*/true, /*usesBatchedInstructionEmission=*/true, std::nullopt); return success(); case PimConvLoweringLegacy: case PimConvLoweringPackedIm2Col: reportConvLoweringDecision( convOp, geo, decision, estimate, /*batchSize=*/geo.pack, /*numberOfBatches=*/1, /*usesComputeBatch=*/true, /*usesBatchedInstructionEmission=*/true, std::nullopt); return success(); case PimConvLoweringStreamedPatch: case PimConvLoweringOutputChannelTiled: case PimConvLoweringTiled2D: { uint64_t chunkPositions = chooseStreamChunkPositions(geo, /*packFactor=*/1); const int64_t batches = ceilIntegerDivide(geo.p, static_cast(chunkPositions)); reportConvLoweringDecision(convOp, geo, decision, estimate, /*batchSize=*/1, batches, /*usesComputeBatch=*/true, /*usesBatchedInstructionEmission=*/true, chunkPositions); return success(); } case PimConvLoweringInputKTiled: { const int64_t numKSlices = ceilIntegerDivide(geo.k, geo.xbarSize); const uint64_t maxLanesPerBatch = std::max(1, static_cast(crossbarCountInCore.getValue()) / static_cast(std::max(1, numKSlices * 4))); const uint64_t rowChunkWidth = std::max( 1, std::min({chooseStreamChunkPositions(geo, /*packFactor=*/1), maxLanesPerBatch, static_cast(state.outWidth)})); const int64_t batches = state.batchSize * state.outHeight * ceilIntegerDivide(state.outWidth, static_cast(rowChunkWidth)); reportConvLoweringDecision(convOp, geo, decision, estimate, /*batchSize=*/1, batches, /*usesComputeBatch=*/false, /*usesBatchedInstructionEmission=*/false, rowChunkWidth); return success(); } case PimConvLoweringStreamedPacked: { uint64_t chunkPositions = chooseStreamChunkPositions(geo, geo.pack); const int64_t batches = ceilIntegerDivide(geo.p, static_cast(chunkPositions)); reportConvLoweringDecision(convOp, geo, decision, estimate, /*batchSize=*/geo.pack, batches, /*usesComputeBatch=*/true, /*usesBatchedInstructionEmission=*/true, chunkPositions); return success(); } case PimConvLoweringAuto: break; } return convOp.emitOpError("unexpected auto strategy at Conv lowering dispatch"); } static LogicalResult rewriteSelectedConv(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, const DistributedConvAnalysis& analysis, PatternRewriter& rewriter) { FailureOr result = failure(); if (failed(createConvValueForStrategy(convOp, state, decision, analysis, analysis.steps, rewriter, result))) return failure(); if (!analysis.hasLocalConsumers()) { rewriter.replaceOp(convOp, *result); return success(); } assert(analysis.replacementOp && "conv rewrite expects a replacement op"); rewriter.replaceOp(analysis.replacementOp, *result); for (auto it = analysis.steps.rbegin(); it != analysis.steps.rend(); ++it) if (it->op != analysis.replacementOp) rewriter.eraseOp(it->op); rewriter.eraseOp(convOp); return success(); } [[maybe_unused]] static LogicalResult rewriteUngroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, const DistributedConvAnalysis& analysis, PatternRewriter& rewriter) { return rewriteSelectedConv(convOp, state, decision, analysis, rewriter); } static LogicalResult rewriteGroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, PatternRewriter& rewriter); static ConvLoweringState makeGroupedConvLoweringState(const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType); static ConvLoweringState makeGroupedConvLoweringState( const ConvLoweringState& parent, Value groupX, Value groupW, Value groupB, RankedTensorType groupOutType) { ConvLoweringState state = parent; state.x = groupX; state.w = groupW; state.b = groupB; state.xType = cast(groupX.getType()); state.wType = cast(groupW.getType()); state.outType = groupOutType; state.batchSize = state.xType.getDimSize(0); state.numChannelsIn = state.xType.getDimSize(1); state.xHeight = state.xType.getDimSize(2); state.xWidth = state.xType.getDimSize(3); state.numChannelsOut = state.wType.getDimSize(0); state.wHeight = state.wType.getDimSize(2); state.wWidth = state.wType.getDimSize(3); state.outHeight = state.outType.getDimSize(2); state.outWidth = state.outType.getDimSize(3); state.group = 1; state.numChannelsInPerGroup = state.numChannelsIn; state.numChannelsOutPerGroup = state.numChannelsOut; state.hasBias = static_cast(groupB); return state; } static FailureOr buildGroupedConvValue(Operation* op, Location loc, const ConvLoweringState& state, const ConvLoweringDecision& decision, PatternRewriter& rewriter) { SmallVector xSlices = sliceTensor(state.x, /*axis=*/1, state.numChannelsInPerGroup, rewriter, loc); SmallVector wSlices = sliceTensor(state.w, /*axis=*/0, state.numChannelsOutPerGroup, rewriter, loc); SmallVector bSlices; if (state.hasBias) { auto biasType = cast(state.b.getType()); int64_t biasAxis = -1; if (biasType.getRank() == 1) biasAxis = 0; else if (biasType.getRank() == 2) biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1; else { op->emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank " << biasType.getRank(); return failure(); } bSlices = sliceTensor(state.b, biasAxis, state.numChannelsOutPerGroup, rewriter, loc); } if (xSlices.size() != static_cast(state.group) || wSlices.size() != static_cast(state.group) || (state.hasBias && bSlices.size() != static_cast(state.group))) { op->emitOpError("failed to partition grouped convolution operands for Spatial lowering"); return failure(); } SmallVector groupResults; groupResults.reserve(state.group); auto groupOutType = RankedTensorType::get( {state.batchSize, state.numChannelsOutPerGroup, state.outHeight, state.outWidth}, state.outType.getElementType()); for (int64_t groupId = 0; groupId < state.group; groupId++) { Value groupX = xSlices[groupId]; Value groupW = wSlices[groupId]; Value groupB = state.hasBias ? bSlices[groupId] : Value(); ConvLoweringState groupState = makeGroupedConvLoweringState(state, groupX, groupW, groupB, groupOutType); DistributedConvAnalysis groupAnalysis; groupAnalysis.barrierKind = DistributedConvBarrierKind::GroupedConv; groupAnalysis.barrierDetail = "grouped convolution still materializes densely"; FailureOr groupResult = buildConvValueForStrategy(op, loc, groupState, decision, groupAnalysis, {}, rewriter); if (failed(groupResult)) return failure(); groupResults.push_back(*groupResult); } if (llvm::all_of(groupResults, isCompileTimeComputable)) return createSpatConcat(rewriter, loc, /*axis=*/1, groupResults); auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {state.outType}, {}, groupResults, [&](ValueRange args) { spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args)); }); return concatCompute.getResult(0); } [[maybe_unused]] static LogicalResult rewriteGroupedConv(ONNXConvOp convOp, const ConvLoweringState& state, const ConvLoweringDecision& decision, PatternRewriter& rewriter) { FailureOr result = buildGroupedConvValue(convOp.getOperation(), convOp.getLoc(), state, decision, rewriter); if (failed(result)) return failure(); rewriter.replaceOp(convOp, *result); return success(); } } // namespace LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor, ConversionPatternRewriter& rewriter) const { FailureOr state = analyzeConvLoweringState(convOp, convOpAdaptor); if (failed(state)) return failure(); SmallVector pads { state->padHeightBegin, state->padWidthBegin, state->padHeightEnd, state->padWidthEnd}; SmallVector strides {state->strideHeight, state->strideWidth}; SmallVector dilations {state->dilationHeight, state->dilationWidth}; Value bias = state->hasBias ? convOpAdaptor.getB() : Value(); auto convPlan = spatial::SpatConv2DPlanOp::create(rewriter, convOp.getLoc(), convOp.getY().getType(), convOpAdaptor.getX(), convOpAdaptor.getW(), bias, rewriter.getDenseI64ArrayAttr(pads), rewriter.getDenseI64ArrayAttr(strides), rewriter.getDenseI64ArrayAttr(dilations), rewriter.getI64IntegerAttr(state->group), rewriter.getStringAttr("nchw")); rewriter.replaceOp(convOp, convPlan.getResult()); return success(); } void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } LogicalResult canLowerConvPlanToRowStrip(spatial::SpatConv2DPlanOp planOp) { FailureOr state = analyzeConvLoweringState(planOp); if (failed(state)) return failure(); if (state->group != 1 || state->batchSize != 1) return failure(); if (state->outType.getRank() != 4 || !state->outType.hasStaticShape()) return failure(); FailureOr requestedStrategy = resolveRequestedConvLoweringStrategy(planOp.getOperation()); if (failed(requestedStrategy)) return failure(); DistributedConvAnalysis analysis; analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; analysis.barrierDetail = "selected row-strip layout"; ConvGeometry geometry = buildConvGeometry(*state); ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, analysis); if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state) && *requestedStrategy == PimConvLoweringAuto) { decision = {PimConvLoweringLegacy, "depthwise auto fallback when structured depthwise lowering is not representable", /*isAuto=*/true, "", ""}; } if (failed(verifyForcedConvLoweringStrategy(planOp.getOperation(), geometry, decision.strategy))) return failure(); switch (decision.strategy) { case PimConvLoweringLegacy: case PimConvLoweringPackedIm2Col: case PimConvLoweringStreamedPatch: case PimConvLoweringOutputChannelTiled: case PimConvLoweringTiled2D: case PimConvLoweringStreamedPacked: return success(); case PimConvLoweringAuto: case PimConvLoweringDepthwise: case PimConvLoweringInputKTiled: return failure(); } llvm_unreachable("unknown conv lowering strategy"); } LogicalResult canConsumeAndProduceRowStrip(spatial::SpatConv2DPlanOp planOp) { FailureOr state = analyzeConvLoweringState(planOp); if (failed(state)) return failure(); StringRef failureReason; return canDirectLowerRowStripConv(*state, failureReason) ? success() : failure(); } FailureOr lowerSelectedConv2DPlan(spatial::SpatConv2DPlanOp planOp, std::optional rowStripInput, bool emitRowStripLayout, PatternRewriter& rewriter) { FailureOr state = analyzeConvLoweringState(planOp); if (failed(state)) return failure(); FailureOr requestedStrategy = resolveRequestedConvLoweringStrategy(planOp.getOperation()); if (failed(requestedStrategy)) return failure(); DistributedConvAnalysis analysis; analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; analysis.barrierDetail = emitRowStripLayout ? "selected row-strip layout" : "selected dense layout"; ConvGeometry geometry = buildConvGeometry(*state); ConvLoweringDecision decision = chooseConvLoweringStrategy(geometry, *requestedStrategy, analysis); if (decision.strategy == PimConvLoweringDepthwise && !depthwise::canUseStructuredRewrite(*state) && *requestedStrategy == PimConvLoweringAuto) { decision = {PimConvLoweringLegacy, "depthwise auto fallback when structured depthwise lowering is not representable", /*isAuto=*/true, "", ""}; } if (failed(verifyForcedConvLoweringStrategy(planOp.getOperation(), geometry, decision.strategy))) return failure(); if (emitRowStripLayout) { if (rowStripInput) { if (failed(canConsumeAndProduceRowStrip(planOp))) return planOp.emitOpError("selected row-strip input/output layout is not supported for this Conv plan"), failure(); return createConvRowsFromRowStripInput(*state, decision, *rowStripInput, rewriter, planOp.getLoc()); } if (failed(canLowerConvPlanToRowStrip(planOp))) return planOp.emitOpError("selected row-strip layout is not supported for this Conv plan"), failure(); FailureOr rows = createConvRowsForStrategy(*state, decision, rewriter, planOp.getLoc()); if (failed(rows)) return failure(); FailureOr packedRows = createRowStripPackedRows(*rows, *state, rewriter, planOp.getLoc()); if (failed(packedRows)) return planOp.emitOpError("failed to pack Conv rows into the selected row-strip physical layout"), failure(); return *packedRows; } if (decision.strategy == PimConvLoweringDepthwise) return lowerDenseSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc()); if (state->group != 1) return lowerGroupedSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc()); return lowerDenseSelectedConvPlan(planOp.getOperation(), *state, decision.strategy, rewriter, planOp.getLoc()); } } // namespace onnx_mlir