fix wrong send/receive reordering in post dcp merge instructions compaction
This commit is contained in:
@@ -1003,6 +1003,23 @@ public:
|
||||
DenseMap<Value, Value> externalInputMap;
|
||||
DenseMap<Value, size_t> weightToIndex;
|
||||
};
|
||||
struct RemoteSendInfo {
|
||||
ChannelInfo channelInfo;
|
||||
ComputeInstance consumer;
|
||||
size_t inputIndex = 0;
|
||||
size_t consumerOrder = 0;
|
||||
size_t sourceOrder = 0;
|
||||
};
|
||||
struct RemoteReceiveEntry {
|
||||
ChannelInfo channelInfo;
|
||||
ComputeInstance consumer;
|
||||
size_t inputIndex = 0;
|
||||
size_t sourceOrder = 0;
|
||||
};
|
||||
auto getRemoteSendPairKey = [](const ChannelInfo& channelInfo) {
|
||||
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
|
||||
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||
};
|
||||
|
||||
auto getTaskInputs = [&](const ScheduledTask& task) {
|
||||
SmallVector<Value> inputs;
|
||||
@@ -1143,7 +1160,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
DenseMap<ComputeInstance, SmallVector<SmallVector<ChannelInfo>>> remoteSendsByTask;
|
||||
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
||||
@@ -1176,7 +1193,7 @@ public:
|
||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||
if (perResultChannels.empty())
|
||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
|
||||
perResultChannels[producerRef->resultIndex].push_back(info);
|
||||
perResultChannels[producerRef->resultIndex].push_back({info, task.key, inputIndex, task.order, 0});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -1201,6 +1218,79 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
DenseSet<uint64_t> pairsNeedingReceiveReorder;
|
||||
for (size_t cpu : orderedCpus) {
|
||||
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
||||
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
auto sendsIt = remoteSendsByTask.find(task.key);
|
||||
if (sendsIt == remoteSendsByTask.end())
|
||||
continue;
|
||||
for (auto& sendInfos : sendsIt->second) {
|
||||
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
|
||||
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
|
||||
if (!inserted) {
|
||||
if (sendInfo.consumerOrder < it->second)
|
||||
pairsNeedingReceiveReorder.insert(pairKey);
|
||||
it->second = sendInfo.consumerOrder;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
|
||||
for (auto& taskSends : remoteSendsByTask) {
|
||||
for (auto& sendInfos : taskSends.second) {
|
||||
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey))
|
||||
reorderedSendsByPair[pairKey].push_back(&sendInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& pairSends : reorderedSendsByPair) {
|
||||
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
|
||||
if (lhs->sourceOrder != rhs->sourceOrder)
|
||||
return lhs->sourceOrder < rhs->sourceOrder;
|
||||
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
|
||||
});
|
||||
for (RemoteSendInfo* sendInfo : pairSends.second) {
|
||||
int64_t channelId = nextChannelId++;
|
||||
sendInfo->channelInfo.channelId = channelId;
|
||||
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
|
||||
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send");
|
||||
assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||
assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel");
|
||||
remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId;
|
||||
}
|
||||
}
|
||||
|
||||
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
|
||||
for (auto& taskSends : remoteSendsByTask) {
|
||||
for (const auto& sendInfos : taskSends.second) {
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
if (!pairsNeedingReceiveReorder.contains(pairKey))
|
||||
continue;
|
||||
size_t targetCpu = static_cast<size_t>(sendInfo.channelInfo.targetCoreId - 1);
|
||||
receiveQueuesByCpu[targetCpu][pairKey].push_back(
|
||||
{sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder});
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& cpuQueues : receiveQueuesByCpu) {
|
||||
for (auto& pairQueue : cpuQueues.second) {
|
||||
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
|
||||
if (lhs.sourceOrder != rhs.sourceOrder)
|
||||
return lhs.sourceOrder < rhs.sourceOrder;
|
||||
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
|
||||
IRRewriter rewriter(&getContext());
|
||||
DenseMap<size_t, CpuProgram> cpuPrograms;
|
||||
@@ -1255,6 +1345,59 @@ public:
|
||||
CpuProgram& program = cpuPrograms[cpu];
|
||||
IRRewriter cpuRewriter(&getContext());
|
||||
cpuRewriter.setInsertionPointToEnd(program.block);
|
||||
DenseMap<uint64_t, size_t> receiveQueueIndices;
|
||||
DenseMap<ComputeInstance, SmallVector<Value>> preReceivedInputsByTask;
|
||||
|
||||
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
|
||||
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||
return std::nullopt;
|
||||
Value value = inputsIt->second[inputIndex];
|
||||
if (!value)
|
||||
return std::nullopt;
|
||||
return value;
|
||||
};
|
||||
|
||||
auto receiveThroughInput = [&](const ChannelInfo& requestedChannelInfo,
|
||||
ComputeInstance requestedConsumer,
|
||||
size_t requestedInputIndex) -> std::optional<Value> {
|
||||
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
|
||||
auto cpuQueuesIt = receiveQueuesByCpu.find(cpu);
|
||||
if (cpuQueuesIt == receiveQueuesByCpu.end())
|
||||
return std::nullopt;
|
||||
auto queueIt = cpuQueuesIt->second.find(pairKey);
|
||||
if (queueIt == cpuQueuesIt->second.end())
|
||||
return std::nullopt;
|
||||
|
||||
auto& queue = queueIt->second;
|
||||
size_t& queueIndex = receiveQueueIndices[pairKey];
|
||||
while (queueIndex < queue.size()) {
|
||||
const RemoteReceiveEntry& entry = queue[queueIndex++];
|
||||
auto consumerTaskIt = taskByKey.find(entry.consumer);
|
||||
if (consumerTaskIt == taskByKey.end())
|
||||
return std::nullopt;
|
||||
SmallVector<Value> consumerInputs = getTaskInputs(consumerTaskIt->second);
|
||||
if (consumerInputs.size() <= entry.inputIndex)
|
||||
return std::nullopt;
|
||||
Type inputType = consumerInputs[entry.inputIndex].getType();
|
||||
auto receive =
|
||||
spatial::SpatChannelReceiveOp::create(cpuRewriter,
|
||||
loc,
|
||||
inputType,
|
||||
cpuRewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
||||
cpuRewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
|
||||
cpuRewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
|
||||
|
||||
auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
|
||||
if (receivedInputs.size() <= entry.inputIndex)
|
||||
receivedInputs.resize(entry.inputIndex + 1);
|
||||
receivedInputs[entry.inputIndex] = receive.getResult();
|
||||
|
||||
if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex)
|
||||
return receive.getResult();
|
||||
}
|
||||
return std::nullopt;
|
||||
};
|
||||
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
SmallVector<Value> taskInputs = getTaskInputs(task);
|
||||
@@ -1284,6 +1427,24 @@ public:
|
||||
continue;
|
||||
}
|
||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.key, inputIndex)) {
|
||||
resolvedInputs.push_back(*preReceived);
|
||||
continue;
|
||||
}
|
||||
std::optional<Value> received = receiveThroughInput(channelInfo, task.key, inputIndex);
|
||||
if (!received) {
|
||||
task.sourceOp->emitOpError("failed to materialize reordered remote receive")
|
||||
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
||||
<< " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId
|
||||
<< " channelId=" << channelInfo.channelId;
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
resolvedInputs.push_back(*received);
|
||||
continue;
|
||||
}
|
||||
auto receive =
|
||||
spatial::SpatChannelReceiveOp::create(cpuRewriter,
|
||||
loc,
|
||||
@@ -1367,13 +1528,14 @@ public:
|
||||
if (sendInfos.empty())
|
||||
continue;
|
||||
Value producedValue = taskYieldValues[resultIndex];
|
||||
for (const ChannelInfo& sendInfo : sendInfos)
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
spatial::SpatChannelSendOp::create(cpuRewriter,
|
||||
loc,
|
||||
cpuRewriter.getI64IntegerAttr(sendInfo.channelId),
|
||||
cpuRewriter.getI32IntegerAttr(sendInfo.sourceCoreId),
|
||||
cpuRewriter.getI32IntegerAttr(sendInfo.targetCoreId),
|
||||
cpuRewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
|
||||
cpuRewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId),
|
||||
cpuRewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId),
|
||||
producedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1666,23 +1828,21 @@ private:
|
||||
IRRewriter rewriter(context);
|
||||
|
||||
rewriter.setInsertionPointAfter(producerOp);
|
||||
auto savedSendInsertPoint = rewriter.saveInsertionPoint();
|
||||
auto insertNew = [this, savedSendInsertPoint, context, loc, computeValueResults, producerCpu](size_t resultIndex,
|
||||
size_t targetCpu) {
|
||||
auto insertNew = [this, context, loc, computeValueResults, producerCpu](size_t resultIndex, size_t targetCpu) {
|
||||
auto channelId = nextChannelId++;
|
||||
LazyInsertComputeResult::ChannelInfo channelInfo {
|
||||
channelId, getPhysicalCoreId(producerCpu), getPhysicalCoreId(targetCpu)};
|
||||
auto insertVal = [&context, loc, computeValueResults, channelInfo, resultIndex, savedSendInsertPoint](
|
||||
mlir::IRRewriter::InsertPoint) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(savedSendInsertPoint);
|
||||
spatial::SpatChannelSendOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId),
|
||||
computeValueResults.getOuter(resultIndex));
|
||||
};
|
||||
auto insertVal =
|
||||
[&context, loc, computeValueResults, channelInfo, resultIndex](mlir::IRRewriter::InsertPoint insertPoint) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(insertPoint);
|
||||
spatial::SpatChannelSendOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId),
|
||||
computeValueResults.getOuter(resultIndex));
|
||||
};
|
||||
std::pair<LazyInsertComputeResult::ChannelInfo, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {
|
||||
channelInfo, insertVal};
|
||||
return ret;
|
||||
|
||||
@@ -10,8 +10,6 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -340,7 +338,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
++runIt;
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
bool hasRepeatedEndpoint = false;
|
||||
for (size_t lhs = 0; lhs < run.size() && !hasRepeatedEndpoint; ++lhs) {
|
||||
for (size_t rhs = lhs + 1; rhs < run.size(); ++rhs) {
|
||||
if (run[lhs].getSourceCoreId() == run[rhs].getSourceCoreId()
|
||||
&& run[lhs].getTargetCoreId() == run[rhs].getTargetCoreId()) {
|
||||
hasRepeatedEndpoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (run.size() > 1 && !hasRepeatedEndpoint) {
|
||||
struct ReceiveEntry {
|
||||
spatial::SpatChannelReceiveOp op;
|
||||
size_t originalIndex = 0;
|
||||
@@ -352,10 +361,6 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -436,10 +441,6 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto op : run)
|
||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
|
||||
Reference in New Issue
Block a user