From 061139aefb879aaeb2caa4c46aafc03fdd2ece3b Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 13 May 2026 21:48:49 +0200 Subject: [PATCH] fix wrong send/receive reordering in post dcp merge instructions compaction --- .../MergeComputeNodesPass.cpp | 200 ++++++++++++++++-- .../MergeComputeNodes/RegularOpCompaction.cpp | 23 +- 2 files changed, 192 insertions(+), 31 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 9386805..a60b64e 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -1003,6 +1003,23 @@ public: DenseMap externalInputMap; DenseMap weightToIndex; }; + struct RemoteSendInfo { + ChannelInfo channelInfo; + ComputeInstance consumer; + size_t inputIndex = 0; + size_t consumerOrder = 0; + size_t sourceOrder = 0; + }; + struct RemoteReceiveEntry { + ChannelInfo channelInfo; + ComputeInstance consumer; + size_t inputIndex = 0; + size_t sourceOrder = 0; + }; + auto getRemoteSendPairKey = [](const ChannelInfo& channelInfo) { + return (static_cast(static_cast(channelInfo.sourceCoreId)) << 32) + | static_cast(channelInfo.targetCoreId); + }; auto getTaskInputs = [&](const ScheduledTask& task) { SmallVector inputs; @@ -1143,7 +1160,7 @@ public: } }; - DenseMap>> remoteSendsByTask; + DenseMap>> remoteSendsByTask; DenseMap>> remoteInputsByTask; DenseMap> cpuExternalInputs; DenseMap> cpuWeights; @@ -1176,7 +1193,7 @@ public: auto& perResultChannels = remoteSendsByTask[producerRef->instance]; if (perResultChannels.empty()) perResultChannels.resize(getTaskOutputTypes(producerIt->second).size()); - perResultChannels[producerRef->resultIndex].push_back(info); + perResultChannels[producerRef->resultIndex].push_back({info, task.key, inputIndex, task.order, 0}); } continue; } @@ -1201,6 +1218,79 @@ public: } } + DenseSet pairsNeedingReceiveReorder; + for (size_t cpu : orderedCpus) { + DenseMap nextSourceOrderByPair; + DenseMap lastConsumerOrderByPair; + for (const ScheduledTask& task : tasksByCpu[cpu]) { + auto sendsIt = remoteSendsByTask.find(task.key); + if (sendsIt == remoteSendsByTask.end()) + continue; + for (auto& sendInfos : sendsIt->second) { + for (RemoteSendInfo& sendInfo : sendInfos) { + uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); + sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++; + auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder); + if (!inserted) { + if (sendInfo.consumerOrder < it->second) + pairsNeedingReceiveReorder.insert(pairKey); + it->second = sendInfo.consumerOrder; + } + } + } + } + } + + DenseMap> reorderedSendsByPair; + for (auto& taskSends : remoteSendsByTask) { + for (auto& sendInfos : taskSends.second) { + for (RemoteSendInfo& sendInfo : sendInfos) { + uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); + if (pairsNeedingReceiveReorder.contains(pairKey)) + reorderedSendsByPair[pairKey].push_back(&sendInfo); + } + } + } + for (auto& pairSends : reorderedSendsByPair) { + llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) { + if (lhs->sourceOrder != rhs->sourceOrder) + return lhs->sourceOrder < rhs->sourceOrder; + return lhs->channelInfo.channelId < rhs->channelInfo.channelId; + }); + for (RemoteSendInfo* sendInfo : pairSends.second) { + int64_t channelId = nextChannelId++; + sendInfo->channelInfo.channelId = channelId; + auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer); + assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send"); + assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range"); + assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel"); + remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId; + } + } + + DenseMap>> receiveQueuesByCpu; + for (auto& taskSends : remoteSendsByTask) { + for (const auto& sendInfos : taskSends.second) { + for (const RemoteSendInfo& sendInfo : sendInfos) { + uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); + if (!pairsNeedingReceiveReorder.contains(pairKey)) + continue; + size_t targetCpu = static_cast(sendInfo.channelInfo.targetCoreId - 1); + receiveQueuesByCpu[targetCpu][pairKey].push_back( + {sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder}); + } + } + } + for (auto& cpuQueues : receiveQueuesByCpu) { + for (auto& pairQueue : cpuQueues.second) { + llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) { + if (lhs.sourceOrder != rhs.sourceOrder) + return lhs.sourceOrder < rhs.sourceOrder; + return lhs.channelInfo.channelId < rhs.channelInfo.channelId; + }); + } + } + auto returnOp = cast(func.getBody().front().getTerminator()); IRRewriter rewriter(&getContext()); DenseMap cpuPrograms; @@ -1255,6 +1345,59 @@ public: CpuProgram& program = cpuPrograms[cpu]; IRRewriter cpuRewriter(&getContext()); cpuRewriter.setInsertionPointToEnd(program.block); + DenseMap receiveQueueIndices; + DenseMap> preReceivedInputsByTask; + + auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional { + auto inputsIt = preReceivedInputsByTask.find(consumer); + if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex) + return std::nullopt; + Value value = inputsIt->second[inputIndex]; + if (!value) + return std::nullopt; + return value; + }; + + auto receiveThroughInput = [&](const ChannelInfo& requestedChannelInfo, + ComputeInstance requestedConsumer, + size_t requestedInputIndex) -> std::optional { + uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo); + auto cpuQueuesIt = receiveQueuesByCpu.find(cpu); + if (cpuQueuesIt == receiveQueuesByCpu.end()) + return std::nullopt; + auto queueIt = cpuQueuesIt->second.find(pairKey); + if (queueIt == cpuQueuesIt->second.end()) + return std::nullopt; + + auto& queue = queueIt->second; + size_t& queueIndex = receiveQueueIndices[pairKey]; + while (queueIndex < queue.size()) { + const RemoteReceiveEntry& entry = queue[queueIndex++]; + auto consumerTaskIt = taskByKey.find(entry.consumer); + if (consumerTaskIt == taskByKey.end()) + return std::nullopt; + SmallVector consumerInputs = getTaskInputs(consumerTaskIt->second); + if (consumerInputs.size() <= entry.inputIndex) + return std::nullopt; + Type inputType = consumerInputs[entry.inputIndex].getType(); + auto receive = + spatial::SpatChannelReceiveOp::create(cpuRewriter, + loc, + inputType, + cpuRewriter.getI64IntegerAttr(entry.channelInfo.channelId), + cpuRewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId), + cpuRewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId)); + + auto& receivedInputs = preReceivedInputsByTask[entry.consumer]; + if (receivedInputs.size() <= entry.inputIndex) + receivedInputs.resize(entry.inputIndex + 1); + receivedInputs[entry.inputIndex] = receive.getResult(); + + if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex) + return receive.getResult(); + } + return std::nullopt; + }; for (const ScheduledTask& task : tasksByCpu[cpu]) { SmallVector taskInputs = getTaskInputs(task); @@ -1284,6 +1427,24 @@ public: continue; } const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex]; + uint64_t pairKey = getRemoteSendPairKey(channelInfo); + if (pairsNeedingReceiveReorder.contains(pairKey)) { + if (std::optional preReceived = lookupPreReceivedInput(task.key, inputIndex)) { + resolvedInputs.push_back(*preReceived); + continue; + } + std::optional received = receiveThroughInput(channelInfo, task.key, inputIndex); + if (!received) { + task.sourceOp->emitOpError("failed to materialize reordered remote receive") + << " consumerCpu=" << cpu << " consumerSlot=" << task.slot + << " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId + << " channelId=" << channelInfo.channelId; + signalPassFailure(); + return; + } + resolvedInputs.push_back(*received); + continue; + } auto receive = spatial::SpatChannelReceiveOp::create(cpuRewriter, loc, @@ -1367,13 +1528,14 @@ public: if (sendInfos.empty()) continue; Value producedValue = taskYieldValues[resultIndex]; - for (const ChannelInfo& sendInfo : sendInfos) + for (const RemoteSendInfo& sendInfo : sendInfos) { spatial::SpatChannelSendOp::create(cpuRewriter, loc, - cpuRewriter.getI64IntegerAttr(sendInfo.channelId), - cpuRewriter.getI32IntegerAttr(sendInfo.sourceCoreId), - cpuRewriter.getI32IntegerAttr(sendInfo.targetCoreId), + cpuRewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId), + cpuRewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId), + cpuRewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId), producedValue); + } } } } @@ -1666,23 +1828,21 @@ private: IRRewriter rewriter(context); rewriter.setInsertionPointAfter(producerOp); - auto savedSendInsertPoint = rewriter.saveInsertionPoint(); - auto insertNew = [this, savedSendInsertPoint, context, loc, computeValueResults, producerCpu](size_t resultIndex, - size_t targetCpu) { + auto insertNew = [this, context, loc, computeValueResults, producerCpu](size_t resultIndex, size_t targetCpu) { auto channelId = nextChannelId++; LazyInsertComputeResult::ChannelInfo channelInfo { channelId, getPhysicalCoreId(producerCpu), getPhysicalCoreId(targetCpu)}; - auto insertVal = [&context, loc, computeValueResults, channelInfo, resultIndex, savedSendInsertPoint]( - mlir::IRRewriter::InsertPoint) { - IRRewriter rewriter(context); - rewriter.restoreInsertionPoint(savedSendInsertPoint); - spatial::SpatChannelSendOp::create(rewriter, - loc, - rewriter.getI64IntegerAttr(channelInfo.channelId), - rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(channelInfo.targetCoreId), - computeValueResults.getOuter(resultIndex)); - }; + auto insertVal = + [&context, loc, computeValueResults, channelInfo, resultIndex](mlir::IRRewriter::InsertPoint insertPoint) { + IRRewriter rewriter(context); + rewriter.restoreInsertionPoint(insertPoint); + spatial::SpatChannelSendOp::create(rewriter, + loc, + rewriter.getI64IntegerAttr(channelInfo.channelId), + rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(channelInfo.targetCoreId), + computeValueResults.getOuter(resultIndex)); + }; std::pair> ret { channelInfo, insertVal}; return ret; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp index 8a92beb..64cdfc7 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -10,8 +10,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include - #include "RegularOpCompaction.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -340,7 +338,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { ++runIt; } - if (run.size() > 1) { + bool hasRepeatedEndpoint = false; + for (size_t lhs = 0; lhs < run.size() && !hasRepeatedEndpoint; ++lhs) { + for (size_t rhs = lhs + 1; rhs < run.size(); ++rhs) { + if (run[lhs].getSourceCoreId() == run[rhs].getSourceCoreId() + && run[lhs].getTargetCoreId() == run[rhs].getTargetCoreId()) { + hasRepeatedEndpoint = true; + break; + } + } + } + + if (run.size() > 1 && !hasRepeatedEndpoint) { struct ReceiveEntry { spatial::SpatChannelReceiveOp op; size_t originalIndex = 0; @@ -352,10 +361,6 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { sortedEntries.reserve(run.size()); for (auto [originalIndex, op] : llvm::enumerate(run)) sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); - llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); SmallVector channelIds; SmallVector sourceCoreIds; @@ -436,10 +441,6 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { sortedEntries.reserve(run.size()); for (auto op : run) sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); - llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); SmallVector channelIds; SmallVector sourceCoreIds;