fix wrong send/receive reordering in post dcp merge instructions compaction
This commit is contained in:
@@ -1003,6 +1003,23 @@ public:
|
|||||||
DenseMap<Value, Value> externalInputMap;
|
DenseMap<Value, Value> externalInputMap;
|
||||||
DenseMap<Value, size_t> weightToIndex;
|
DenseMap<Value, size_t> weightToIndex;
|
||||||
};
|
};
|
||||||
|
struct RemoteSendInfo {
|
||||||
|
ChannelInfo channelInfo;
|
||||||
|
ComputeInstance consumer;
|
||||||
|
size_t inputIndex = 0;
|
||||||
|
size_t consumerOrder = 0;
|
||||||
|
size_t sourceOrder = 0;
|
||||||
|
};
|
||||||
|
struct RemoteReceiveEntry {
|
||||||
|
ChannelInfo channelInfo;
|
||||||
|
ComputeInstance consumer;
|
||||||
|
size_t inputIndex = 0;
|
||||||
|
size_t sourceOrder = 0;
|
||||||
|
};
|
||||||
|
auto getRemoteSendPairKey = [](const ChannelInfo& channelInfo) {
|
||||||
|
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
|
||||||
|
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||||
|
};
|
||||||
|
|
||||||
auto getTaskInputs = [&](const ScheduledTask& task) {
|
auto getTaskInputs = [&](const ScheduledTask& task) {
|
||||||
SmallVector<Value> inputs;
|
SmallVector<Value> inputs;
|
||||||
@@ -1143,7 +1160,7 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DenseMap<ComputeInstance, SmallVector<SmallVector<ChannelInfo>>> remoteSendsByTask;
|
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||||
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||||
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||||
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
||||||
@@ -1176,7 +1193,7 @@ public:
|
|||||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||||
if (perResultChannels.empty())
|
if (perResultChannels.empty())
|
||||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
|
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
|
||||||
perResultChannels[producerRef->resultIndex].push_back(info);
|
perResultChannels[producerRef->resultIndex].push_back({info, task.key, inputIndex, task.order, 0});
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -1201,6 +1218,79 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DenseSet<uint64_t> pairsNeedingReceiveReorder;
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
||||||
|
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
||||||
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
|
auto sendsIt = remoteSendsByTask.find(task.key);
|
||||||
|
if (sendsIt == remoteSendsByTask.end())
|
||||||
|
continue;
|
||||||
|
for (auto& sendInfos : sendsIt->second) {
|
||||||
|
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
|
||||||
|
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
|
||||||
|
if (!inserted) {
|
||||||
|
if (sendInfo.consumerOrder < it->second)
|
||||||
|
pairsNeedingReceiveReorder.insert(pairKey);
|
||||||
|
it->second = sendInfo.consumerOrder;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
|
||||||
|
for (auto& taskSends : remoteSendsByTask) {
|
||||||
|
for (auto& sendInfos : taskSends.second) {
|
||||||
|
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
if (pairsNeedingReceiveReorder.contains(pairKey))
|
||||||
|
reorderedSendsByPair[pairKey].push_back(&sendInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& pairSends : reorderedSendsByPair) {
|
||||||
|
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
|
||||||
|
if (lhs->sourceOrder != rhs->sourceOrder)
|
||||||
|
return lhs->sourceOrder < rhs->sourceOrder;
|
||||||
|
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
|
||||||
|
});
|
||||||
|
for (RemoteSendInfo* sendInfo : pairSends.second) {
|
||||||
|
int64_t channelId = nextChannelId++;
|
||||||
|
sendInfo->channelInfo.channelId = channelId;
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
|
||||||
|
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send");
|
||||||
|
assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||||
|
assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel");
|
||||||
|
remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
|
||||||
|
for (auto& taskSends : remoteSendsByTask) {
|
||||||
|
for (const auto& sendInfos : taskSends.second) {
|
||||||
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
if (!pairsNeedingReceiveReorder.contains(pairKey))
|
||||||
|
continue;
|
||||||
|
size_t targetCpu = static_cast<size_t>(sendInfo.channelInfo.targetCoreId - 1);
|
||||||
|
receiveQueuesByCpu[targetCpu][pairKey].push_back(
|
||||||
|
{sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& cpuQueues : receiveQueuesByCpu) {
|
||||||
|
for (auto& pairQueue : cpuQueues.second) {
|
||||||
|
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
|
||||||
|
if (lhs.sourceOrder != rhs.sourceOrder)
|
||||||
|
return lhs.sourceOrder < rhs.sourceOrder;
|
||||||
|
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
DenseMap<size_t, CpuProgram> cpuPrograms;
|
DenseMap<size_t, CpuProgram> cpuPrograms;
|
||||||
@@ -1255,6 +1345,59 @@ public:
|
|||||||
CpuProgram& program = cpuPrograms[cpu];
|
CpuProgram& program = cpuPrograms[cpu];
|
||||||
IRRewriter cpuRewriter(&getContext());
|
IRRewriter cpuRewriter(&getContext());
|
||||||
cpuRewriter.setInsertionPointToEnd(program.block);
|
cpuRewriter.setInsertionPointToEnd(program.block);
|
||||||
|
DenseMap<uint64_t, size_t> receiveQueueIndices;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<Value>> preReceivedInputsByTask;
|
||||||
|
|
||||||
|
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
|
||||||
|
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||||
|
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||||
|
return std::nullopt;
|
||||||
|
Value value = inputsIt->second[inputIndex];
|
||||||
|
if (!value)
|
||||||
|
return std::nullopt;
|
||||||
|
return value;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto receiveThroughInput = [&](const ChannelInfo& requestedChannelInfo,
|
||||||
|
ComputeInstance requestedConsumer,
|
||||||
|
size_t requestedInputIndex) -> std::optional<Value> {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
|
||||||
|
auto cpuQueuesIt = receiveQueuesByCpu.find(cpu);
|
||||||
|
if (cpuQueuesIt == receiveQueuesByCpu.end())
|
||||||
|
return std::nullopt;
|
||||||
|
auto queueIt = cpuQueuesIt->second.find(pairKey);
|
||||||
|
if (queueIt == cpuQueuesIt->second.end())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto& queue = queueIt->second;
|
||||||
|
size_t& queueIndex = receiveQueueIndices[pairKey];
|
||||||
|
while (queueIndex < queue.size()) {
|
||||||
|
const RemoteReceiveEntry& entry = queue[queueIndex++];
|
||||||
|
auto consumerTaskIt = taskByKey.find(entry.consumer);
|
||||||
|
if (consumerTaskIt == taskByKey.end())
|
||||||
|
return std::nullopt;
|
||||||
|
SmallVector<Value> consumerInputs = getTaskInputs(consumerTaskIt->second);
|
||||||
|
if (consumerInputs.size() <= entry.inputIndex)
|
||||||
|
return std::nullopt;
|
||||||
|
Type inputType = consumerInputs[entry.inputIndex].getType();
|
||||||
|
auto receive =
|
||||||
|
spatial::SpatChannelReceiveOp::create(cpuRewriter,
|
||||||
|
loc,
|
||||||
|
inputType,
|
||||||
|
cpuRewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
||||||
|
cpuRewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
|
||||||
|
cpuRewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
|
||||||
|
|
||||||
|
auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
|
||||||
|
if (receivedInputs.size() <= entry.inputIndex)
|
||||||
|
receivedInputs.resize(entry.inputIndex + 1);
|
||||||
|
receivedInputs[entry.inputIndex] = receive.getResult();
|
||||||
|
|
||||||
|
if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex)
|
||||||
|
return receive.getResult();
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
SmallVector<Value> taskInputs = getTaskInputs(task);
|
SmallVector<Value> taskInputs = getTaskInputs(task);
|
||||||
@@ -1284,6 +1427,24 @@ public:
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||||
|
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||||
|
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.key, inputIndex)) {
|
||||||
|
resolvedInputs.push_back(*preReceived);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::optional<Value> received = receiveThroughInput(channelInfo, task.key, inputIndex);
|
||||||
|
if (!received) {
|
||||||
|
task.sourceOp->emitOpError("failed to materialize reordered remote receive")
|
||||||
|
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
||||||
|
<< " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId
|
||||||
|
<< " channelId=" << channelInfo.channelId;
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(*received);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto receive =
|
auto receive =
|
||||||
spatial::SpatChannelReceiveOp::create(cpuRewriter,
|
spatial::SpatChannelReceiveOp::create(cpuRewriter,
|
||||||
loc,
|
loc,
|
||||||
@@ -1367,13 +1528,14 @@ public:
|
|||||||
if (sendInfos.empty())
|
if (sendInfos.empty())
|
||||||
continue;
|
continue;
|
||||||
Value producedValue = taskYieldValues[resultIndex];
|
Value producedValue = taskYieldValues[resultIndex];
|
||||||
for (const ChannelInfo& sendInfo : sendInfos)
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
spatial::SpatChannelSendOp::create(cpuRewriter,
|
spatial::SpatChannelSendOp::create(cpuRewriter,
|
||||||
loc,
|
loc,
|
||||||
cpuRewriter.getI64IntegerAttr(sendInfo.channelId),
|
cpuRewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
|
||||||
cpuRewriter.getI32IntegerAttr(sendInfo.sourceCoreId),
|
cpuRewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId),
|
||||||
cpuRewriter.getI32IntegerAttr(sendInfo.targetCoreId),
|
cpuRewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId),
|
||||||
producedValue);
|
producedValue);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1666,23 +1828,21 @@ private:
|
|||||||
IRRewriter rewriter(context);
|
IRRewriter rewriter(context);
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(producerOp);
|
rewriter.setInsertionPointAfter(producerOp);
|
||||||
auto savedSendInsertPoint = rewriter.saveInsertionPoint();
|
auto insertNew = [this, context, loc, computeValueResults, producerCpu](size_t resultIndex, size_t targetCpu) {
|
||||||
auto insertNew = [this, savedSendInsertPoint, context, loc, computeValueResults, producerCpu](size_t resultIndex,
|
|
||||||
size_t targetCpu) {
|
|
||||||
auto channelId = nextChannelId++;
|
auto channelId = nextChannelId++;
|
||||||
LazyInsertComputeResult::ChannelInfo channelInfo {
|
LazyInsertComputeResult::ChannelInfo channelInfo {
|
||||||
channelId, getPhysicalCoreId(producerCpu), getPhysicalCoreId(targetCpu)};
|
channelId, getPhysicalCoreId(producerCpu), getPhysicalCoreId(targetCpu)};
|
||||||
auto insertVal = [&context, loc, computeValueResults, channelInfo, resultIndex, savedSendInsertPoint](
|
auto insertVal =
|
||||||
mlir::IRRewriter::InsertPoint) {
|
[&context, loc, computeValueResults, channelInfo, resultIndex](mlir::IRRewriter::InsertPoint insertPoint) {
|
||||||
IRRewriter rewriter(context);
|
IRRewriter rewriter(context);
|
||||||
rewriter.restoreInsertionPoint(savedSendInsertPoint);
|
rewriter.restoreInsertionPoint(insertPoint);
|
||||||
spatial::SpatChannelSendOp::create(rewriter,
|
spatial::SpatChannelSendOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId),
|
rewriter.getI32IntegerAttr(channelInfo.targetCoreId),
|
||||||
computeValueResults.getOuter(resultIndex));
|
computeValueResults.getOuter(resultIndex));
|
||||||
};
|
};
|
||||||
std::pair<LazyInsertComputeResult::ChannelInfo, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {
|
std::pair<LazyInsertComputeResult::ChannelInfo, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {
|
||||||
channelInfo, insertVal};
|
channelInfo, insertVal};
|
||||||
return ret;
|
return ret;
|
||||||
|
|||||||
@@ -10,8 +10,6 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <tuple>
|
|
||||||
|
|
||||||
#include "RegularOpCompaction.hpp"
|
#include "RegularOpCompaction.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -340,7 +338,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
++runIt;
|
++runIt;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (run.size() > 1) {
|
bool hasRepeatedEndpoint = false;
|
||||||
|
for (size_t lhs = 0; lhs < run.size() && !hasRepeatedEndpoint; ++lhs) {
|
||||||
|
for (size_t rhs = lhs + 1; rhs < run.size(); ++rhs) {
|
||||||
|
if (run[lhs].getSourceCoreId() == run[rhs].getSourceCoreId()
|
||||||
|
&& run[lhs].getTargetCoreId() == run[rhs].getTargetCoreId()) {
|
||||||
|
hasRepeatedEndpoint = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (run.size() > 1 && !hasRepeatedEndpoint) {
|
||||||
struct ReceiveEntry {
|
struct ReceiveEntry {
|
||||||
spatial::SpatChannelReceiveOp op;
|
spatial::SpatChannelReceiveOp op;
|
||||||
size_t originalIndex = 0;
|
size_t originalIndex = 0;
|
||||||
@@ -352,10 +361,6 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
sortedEntries.reserve(run.size());
|
sortedEntries.reserve(run.size());
|
||||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
for (auto [originalIndex, op] : llvm::enumerate(run))
|
||||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||||
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
|
||||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
|
||||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
@@ -436,10 +441,6 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
sortedEntries.reserve(run.size());
|
sortedEntries.reserve(run.size());
|
||||||
for (auto op : run)
|
for (auto op : run)
|
||||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||||
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
|
||||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
|
||||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
|||||||
Reference in New Issue
Block a user