From 3a985b3675589039ab7f8346332880d4a8249129 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Thu, 18 Jun 2026 10:59:02 +0200 Subject: [PATCH] Different type of convolution --- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 1148 +++++++++++++++++ .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 19 +- 2 files changed, 1165 insertions(+), 2 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index d4dad33..8a3670f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -1,17 +1,26 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #include +#include +#include +#include #include +#include +#include #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" +#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" @@ -61,6 +70,1145 @@ struct ConvLoweringState { bool hasBias; }; +struct ConvGeometry { + int64_t batchSize; + int64_t numChannelsIn; + int64_t xHeight; + int64_t xWidth; + int64_t numChannelsOut; + int64_t wHeight; + int64_t wWidth; + int64_t outHeight; + int64_t outWidth; + int64_t group; + int64_t numChannelsInPerGroup; + int64_t numChannelsOutPerGroup; + int64_t k; + int64_t c; + int64_t p; + int64_t xbarSize; + int64_t pack; + uint64_t im2colElements; + bool hasBias; + bool isDepthwise; +}; + +struct ConvLoweringDecision { + PimConvLoweringType strategy; + std::string reason; + bool isAuto = false; + std::string fallbackReason; + std::string rejectedAutoStrategy; +}; + +struct PreparedConvInput { + Value value; + RankedTensorType type; +}; + +struct ConvStrategyEstimate { + uint64_t estimatedMvmCount = 0; + uint64_t estimatedReductionVAddCount = 0; + uint64_t estimatedOutputFragments = 1; + bool perOutputPositionReduction = false; + bool requiresFuncReturnMaterialization = false; + bool giantCollectorConcatExpected = false; + bool fullInputBroadcastExpected = false; + uint64_t concatOperandCount = 0; + std::string materializationKind = "none"; + std::string collectorCore = "-"; +}; + +struct ConvReportEntry { + uint64_t id; + std::string where; + std::string strategy; + std::string mode; + std::string inputShape; + std::string weightShape; + std::string outputShape; + int64_t groups; + int64_t k; + int64_t c; + int64_t p; + int64_t xbarSize; + int64_t pack; + uint64_t im2colElements; + uint64_t im2colBudget; + std::string chunkText; + int64_t batchSize; + int64_t numberOfBatches; + std::string spatialComputeBatch; + std::string batchedInstructionEmission; + std::string reason; + std::string fallbackReason; + std::string rejectedAutoStrategy; + uint64_t estimatedMvmCount; + uint64_t estimatedReductionVAddCount; + uint64_t estimatedOutputFragments; + std::string materializationRequiredAtReturn; + std::string materializationKind; + uint64_t concatOperandCount; + std::string collectorCore; + std::string giantCollectorConcatExpected; + std::string fullInputBroadcastExpected; +}; + +enum class DistributedTensorOpKind { + Relu, + Sigmoid, + Add, + Sub, + Mul, + Div, + Conv, +}; + +enum class DistributedConvBarrierKind { + Return, + UnsupportedConsumer, + Fanout, + DeadValue, + GroupedConv, + Depthwise, +}; + +enum class DistributedTensorConstantKind { + None, + Splat, + PerChannel, +}; + +enum class DistributedTensorLayoutKind { + NchwRowStrip, +}; + +struct DistributedFragmentInfo { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + int64_t producerLane = 0; +}; + +struct DistributedTensorInfo { + Value storage; + RankedTensorType logicalType; + DistributedTensorLayoutKind layoutKind = DistributedTensorLayoutKind::NchwRowStrip; + SmallVector fragments; + int64_t laneCount = 0; + int64_t fragmentHeight = 1; + int64_t channels = 0; + int64_t height = 0; + int64_t width = 0; + + bool isRowStripNchw() const { return layoutKind == DistributedTensorLayoutKind::NchwRowStrip; } +}; + +struct DistributedTensorRegistry { + llvm::DenseMap infos; + + void bind(Value value, const DistributedTensorInfo& info) { infos[value] = info; } + + const DistributedTensorInfo* lookup(Value value) const { + auto it = infos.find(value); + if (it == infos.end()) + return nullptr; + return &it->second; + } +}; + +struct DistributedTensorStep { + Operation* op = nullptr; + DistributedTensorOpKind kind; + DenseElementsAttr constantAttr; + DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None; + bool fragmentOnLhs = true; + std::optional convState; +}; + +struct DistributedConvAnalysis { + SmallVector steps; + Operation* replacementOp = nullptr; + DistributedConvBarrierKind barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; + std::string barrierDetail; + + bool hasLocalConsumers() const { return !steps.empty(); } + bool hasDistributedConvConsumer() const { + return llvm::any_of(steps, [](const DistributedTensorStep& step) { return step.kind == DistributedTensorOpKind::Conv; }); + } +}; + +struct DistributedChainReportEntry { + uint64_t chainId = 0; + uint64_t chainLength = 0; + std::string producerKind; + std::string distributedOps; + std::string materializationPoints; + std::string firstMaterializationReason; + uint64_t maxLiveFragments = 0; + uint64_t maxFragmentFanout = 0; + uint64_t patchBuilderCoreCount = 0; + std::string centralJunctionDetected = "no"; + std::string convInputMaterializationKind = "dense_materialization_fallback"; + uint64_t localPatchFragments = 0; + uint64_t remotePatchFragments = 0; + uint64_t haloTransferCount = 0; + uint64_t groupedTransferCount = 0; + std::string fallbackReason; +}; + +struct DistributedConvReportTotals { + uint64_t totalConvs = 0; + uint64_t distributedTensorsCreated = 0; + uint64_t distributedValuesPropagated = 0; + uint64_t distributedConsumersHandled = 0; + uint64_t distributedConvInputsSeen = 0; + uint64_t distributedConvInputsConsumed = 0; + uint64_t materializationBarriersInserted = 0; + std::map fallbackReasons; + std::map barrierReasons; + SmallVector chains; +}; + +static bool +isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup); +static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor); +static FailureOr analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b); + +static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) { + switch (kind) { + case DistributedConvBarrierKind::Return: return "func.return"; + case DistributedConvBarrierKind::UnsupportedConsumer: return "unsupported_consumer"; + case DistributedConvBarrierKind::Fanout: return "fanout"; + case DistributedConvBarrierKind::DeadValue: return "dead_value"; + case DistributedConvBarrierKind::GroupedConv: return "grouped_conv"; + case DistributedConvBarrierKind::Depthwise: return "depthwise_conv"; + } + llvm_unreachable("unknown distributed conv barrier kind"); +} + +static StringRef stringifyConvLoweringStrategy(PimConvLoweringType strategy) { + switch (strategy) { + case PimConvLoweringAuto: return "auto"; + case PimConvLoweringLegacy: return "legacy"; + case PimConvLoweringDepthwise: return "depthwise"; + case PimConvLoweringPackedIm2Col: return "packed-im2col"; + case PimConvLoweringStreamedPatch: return "streamed-patch"; + case PimConvLoweringStreamedPacked: return "streamed-packed"; + case PimConvLoweringOutputChannelTiled: return "output-channel-tiled"; + case PimConvLoweringInputKTiled: return "input-k-tiled"; + case PimConvLoweringTiled2D: return "tiled-2d"; + } + llvm_unreachable("unknown conv lowering strategy"); +} + +static bool requiresFuncReturnMaterialization(const DistributedConvAnalysis& analysis) { + return !analysis.hasLocalConsumers() && analysis.barrierKind == DistributedConvBarrierKind::Return; +} + +static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo, + PimConvLoweringType strategy, + const DistributedConvAnalysis& analysis) { + ConvStrategyEstimate estimate; + estimate.requiresFuncReturnMaterialization = requiresFuncReturnMaterialization(analysis); + estimate.materializationKind = estimate.requiresFuncReturnMaterialization ? "func.return" : "none"; + + switch (strategy) { + case PimConvLoweringLegacy: + case PimConvLoweringPackedIm2Col: + estimate.estimatedMvmCount = static_cast(std::max(1, geo.p)); + break; + case PimConvLoweringStreamedPatch: + case PimConvLoweringStreamedPacked: + case PimConvLoweringOutputChannelTiled: { + uint64_t chunkPositions = chooseStreamChunkPositions(geo, /*packFactor=*/1); + estimate.estimatedMvmCount = static_cast(std::max(1, geo.p)); + estimate.estimatedOutputFragments = + std::max(1, static_cast(ceilIntegerDivide(geo.p, static_cast(chunkPositions)))); + break; + } + case PimConvLoweringInputKTiled: { + const int64_t numKSlices = ceilIntegerDivide(geo.k, geo.xbarSize); + const uint64_t maxLanesPerBatch = + std::max(1, + static_cast(crossbarCountInCore.getValue()) + / static_cast(std::max(1, numKSlices * 4))); + const uint64_t rowChunkWidth = std::max( + 1, + std::min({chooseStreamChunkPositions(geo, /*packFactor=*/1), + maxLanesPerBatch, + static_cast(std::max(1, geo.outWidth))})); + estimate.estimatedMvmCount = + static_cast(std::max(1, geo.p)) * static_cast(std::max(1, numKSlices)); + estimate.estimatedReductionVAddCount = + static_cast(std::max(1, geo.p)) + * static_cast(std::max(0, numKSlices - 1) + (geo.hasBias ? 1 : 0)); + estimate.estimatedOutputFragments = static_cast(std::max(1, geo.batchSize)) + * static_cast(std::max(1, geo.outHeight)) + * static_cast( + ceilIntegerDivide(geo.outWidth, static_cast(rowChunkWidth))); + estimate.perOutputPositionReduction = numKSlices > 1; + estimate.fullInputBroadcastExpected = estimate.estimatedOutputFragments > 1; + if (estimate.requiresFuncReturnMaterialization && estimate.estimatedOutputFragments >= 128) { + estimate.giantCollectorConcatExpected = true; + estimate.materializationKind = "single_collector_concat"; + estimate.concatOperandCount = estimate.estimatedOutputFragments; + estimate.collectorCore = "scheduled_post_merge"; + } + break; + } + case PimConvLoweringTiled2D: + estimate.estimatedMvmCount = static_cast(std::max(1, geo.p)); + break; + case PimConvLoweringDepthwise: + case PimConvLoweringAuto: + break; + } + + if (estimate.requiresFuncReturnMaterialization && estimate.materializationKind == "none") + estimate.materializationKind = "func.return"; + return estimate; +} + +static ConvGeometry buildConvGeometry(const ConvLoweringState& state) { + ConvGeometry geo { + state.batchSize, + state.numChannelsIn, + state.xHeight, + state.xWidth, + state.numChannelsOut, + state.wHeight, + state.wWidth, + state.outHeight, + state.outWidth, + state.group, + state.numChannelsInPerGroup, + state.numChannelsOutPerGroup, + state.numChannelsInPerGroup * state.wHeight * state.wWidth, + state.numChannelsOutPerGroup, + state.batchSize * state.outHeight * state.outWidth, + static_cast(crossbarSize.getValue()), + 1, + 0, + state.hasBias, + isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup), + }; + geo.pack = std::max(1, geo.xbarSize / std::max(geo.k, geo.c)); + geo.im2colElements = static_cast(std::max(0, geo.p)) * static_cast(std::max(0, geo.k)); + return geo; +} + +static std::string formatShape(ArrayRef dims) { + std::string text; + llvm::raw_string_ostream os(text); + os << "["; + for (size_t i = 0; i < dims.size(); ++i) { + if (i != 0) + os << "x"; + os << dims[i]; + } + os << "]"; + return text; +} + +static std::string collapseWhitespace(StringRef text) { + std::string out; + out.reserve(text.size()); + bool lastWasSpace = false; + for (char c : text) { + bool isSpace = std::isspace(static_cast(c)); + if (isSpace) { + if (!lastWasSpace && !out.empty()) + out.push_back(' '); + lastWasSpace = true; + continue; + } + out.push_back(c); + lastWasSpace = false; + } + return out; +} + +static std::string abbreviate(StringRef text, size_t maxLen) { + if (text.size() <= maxLen) + return text.str(); + return (text.take_front(maxLen - 3) + "...").str(); +} + +static std::string abbreviateFromEnd(StringRef text, size_t maxLen) { + if (text.size() <= maxLen) + return text.str(); + return ("..." + text.take_back(maxLen - 3)).str(); +} + +static std::string summarizeLocation(Location loc, size_t maxLen = 44) { + std::string text; + llvm::raw_string_ostream os(text); + loc.print(os); + os.flush(); + std::string collapsed = collapseWhitespace(text); + if (collapsed.size() <= maxLen) + return collapsed; + if (collapsed.find('/') != std::string::npos || collapsed.find('#') != std::string::npos) + return abbreviateFromEnd(collapsed, maxLen); + return abbreviate(collapsed, maxLen); +} + +static std::string alignCell(StringRef text, size_t width, bool rightAlign = false) { + std::string cell = text.str(); + if (cell.size() < width) { + size_t padding = width - cell.size(); + if (rightAlign) + cell.insert(cell.begin(), padding, ' '); + else + cell.append(padding, ' '); + } + return cell; +} + +static bool hasSameStaticTensorType(Value value, Type expectedType) { + auto valueType = dyn_cast(value.getType()); + auto expectedTensorType = dyn_cast(expectedType); + return valueType && expectedTensorType && valueType.hasStaticShape() && valueType == expectedTensorType; +} + +static bool isSplatConstantValue(Value value, DenseElementsAttr& denseAttr) { + denseAttr = getHostConstDenseElementsAttr(value); + return static_cast(denseAttr) && denseAttr.isSplat(); +} + +static bool isPerChannelConstantValue(Value value, RankedTensorType currentType, DenseElementsAttr& denseAttr) { + denseAttr = getHostConstDenseElementsAttr(value); + if (!denseAttr || denseAttr.isSplat()) + return false; + + auto constantType = dyn_cast(denseAttr.getType()); + if (!constantType || !constantType.hasStaticShape()) + return false; + + const int64_t channels = currentType.getDimSize(1); + if (constantType.getRank() == 1) + return constantType.getDimSize(0) == channels; + if (constantType.getRank() == 2) + return constantType.getDimSize(0) == 1 && constantType.getDimSize(1) == channels; + if (constantType.getRank() != 4) + return false; + return constantType.getDimSize(0) == 1 && constantType.getDimSize(1) == channels + && constantType.getDimSize(2) == 1 && constantType.getDimSize(3) == 1; +} + +static std::optional +classifyDistributedBinaryConsumer(Operation* user, + Value currentValue, + Value lhs, + Value rhs, + DistributedTensorOpKind kind, + bool allowFragmentOnRhs, + std::string& failureDetail) { + if (user->getNumResults() != 1 || !hasSameStaticTensorType(user->getResult(0), currentValue.getType())) { + failureDetail = "result type mismatch"; + return std::nullopt; + } + + auto currentType = cast(currentValue.getType()); + DenseElementsAttr constantAttr; + if (lhs == currentValue) { + DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None; + if (isSplatConstantValue(rhs, constantAttr)) + constantKind = DistributedTensorConstantKind::Splat; + else if (isPerChannelConstantValue(rhs, currentType, constantAttr)) + constantKind = DistributedTensorConstantKind::PerChannel; + else + failureDetail = "unsupported rhs broadcast"; + if (constantKind == DistributedTensorConstantKind::None) + return std::nullopt; + return DistributedTensorStep {user, kind, constantAttr, constantKind, /*fragmentOnLhs=*/true, std::nullopt}; + } + + if (rhs == currentValue && allowFragmentOnRhs) { + DistributedTensorConstantKind constantKind = DistributedTensorConstantKind::None; + if (isSplatConstantValue(lhs, constantAttr)) + constantKind = DistributedTensorConstantKind::Splat; + else if (isPerChannelConstantValue(lhs, currentType, constantAttr)) + constantKind = DistributedTensorConstantKind::PerChannel; + else + failureDetail = "unsupported lhs broadcast"; + if (constantKind == DistributedTensorConstantKind::None) + return std::nullopt; + return DistributedTensorStep {user, kind, constantAttr, constantKind, /*fragmentOnLhs=*/false, std::nullopt}; + } + + failureDetail = "conv result is not the supported binary operand"; + return std::nullopt; +} + +[[maybe_unused]] static bool canConsumeDistributedConvInput(const ConvLoweringState& state, StringRef& failureReason) { + if (state.batchSize != 1) { + failureReason = "unsupported_batch"; + return false; + } + if (state.group != 1) { + failureReason = "unsupported_groups"; + return false; + } + if (state.strideHeight != 1 || state.strideWidth != 1) { + failureReason = "unsupported_stride"; + return false; + } + if (state.dilationHeight != 1 || state.dilationWidth != 1) { + failureReason = "unsupported_dilation"; + return false; + } + if (state.padHeightBegin != state.padHeightEnd || state.padWidthBegin != state.padWidthEnd) { + failureReason = "unsupported_padding"; + return false; + } + if (state.padHeightBegin != 1 || state.padWidthBegin != 1) { + failureReason = "unsupported_padding"; + return false; + } + if (state.wHeight != 3 || state.wWidth != 3) { + failureReason = "unsupported_kernel"; + return false; + } + if (state.outHeight != state.xHeight || state.outWidth != state.xWidth) { + failureReason = "unsupported_output_shape"; + return false; + } + if (!getHostConstDenseElementsAttr(state.w)) { + failureReason = "non_constant_weights"; + return false; + } + if (state.hasBias && !getHostConstDenseElementsAttr(state.b)) { + failureReason = "non_constant_bias"; + return false; + } + failureReason = ""; + return true; +} + +static std::string stringifyDistributedTensorOpKind(DistributedTensorOpKind kind) { + switch (kind) { + case DistributedTensorOpKind::Relu: return "Relu"; + case DistributedTensorOpKind::Sigmoid: return "Sigmoid"; + case DistributedTensorOpKind::Add: return "Add"; + case DistributedTensorOpKind::Sub: return "Sub"; + case DistributedTensorOpKind::Mul: return "Mul"; + case DistributedTensorOpKind::Div: return "Div"; + case DistributedTensorOpKind::Conv: return "Conv"; + } + llvm_unreachable("unknown distributed tensor op kind"); +} + +static DistributedConvAnalysis analyzeDistributedConvConsumers(ONNXConvOp convOp) { + DistributedConvAnalysis analysis; + analysis.replacementOp = convOp; + + Value currentValue = convOp.getResult(); + while (true) { + if (currentValue.use_empty()) { + analysis.barrierKind = DistributedConvBarrierKind::DeadValue; + analysis.barrierDetail = "result has no users"; + return analysis; + } + + if (!currentValue.hasOneUse()) { + analysis.barrierKind = DistributedConvBarrierKind::Fanout; + analysis.barrierDetail = "value has multiple users"; + return analysis; + } + + Operation* user = *currentValue.getUsers().begin(); + if (isa(user)) { + analysis.barrierKind = DistributedConvBarrierKind::Return; + analysis.barrierDetail = "materialize at func.return"; + return analysis; + } + + std::optional step; + std::string failureDetail; + if (auto reluOp = dyn_cast(user)) { + if (hasSameStaticTensorType(reluOp.getResult(), currentValue.getType())) + step = DistributedTensorStep { + user, DistributedTensorOpKind::Relu, {}, DistributedTensorConstantKind::None, true, std::nullopt}; + else + failureDetail = "relu result type mismatch"; + } + else if (auto sigmoidOp = dyn_cast(user)) { + if (hasSameStaticTensorType(sigmoidOp.getResult(), currentValue.getType())) + step = DistributedTensorStep { + user, DistributedTensorOpKind::Sigmoid, {}, DistributedTensorConstantKind::None, true, std::nullopt}; + else + failureDetail = "sigmoid result type mismatch"; + } + else if (auto addOp = dyn_cast(user)) { + step = classifyDistributedBinaryConsumer( + user, currentValue, addOp.getA(), addOp.getB(), DistributedTensorOpKind::Add, /*allowFragmentOnRhs=*/true, + failureDetail); + } + else if (auto subOp = dyn_cast(user)) { + step = classifyDistributedBinaryConsumer( + user, currentValue, subOp.getA(), subOp.getB(), DistributedTensorOpKind::Sub, /*allowFragmentOnRhs=*/true, + failureDetail); + } + else if (auto mulOp = dyn_cast(user)) { + step = classifyDistributedBinaryConsumer( + user, currentValue, mulOp.getA(), mulOp.getB(), DistributedTensorOpKind::Mul, /*allowFragmentOnRhs=*/true, + failureDetail); + } + else if (auto divOp = dyn_cast(user)) { + step = classifyDistributedBinaryConsumer( + user, currentValue, divOp.getA(), divOp.getB(), DistributedTensorOpKind::Div, /*allowFragmentOnRhs=*/false, + failureDetail); + if (step) { + auto denseAttr = dyn_cast(step->constantAttr); + if (!denseAttr) { + failureDetail = "div requires floating-point splat constant"; + step.reset(); + } + } + } + else if (auto nextConv = dyn_cast(user)) { + failureDetail = "onnx.Conv distributed consumer blocked by dim0-only whole-batch materialization in MergeComputeNodes"; + } + else { + failureDetail = (user->getName().getStringRef() + " is not distributed-aware yet").str(); + } + + if (!step) { + analysis.barrierKind = DistributedConvBarrierKind::UnsupportedConsumer; + analysis.barrierDetail = failureDetail; + return analysis; + } + + analysis.replacementOp = user; + analysis.steps.push_back(*step); + currentValue = user->getResult(0); + } +} + +static void rewriteDistributedConvReport(const DistributedConvReportTotals& totals) { + std::fstream reportFile = openReportFile("conv_distributed_consumption_report"); + if (!reportFile.is_open()) + return; + + reportFile << "# PIM Conv Distributed Consumption Report\n\n"; + reportFile << "Totals:\n"; + reportFile << "- convs_seen: " << totals.totalConvs << "\n"; + reportFile << "- distributed_tensors_created: " << totals.distributedTensorsCreated << "\n"; + reportFile << "- distributed_values_propagated: " << totals.distributedValuesPropagated << "\n"; + reportFile << "- distributed_consumers_handled_locally: " << totals.distributedConsumersHandled << "\n"; + reportFile << "- distributed_conv_inputs_seen: " << totals.distributedConvInputsSeen << "\n"; + reportFile << "- distributed_conv_inputs_consumed: " << totals.distributedConvInputsConsumed << "\n"; + reportFile << "- materialization_barriers_inserted: " << totals.materializationBarriersInserted << "\n\n"; + + if (!totals.barrierReasons.empty()) { + reportFile << "Materialization barriers:\n"; + for (const auto& [reason, count] : totals.barrierReasons) + reportFile << "- " << reason << ": " << count << "\n"; + reportFile << "\n"; + } + + if (!totals.fallbackReasons.empty()) { + reportFile << "Fallback / no-distribution reasons:\n"; + for (const auto& [reason, count] : totals.fallbackReasons) + reportFile << "- " << reason << ": " << count << "\n"; + reportFile << "\n"; + } + + if (!totals.chains.empty()) { + reportFile << "Chains:\n"; + for (const DistributedChainReportEntry& chain : totals.chains) { + reportFile << "- chain_id: " << chain.chainId << "\n"; + reportFile << " chain_length: " << chain.chainLength << "\n"; + reportFile << " producer_kind: " << chain.producerKind << "\n"; + reportFile << " distributed_ops: " << chain.distributedOps << "\n"; + reportFile << " materialization_points: " << chain.materializationPoints << "\n"; + reportFile << " first_materialization_reason: " << chain.firstMaterializationReason << "\n"; + reportFile << " max_live_fragments: " << chain.maxLiveFragments << "\n"; + reportFile << " max_fragment_fanout: " << chain.maxFragmentFanout << "\n"; + reportFile << " patch_builder_core_count: " << chain.patchBuilderCoreCount << "\n"; + reportFile << " central_junction_detected: " << chain.centralJunctionDetected << "\n"; + reportFile << " conv_input_materialization_kind: " << chain.convInputMaterializationKind << "\n"; + reportFile << " local_patch_fragments: " << chain.localPatchFragments << "\n"; + reportFile << " remote_patch_fragments: " << chain.remotePatchFragments << "\n"; + reportFile << " halo_transfer_count: " << chain.haloTransferCount << "\n"; + reportFile << " grouped_transfer_count: " << chain.groupedTransferCount << "\n"; + if (!chain.fallbackReason.empty()) + reportFile << " fallback_reason: " << chain.fallbackReason << "\n"; + } + } +} + +static void recordDistributedConvOutcome(const DistributedConvAnalysis& analysis) { + static std::mutex reportMutex; + static DistributedConvReportTotals totals; + + std::string barrierKey = stringifyDistributedConvBarrierKind(analysis.barrierKind).str(); + if (!analysis.barrierDetail.empty()) + barrierKey += ": " + analysis.barrierDetail; + + std::lock_guard guard(reportMutex); + totals.totalConvs++; + if (analysis.hasLocalConsumers()) { + totals.distributedTensorsCreated++; + totals.distributedValuesPropagated += analysis.steps.size(); + totals.distributedConsumersHandled += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) { + return step.kind != DistributedTensorOpKind::Conv; + }); + totals.distributedConvInputsSeen += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) { + return step.kind == DistributedTensorOpKind::Conv; + }); + totals.distributedConvInputsConsumed += llvm::count_if(analysis.steps, [](const DistributedTensorStep& step) { + return step.kind == DistributedTensorOpKind::Conv; + }); + totals.materializationBarriersInserted++; + totals.barrierReasons[barrierKey]++; + } + else { + totals.fallbackReasons[barrierKey]++; + } + DistributedChainReportEntry chain; + chain.chainId = totals.totalConvs; + chain.chainLength = analysis.steps.size() + 1; + chain.producerKind = "Conv"; + chain.materializationPoints = stringifyDistributedConvBarrierKind(analysis.barrierKind).str(); + chain.firstMaterializationReason = analysis.barrierDetail; + chain.maxFragmentFanout = analysis.barrierKind == DistributedConvBarrierKind::Fanout ? 2 : 1; + std::string ops; + for (size_t index = 0; index < analysis.steps.size(); ++index) { + if (!ops.empty()) + ops += ", "; + ops += stringifyDistributedTensorOpKind(analysis.steps[index].kind); + } + chain.distributedOps = ops; + if (analysis.hasDistributedConvConsumer()) { + chain.convInputMaterializationKind = "distributed_with_halo_exchange"; + chain.patchBuilderCoreCount = 1; + } + if (totals.chains.size() == 16) + totals.chains.erase(totals.chains.begin()); + totals.chains.push_back(std::move(chain)); + rewriteDistributedConvReport(totals); +} + +static std::string makeDivider(ArrayRef widths) { + std::string divider = "+"; + for (size_t width : widths) { + divider.append(width + 2, '-'); + divider.push_back('+'); + } + return divider; +} + +static void printConvReportLegend(std::fstream& reportFile) { + reportFile << "# PIM Conv Lowering Report\n\n"; + reportFile << "Legend:\n"; + reportFile << "- `id`: sequential Conv report entry index within this compiler invocation.\n"; + reportFile << "- `where`: summarized MLIR location of the Conv op.\n"; + reportFile << "- `mode`: whether the selected strategy came from `auto` policy or a forced compiler option.\n"; + reportFile << "- `strategy`: selected Conv lowering algorithm.\n"; + reportFile << "- `input`, `weight`, `output`: tensor shapes of the Conv operands/result.\n"; + reportFile << "- `groups`: ONNX Conv group count.\n"; + reportFile << "- `K`: logical reduction size, `CinPerGroup * Kh * Kw`.\n"; + reportFile << "- `C`: logical output-channel width handled by the selected strategy.\n"; + reportFile << "- `P`: total output positions, `N * Hout * Wout`.\n"; + reportFile << "- `X`: crossbar size.\n"; + reportFile << "- `pack`: packed spatial positions per MVM group, `floor(X / max(K, C))`.\n"; + reportFile << "- `im2col`: total explicit im2col element count, `P * K`.\n"; + reportFile << "- `im2col_budget`: maximum im2col element budget allowed by the compiler option.\n"; + reportFile << "- `stream_chunk`: output positions materialized per streamed chunk, or `-` when not applicable.\n"; + reportFile << "- `batch_size`: logical batch size passed into the compute lowering for this Conv form.\n"; + reportFile << "- `batches`: number of repeated compute batches emitted for the Conv.\n"; + reportFile << "- `spatial_compute_batch`: whether ONNX-to-Spatial used `spat.compute_batch` for this Conv.\n"; + reportFile << "- `batched_instruction_emission`: whether the lowering is expected to reach batched PIM emission.\n"; + reportFile << "- `reason`: strategy-selection reason.\n"; + reportFile << "- Per-conv details below the table include profitability estimates and materialization diagnostics.\n"; + reportFile << "- placeholders like `[7]`: value too long for the table cell; see the appendix at the end.\n\n"; +} + +struct ConvReportOverflow { + uint64_t placeholderId; + SmallVector rowIds; + std::string column; + std::string value; +}; + +static StringRef describeConvLoweringStrategy(PimConvLoweringType strategy) { + switch (strategy) { + case PimConvLoweringAuto: return "Automatic policy selection."; + case PimConvLoweringLegacy: return "Legacy Conv lowering path kept for compatibility."; + case PimConvLoweringDepthwise: return "Specialized depthwise lowering that avoids generic cross-channel GEMM mixing."; + case PimConvLoweringPackedIm2Col: return "Explicit im2col plus packed GEMM lowering for Conv shapes that fit well in one crossbar."; + case PimConvLoweringStreamedPatch: return "Chunked per-patch streaming Conv lowering without global im2col materialization."; + case PimConvLoweringStreamedPacked: return "Chunked streamed Conv lowering that still packs multiple output positions per MVM group."; + case PimConvLoweringOutputChannelTiled: return "Conv lowering that splits output channels across tiles when C exceeds one crossbar width."; + case PimConvLoweringInputKTiled: return "Conv lowering that splits the reduction dimension K across tiles and accumulates partial sums."; + case PimConvLoweringTiled2D: return "Conv lowering that tiles both K and output channels because neither dimension fits one crossbar."; + } + llvm_unreachable("unknown conv lowering strategy"); +} + +static std::string fitConvReportCell(StringRef text, + size_t width, + uint64_t rowId, + StringRef column, + std::vector& overflows, + uint64_t& nextPlaceholderId, + bool rightAlign = false) { + if (text.size() <= width) + return alignCell(text, width, rightAlign); + + for (ConvReportOverflow& overflow : overflows) { + if (overflow.column == column && overflow.value == text) { + if (llvm::find(overflow.rowIds, rowId) == overflow.rowIds.end()) + overflow.rowIds.push_back(rowId); + std::string placeholder = "[" + std::to_string(overflow.placeholderId) + "]"; + return alignCell(placeholder, width, rightAlign); + } + } + + std::string placeholder = "[" + std::to_string(nextPlaceholderId++) + "]"; + overflows.push_back({nextPlaceholderId - 1, {rowId}, column.str(), text.str()}); + return alignCell(placeholder, width, rightAlign); +} + +static void writeConvReportTable(std::fstream& reportFile, ArrayRef entries) { + static constexpr size_t kIdWidth = 4; + static constexpr size_t kWhereWidth = 24; + static constexpr size_t kModeWidth = 6; + static constexpr size_t kStrategyWidth = 20; + static constexpr size_t kShapeWidth = 14; + static constexpr size_t kGroupsWidth = 3; + static constexpr size_t kSmallWidth = 5; + static constexpr size_t kPWidth = 10; + static constexpr size_t kIm2colWidth = 10; + static constexpr size_t kChunkWidth = 8; + static constexpr size_t kFlagWidth = 3; + static constexpr size_t kReasonWidth = 24; + + const SmallVector widths = { + kIdWidth, kWhereWidth, kModeWidth, kStrategyWidth, kShapeWidth, kShapeWidth, kShapeWidth, kGroupsWidth, + kSmallWidth, kSmallWidth, kPWidth, kSmallWidth, kSmallWidth, kIm2colWidth, kIm2colWidth, + kChunkWidth, kSmallWidth, kSmallWidth, kFlagWidth, kFlagWidth, kReasonWidth, + }; + const std::string divider = makeDivider(widths); + std::vector overflows; + uint64_t nextPlaceholderId = 1; + + auto printRow = [&](ArrayRef cells) { + reportFile << "|"; + for (size_t i = 0; i < cells.size(); ++i) + reportFile << " " << cells[i] << " |"; + reportFile << "\n"; + }; + + reportFile << divider << "\n"; + printRow({ + alignCell("id", kIdWidth, true), + alignCell("where", kWhereWidth), + alignCell("mode", kModeWidth), + alignCell("strategy", kStrategyWidth), + alignCell("input", kShapeWidth), + alignCell("weight", kShapeWidth), + alignCell("output", kShapeWidth), + alignCell("grp", kGroupsWidth, true), + alignCell("K", kSmallWidth, true), + alignCell("C", kSmallWidth, true), + alignCell("P", kPWidth, true), + alignCell("X", kSmallWidth, true), + alignCell("pack", kSmallWidth, true), + alignCell("im2col", kIm2colWidth, true), + alignCell("budget", kIm2colWidth, true), + alignCell("chunk", kChunkWidth, true), + alignCell("batch", kSmallWidth, true), + alignCell("nbat", kSmallWidth, true), + alignCell("scb", kFlagWidth), + alignCell("bie", kFlagWidth), + alignCell("reason", kReasonWidth), + }); + reportFile << divider << "\n"; + + for (const ConvReportEntry& entry : entries) { + printRow({ + alignCell(std::to_string(entry.id), kIdWidth, true), + fitConvReportCell(entry.where, kWhereWidth, entry.id, "where", overflows, nextPlaceholderId), + alignCell(entry.mode, kModeWidth), + fitConvReportCell(entry.strategy, kStrategyWidth, entry.id, "strategy", overflows, nextPlaceholderId), + fitConvReportCell(entry.inputShape, kShapeWidth, entry.id, "input", overflows, nextPlaceholderId), + fitConvReportCell(entry.weightShape, kShapeWidth, entry.id, "weight", overflows, nextPlaceholderId), + fitConvReportCell(entry.outputShape, kShapeWidth, entry.id, "output", overflows, nextPlaceholderId), + alignCell(std::to_string(entry.groups), kGroupsWidth, true), + alignCell(std::to_string(entry.k), kSmallWidth, true), + alignCell(std::to_string(entry.c), kSmallWidth, true), + alignCell(std::to_string(entry.p), kPWidth, true), + alignCell(std::to_string(entry.xbarSize), kSmallWidth, true), + alignCell(std::to_string(entry.pack), kSmallWidth, true), + alignCell(std::to_string(entry.im2colElements), kIm2colWidth, true), + alignCell(std::to_string(entry.im2colBudget), kIm2colWidth, true), + fitConvReportCell(entry.chunkText, kChunkWidth, entry.id, "stream_chunk", overflows, nextPlaceholderId, true), + alignCell(std::to_string(entry.batchSize), kSmallWidth, true), + alignCell(std::to_string(entry.numberOfBatches), kSmallWidth, true), + alignCell(entry.spatialComputeBatch, kFlagWidth), + alignCell(entry.batchedInstructionEmission, kFlagWidth), + fitConvReportCell(entry.reason, kReasonWidth, entry.id, "reason", overflows, nextPlaceholderId), + }); + } + reportFile << divider << "\n"; + + if (overflows.empty()) + reportFile << "\n"; + else { + reportFile << "\nAppendix:\n"; + for (const ConvReportOverflow& overflow : overflows) { + reportFile << " [" << overflow.placeholderId << "] rows "; + for (size_t i = 0; i < overflow.rowIds.size(); ++i) { + if (i != 0) + reportFile << ", "; + reportFile << overflow.rowIds[i]; + } + reportFile << ", " << overflow.column << ": " << overflow.value << "\n"; + } + reportFile << "\n"; + } + + reportFile << "Per-Conv Details:\n"; + for (const ConvReportEntry& entry : entries) { + reportFile << "- Conv " << entry.id << ": mode=" << entry.mode << ", strategy=" << entry.strategy + << ", reason=" << entry.reason << "\n"; + reportFile << " K=" << entry.k << ", Cout=" << entry.c << ", output_positions=" << entry.p + << ", estimated_mvm_count=" << entry.estimatedMvmCount + << ", estimated_reduction_vadd_count=" << entry.estimatedReductionVAddCount + << ", estimated_output_fragments=" << entry.estimatedOutputFragments << "\n"; + reportFile << " materialization_required_at_func_return=" << entry.materializationRequiredAtReturn + << ", materialization_kind=" << entry.materializationKind + << ", giant_collector_concat_expected=" << entry.giantCollectorConcatExpected + << ", concat_operand_count=" << entry.concatOperandCount + << ", collector_core=" << entry.collectorCore + << ", full_input_broadcast_expected=" << entry.fullInputBroadcastExpected << "\n"; + if (!entry.rejectedAutoStrategy.empty()) + reportFile << " rejected_auto_strategy=" << entry.rejectedAutoStrategy << "\n"; + if (!entry.fallbackReason.empty()) + reportFile << " fallback_reason=" << entry.fallbackReason << "\n"; + } + reportFile << "\n"; + + llvm::SmallVector usedStrategies; + for (const ConvReportEntry& entry : entries) { + PimConvLoweringType strategy = PimConvLoweringAuto; + for (PimConvLoweringType candidate : { + PimConvLoweringAuto, + PimConvLoweringLegacy, + PimConvLoweringDepthwise, + PimConvLoweringPackedIm2Col, + PimConvLoweringStreamedPatch, + PimConvLoweringStreamedPacked, + PimConvLoweringOutputChannelTiled, + PimConvLoweringInputKTiled, + PimConvLoweringTiled2D, + }) { + if (entry.strategy == stringifyConvLoweringStrategy(candidate)) { + strategy = candidate; + break; + } + } + if (llvm::find(usedStrategies, strategy) == usedStrategies.end()) + usedStrategies.push_back(strategy); + } + + reportFile << "Strategies used in this report:\n"; + for (PimConvLoweringType strategy : usedStrategies) + reportFile << "- `" << stringifyConvLoweringStrategy(strategy).str() << "`: " + << describeConvLoweringStrategy(strategy).str() << "\n"; +} + +static void rewriteConvLoweringReport(ArrayRef entries) { + std::fstream reportFile = openReportFile("conv_lowering_report"); + if (!reportFile.is_open()) + return; + printConvReportLegend(reportFile); + writeConvReportTable(reportFile, entries); +} + +static FailureOr resolveRequestedConvLoweringStrategy(ONNXConvOp convOp) { + if (!useExperimentalConvImpl) + return pimConvLowering.getValue(); + + if (pimConvLowering != PimConvLoweringAuto && pimConvLowering != PimConvLoweringPackedIm2Col) { + convOp.emitOpError() << "--use-experimental-conv-impl conflicts with --pim-conv-lowering=" + << stringifyConvLoweringStrategy(pimConvLowering); + return failure(); + } + return PimConvLoweringPackedIm2Col; +} + +static ConvLoweringDecision chooseConvLoweringStrategy(const ConvGeometry& geo, + PimConvLoweringType requested, + const DistributedConvAnalysis& analysis) { + if (requested != PimConvLoweringAuto) + return {requested, "forced by compiler option", /*isAuto=*/false, "", ""}; + + // Transform-based convolution is intentionally not selected for this ISA: + // it would require explicit transform sequences and staging traffic on top of + // the same crossbar MVM primitive, which is not attractive here. + if (geo.isDepthwise) + return {PimConvLoweringDepthwise, "depthwise convolution", /*isAuto=*/true, "", ""}; + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements) + return {PimConvLoweringPackedIm2Col, + "fits crossbar, packing useful, and global im2col fits budget", + /*isAuto=*/true, + "", + ""}; + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements > pimConvIm2colMaxElements) + return {PimConvLoweringStreamedPacked, + "fits crossbar and packing useful, but global im2col exceeds budget", + /*isAuto=*/true, + "", + ""}; + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize) + return {PimConvLoweringStreamedPatch, "fits crossbar but packing is not useful", /*isAuto=*/true, "", ""}; + if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize) + return {PimConvLoweringOutputChannelTiled, + "output channels exceed one crossbar width", + /*isAuto=*/true, + "", + ""}; + if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize) { + ConvStrategyEstimate estimate = estimateConvStrategy(geo, PimConvLoweringInputKTiled, analysis); + std::string fallbackReason = "auto rejects input-k-tiled because the reduction-heavy path is force-only for now"; + if (estimate.requiresFuncReturnMaterialization && estimate.perOutputPositionReduction + && estimate.giantCollectorConcatExpected) { + fallbackReason += "; func.return would materialize " + std::to_string(estimate.concatOperandCount) + + " output fragments through a single collector concat after per-position reductions"; + } + if (estimate.fullInputBroadcastExpected) + fallbackReason += "; the current lowering also broadcasts the padded input to many workers"; + return {PimConvLoweringLegacy, + "fall back to legacy explicit-im2col for the current auto policy", + /*isAuto=*/true, + fallbackReason, + stringifyConvLoweringStrategy(PimConvLoweringInputKTiled).str()}; + } + return {PimConvLoweringTiled2D, "both reduction K and output channels exceed one crossbar", /*isAuto=*/true, "", ""}; +} + +static LogicalResult verifyForcedConvLoweringStrategy(ONNXConvOp convOp, + const ConvGeometry& geo, + PimConvLoweringType strategy) { + switch (strategy) { + case PimConvLoweringAuto: + case PimConvLoweringLegacy: + return success(); + case PimConvLoweringDepthwise: + if (geo.isDepthwise) + return success(); + return convOp.emitOpError("forced depthwise Conv lowering requires a depthwise convolution"); + case PimConvLoweringPackedIm2Col: + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2 && geo.im2colElements <= pimConvIm2colMaxElements) + return success(); + return convOp.emitOpError("forced packed-im2col Conv lowering requires K/C to fit, pack >= 2, and im2col within budget"); + case PimConvLoweringStreamedPatch: + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize) + return success(); + return convOp.emitOpError("forced streamed-patch Conv lowering requires K and C to each fit one crossbar"); + case PimConvLoweringStreamedPacked: + if (geo.k <= geo.xbarSize && geo.c <= geo.xbarSize && geo.pack >= 2) + return success(); + return convOp.emitOpError("forced streamed-packed Conv lowering requires K/C to fit and pack >= 2"); + case PimConvLoweringOutputChannelTiled: + if (geo.k <= geo.xbarSize && geo.c > geo.xbarSize) + return success(); + return convOp.emitOpError("forced output-channel-tiled Conv lowering requires K <= X and C > X"); + case PimConvLoweringInputKTiled: + if (geo.k > geo.xbarSize && geo.c <= geo.xbarSize) + return success(); + return convOp.emitOpError("forced input-k-tiled Conv lowering requires K > X and C <= X"); + case PimConvLoweringTiled2D: + if (geo.k > geo.xbarSize && geo.c > geo.xbarSize) + return success(); + return convOp.emitOpError("forced tiled-2d Conv lowering requires K > X and C > X"); + } + llvm_unreachable("unknown conv lowering strategy"); +} + +static void reportConvLoweringDecision(ONNXConvOp convOp, + const ConvGeometry& geo, + const ConvLoweringDecision& decision, + const ConvStrategyEstimate& estimate, + int64_t batchSize, + int64_t numberOfBatches, + bool usesComputeBatch, + bool usesBatchedInstructionEmission, + std::optional streamChunkPositions = std::nullopt) { + if (!pimReportConvLowering) + return; + + const std::string location = summarizeLocation(convOp.getLoc()); + const std::string strategy = stringifyConvLoweringStrategy(decision.strategy).str(); + const std::string mode = decision.isAuto ? "auto" : "forced"; + const std::string inputShape = formatShape(cast(convOp.getX().getType()).getShape()); + const std::string weightShape = formatShape(cast(convOp.getW().getType()).getShape()); + const std::string outputShape = formatShape(cast(convOp.getY().getType()).getShape()); + const std::string chunkText = streamChunkPositions ? std::to_string(*streamChunkPositions) : "-"; + const std::string scbText = usesComputeBatch ? "yes" : "no"; + const std::string bieText = usesBatchedInstructionEmission ? "yes" : "no"; + + static uint64_t reportIndex = 0; + const uint64_t currentIndex = ++reportIndex; + static std::mutex reportMutex; + static std::vector reportEntries; + std::lock_guard lock(reportMutex); + reportEntries.push_back({ + currentIndex, + location, + strategy, + mode, + inputShape, + weightShape, + outputShape, + geo.group, + geo.k, + geo.c, + geo.p, + geo.xbarSize, + geo.pack, + geo.im2colElements, + pimConvIm2colMaxElements, + chunkText, + batchSize, + numberOfBatches, + scbText, + bieText, + decision.reason, + decision.fallbackReason, + decision.rejectedAutoStrategy, + estimate.estimatedMvmCount, + estimate.estimatedReductionVAddCount, + estimate.estimatedOutputFragments, + estimate.requiresFuncReturnMaterialization ? "yes" : "no", + estimate.materializationKind, + estimate.concatOperandCount, + estimate.collectorCore, + estimate.giantCollectorConcatExpected ? "yes" : "no", + estimate.fullInputBroadcastExpected ? "yes" : "no", + }); + rewriteConvLoweringReport(reportEntries); +} + +static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) { + if (pimConvStreamChunkPositions.getNumOccurrences() != 0) + return std::max(1, pimConvStreamChunkPositions); + + const uint64_t patchElements = static_cast(std::max(1, geo.k)); + uint64_t chunkPositions = std::max(1, pimConvIm2colMaxElements / patchElements); + chunkPositions = std::min(chunkPositions, static_cast(std::max(1, geo.p))); + if (packFactor > 1 && chunkPositions > static_cast(packFactor)) { + chunkPositions -= chunkPositions % static_cast(packFactor); + chunkPositions = std::max(chunkPositions, static_cast(packFactor)); + } + return std::max(1, chunkPositions); +} + static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 4b888b1..e74700d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -950,7 +950,12 @@ struct MatMulToGemm : OpRewritePattern { LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { auto shapeInfo = analyzeMatMulShape(matmulOp); - if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty()) + if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector) + return failure(); + + const bool hasNonSingletonOutputBatch = + !shapeInfo->outputBatchShape.empty() && getStaticShapeElementCount(shapeInfo->outputBatchShape) != 1; + if (hasNonSingletonOutputBatch) return failure(); Location loc = matmulOp.getLoc(); @@ -991,7 +996,17 @@ struct MatMulToGemm : OpRewritePattern { gemmResult = ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0})) .getResult(); - rewriter.replaceOp(matmulOp, gemmResult); + + if (shapeInfo->outputBatchShape.empty()) { + rewriter.replaceOp(matmulOp, gemmResult); + return success(); + } + + auto directOutType = + RankedTensorType::get({1, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); + Value batchedResult = ensureBatchedTensor(gemmResult, /*batchSize=*/1, shapeInfo->m, shapeInfo->n, rewriter, loc); + Value finalResult = finalizeNormalizedMatMulResult(batchedResult, directOutType, *shapeInfo, rewriter, loc); + rewriter.replaceOp(matmulOp, finalResult); return success(); } };