fix wrong send/receive reordering in post dcp merge instructions compaction

This commit is contained in:
NiccoloN
2026-05-13 21:48:49 +02:00
parent ea61540e08
commit 061139aefb
2 changed files with 192 additions and 31 deletions
@@ -1003,6 +1003,23 @@ public:
DenseMap<Value, Value> externalInputMap;
DenseMap<Value, size_t> weightToIndex;
};
struct RemoteSendInfo {
ChannelInfo channelInfo;
ComputeInstance consumer;
size_t inputIndex = 0;
size_t consumerOrder = 0;
size_t sourceOrder = 0;
};
struct RemoteReceiveEntry {
ChannelInfo channelInfo;
ComputeInstance consumer;
size_t inputIndex = 0;
size_t sourceOrder = 0;
};
auto getRemoteSendPairKey = [](const ChannelInfo& channelInfo) {
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
| static_cast<uint32_t>(channelInfo.targetCoreId);
};
auto getTaskInputs = [&](const ScheduledTask& task) {
SmallVector<Value> inputs;
@@ -1143,7 +1160,7 @@ public:
}
};
DenseMap<ComputeInstance, SmallVector<SmallVector<ChannelInfo>>> remoteSendsByTask;
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
DenseMap<size_t, SmallVector<Value>> cpuWeights;
@@ -1176,7 +1193,7 @@ public:
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
if (perResultChannels.empty())
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
perResultChannels[producerRef->resultIndex].push_back(info);
perResultChannels[producerRef->resultIndex].push_back({info, task.key, inputIndex, task.order, 0});
}
continue;
}
@@ -1201,6 +1218,79 @@ public:
}
}
DenseSet<uint64_t> pairsNeedingReceiveReorder;
for (size_t cpu : orderedCpus) {
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
for (const ScheduledTask& task : tasksByCpu[cpu]) {
auto sendsIt = remoteSendsByTask.find(task.key);
if (sendsIt == remoteSendsByTask.end())
continue;
for (auto& sendInfos : sendsIt->second) {
for (RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
if (!inserted) {
if (sendInfo.consumerOrder < it->second)
pairsNeedingReceiveReorder.insert(pairKey);
it->second = sendInfo.consumerOrder;
}
}
}
}
}
DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
for (auto& taskSends : remoteSendsByTask) {
for (auto& sendInfos : taskSends.second) {
for (RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey))
reorderedSendsByPair[pairKey].push_back(&sendInfo);
}
}
}
for (auto& pairSends : reorderedSendsByPair) {
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
if (lhs->sourceOrder != rhs->sourceOrder)
return lhs->sourceOrder < rhs->sourceOrder;
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
});
for (RemoteSendInfo* sendInfo : pairSends.second) {
int64_t channelId = nextChannelId++;
sendInfo->channelInfo.channelId = channelId;
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send");
assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel");
remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId;
}
}
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
for (auto& taskSends : remoteSendsByTask) {
for (const auto& sendInfos : taskSends.second) {
for (const RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
if (!pairsNeedingReceiveReorder.contains(pairKey))
continue;
size_t targetCpu = static_cast<size_t>(sendInfo.channelInfo.targetCoreId - 1);
receiveQueuesByCpu[targetCpu][pairKey].push_back(
{sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder});
}
}
}
for (auto& cpuQueues : receiveQueuesByCpu) {
for (auto& pairQueue : cpuQueues.second) {
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
if (lhs.sourceOrder != rhs.sourceOrder)
return lhs.sourceOrder < rhs.sourceOrder;
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
});
}
}
auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
IRRewriter rewriter(&getContext());
DenseMap<size_t, CpuProgram> cpuPrograms;
@@ -1255,6 +1345,59 @@ public:
CpuProgram& program = cpuPrograms[cpu];
IRRewriter cpuRewriter(&getContext());
cpuRewriter.setInsertionPointToEnd(program.block);
DenseMap<uint64_t, size_t> receiveQueueIndices;
DenseMap<ComputeInstance, SmallVector<Value>> preReceivedInputsByTask;
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
auto inputsIt = preReceivedInputsByTask.find(consumer);
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
return std::nullopt;
Value value = inputsIt->second[inputIndex];
if (!value)
return std::nullopt;
return value;
};
auto receiveThroughInput = [&](const ChannelInfo& requestedChannelInfo,
ComputeInstance requestedConsumer,
size_t requestedInputIndex) -> std::optional<Value> {
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
auto cpuQueuesIt = receiveQueuesByCpu.find(cpu);
if (cpuQueuesIt == receiveQueuesByCpu.end())
return std::nullopt;
auto queueIt = cpuQueuesIt->second.find(pairKey);
if (queueIt == cpuQueuesIt->second.end())
return std::nullopt;
auto& queue = queueIt->second;
size_t& queueIndex = receiveQueueIndices[pairKey];
while (queueIndex < queue.size()) {
const RemoteReceiveEntry& entry = queue[queueIndex++];
auto consumerTaskIt = taskByKey.find(entry.consumer);
if (consumerTaskIt == taskByKey.end())
return std::nullopt;
SmallVector<Value> consumerInputs = getTaskInputs(consumerTaskIt->second);
if (consumerInputs.size() <= entry.inputIndex)
return std::nullopt;
Type inputType = consumerInputs[entry.inputIndex].getType();
auto receive =
spatial::SpatChannelReceiveOp::create(cpuRewriter,
loc,
inputType,
cpuRewriter.getI64IntegerAttr(entry.channelInfo.channelId),
cpuRewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
cpuRewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
if (receivedInputs.size() <= entry.inputIndex)
receivedInputs.resize(entry.inputIndex + 1);
receivedInputs[entry.inputIndex] = receive.getResult();
if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex)
return receive.getResult();
}
return std::nullopt;
};
for (const ScheduledTask& task : tasksByCpu[cpu]) {
SmallVector<Value> taskInputs = getTaskInputs(task);
@@ -1284,6 +1427,24 @@ public:
continue;
}
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey)) {
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.key, inputIndex)) {
resolvedInputs.push_back(*preReceived);
continue;
}
std::optional<Value> received = receiveThroughInput(channelInfo, task.key, inputIndex);
if (!received) {
task.sourceOp->emitOpError("failed to materialize reordered remote receive")
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
<< " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId
<< " channelId=" << channelInfo.channelId;
signalPassFailure();
return;
}
resolvedInputs.push_back(*received);
continue;
}
auto receive =
spatial::SpatChannelReceiveOp::create(cpuRewriter,
loc,
@@ -1367,13 +1528,14 @@ public:
if (sendInfos.empty())
continue;
Value producedValue = taskYieldValues[resultIndex];
for (const ChannelInfo& sendInfo : sendInfos)
for (const RemoteSendInfo& sendInfo : sendInfos) {
spatial::SpatChannelSendOp::create(cpuRewriter,
loc,
cpuRewriter.getI64IntegerAttr(sendInfo.channelId),
cpuRewriter.getI32IntegerAttr(sendInfo.sourceCoreId),
cpuRewriter.getI32IntegerAttr(sendInfo.targetCoreId),
cpuRewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
cpuRewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId),
cpuRewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId),
producedValue);
}
}
}
}
@@ -1666,23 +1828,21 @@ private:
IRRewriter rewriter(context);
rewriter.setInsertionPointAfter(producerOp);
auto savedSendInsertPoint = rewriter.saveInsertionPoint();
auto insertNew = [this, savedSendInsertPoint, context, loc, computeValueResults, producerCpu](size_t resultIndex,
size_t targetCpu) {
auto insertNew = [this, context, loc, computeValueResults, producerCpu](size_t resultIndex, size_t targetCpu) {
auto channelId = nextChannelId++;
LazyInsertComputeResult::ChannelInfo channelInfo {
channelId, getPhysicalCoreId(producerCpu), getPhysicalCoreId(targetCpu)};
auto insertVal = [&context, loc, computeValueResults, channelInfo, resultIndex, savedSendInsertPoint](
mlir::IRRewriter::InsertPoint) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(savedSendInsertPoint);
spatial::SpatChannelSendOp::create(rewriter,
loc,
rewriter.getI64IntegerAttr(channelInfo.channelId),
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
rewriter.getI32IntegerAttr(channelInfo.targetCoreId),
computeValueResults.getOuter(resultIndex));
};
auto insertVal =
[&context, loc, computeValueResults, channelInfo, resultIndex](mlir::IRRewriter::InsertPoint insertPoint) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(insertPoint);
spatial::SpatChannelSendOp::create(rewriter,
loc,
rewriter.getI64IntegerAttr(channelInfo.channelId),
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
rewriter.getI32IntegerAttr(channelInfo.targetCoreId),
computeValueResults.getOuter(resultIndex));
};
std::pair<LazyInsertComputeResult::ChannelInfo, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {
channelInfo, insertVal};
return ret;
@@ -10,8 +10,6 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <tuple>
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -340,7 +338,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
++runIt;
}
if (run.size() > 1) {
bool hasRepeatedEndpoint = false;
for (size_t lhs = 0; lhs < run.size() && !hasRepeatedEndpoint; ++lhs) {
for (size_t rhs = lhs + 1; rhs < run.size(); ++rhs) {
if (run[lhs].getSourceCoreId() == run[rhs].getSourceCoreId()
&& run[lhs].getTargetCoreId() == run[rhs].getTargetCoreId()) {
hasRepeatedEndpoint = true;
break;
}
}
}
if (run.size() > 1 && !hasRepeatedEndpoint) {
struct ReceiveEntry {
spatial::SpatChannelReceiveOp op;
size_t originalIndex = 0;
@@ -352,10 +361,6 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
sortedEntries.reserve(run.size());
for (auto [originalIndex, op] : llvm::enumerate(run))
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
@@ -436,10 +441,6 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
sortedEntries.reserve(run.size());
for (auto op : run)
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;