This commit is contained in:
+144
-144
@@ -1,5 +1,3 @@
|
|||||||
#include "MaterializeMergeSchedule.hpp"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
@@ -16,6 +14,7 @@
|
|||||||
#include <optional>
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "MaterializeMergeSchedule.hpp"
|
||||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -41,11 +40,9 @@ static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast<int32
|
|||||||
class MergeScheduleMaterializerImpl {
|
class MergeScheduleMaterializerImpl {
|
||||||
public:
|
public:
|
||||||
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
|
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
|
||||||
: func(funcOp),
|
: func(funcOp), loc(funcOp.getLoc()), returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
|
||||||
loc(funcOp.getLoc()),
|
|
||||||
returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
|
|
||||||
|
|
||||||
LogicalResult run(const MergeScheduleResult &scheduleResult, int64_t &nextChannelIdRef) {
|
LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) {
|
||||||
schedule = &scheduleResult;
|
schedule = &scheduleResult;
|
||||||
nextChannelId = &nextChannelIdRef;
|
nextChannelId = &nextChannelIdRef;
|
||||||
|
|
||||||
@@ -66,10 +63,9 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
struct ScheduledTask {
|
struct ScheduledTask {
|
||||||
ComputeInstance key;
|
ComputeInstance computeInstance;
|
||||||
Operation *sourceOp = nullptr;
|
Operation* sourceOp = nullptr;
|
||||||
size_t cpu = 0;
|
size_t cpu = 0;
|
||||||
size_t slot = 0;
|
|
||||||
size_t order = 0;
|
size_t order = 0;
|
||||||
size_t executionOrder = 0;
|
size_t executionOrder = 0;
|
||||||
};
|
};
|
||||||
@@ -82,7 +78,7 @@ private:
|
|||||||
|
|
||||||
struct CpuProgram {
|
struct CpuProgram {
|
||||||
SpatCompute op;
|
SpatCompute op;
|
||||||
Block *block = nullptr;
|
Block* block = nullptr;
|
||||||
DenseMap<Value, Value> externalInputMap;
|
DenseMap<Value, Value> externalInputMap;
|
||||||
DenseMap<Value, size_t> weightToIndex;
|
DenseMap<Value, size_t> weightToIndex;
|
||||||
};
|
};
|
||||||
@@ -102,17 +98,17 @@ private:
|
|||||||
size_t sourceOrder = 0;
|
size_t sourceOrder = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
static uint64_t getRemoteSendPairKey(const ChannelInfo &channelInfo) {
|
static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) {
|
||||||
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
|
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
|
||||||
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void appendUniqueValue(SmallVectorImpl<Value> &values, DenseSet<Value> &seen, Value value) {
|
static void appendUniqueValue(SmallVectorImpl<Value>& values, DenseSet<Value>& seen, Value value) {
|
||||||
if (seen.insert(value).second)
|
if (seen.insert(value).second)
|
||||||
values.push_back(value);
|
values.push_back(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isInternalInputOp(Operation *op) {
|
bool isOldComputeResult(Operation* op) {
|
||||||
auto it = isInternalInputOpCache.find(op);
|
auto it = isInternalInputOpCache.find(op);
|
||||||
if (it != isInternalInputOpCache.end())
|
if (it != isInternalInputOpCache.end())
|
||||||
return it->second;
|
return it->second;
|
||||||
@@ -122,10 +118,10 @@ private:
|
|||||||
return isInternalInputOpCache[op] = false;
|
return isInternalInputOpCache[op] = false;
|
||||||
|
|
||||||
for (Value result : extract->getResults()) {
|
for (Value result : extract->getResults()) {
|
||||||
for (Operation *user : result.getUsers()) {
|
for (Operation* user : result.getUsers()) {
|
||||||
if (toEraseSet.contains(user))
|
if (oldComputeOps.contains(user))
|
||||||
continue;
|
continue;
|
||||||
if (isInternalInputOp(user))
|
if (isOldComputeResult(user))
|
||||||
continue;
|
continue;
|
||||||
return isInternalInputOpCache[op] = false;
|
return isInternalInputOpCache[op] = false;
|
||||||
}
|
}
|
||||||
@@ -134,21 +130,22 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void collectInternalInputOps(Value value) {
|
void collectInternalInputOps(Value value) {
|
||||||
Operation *op = value.getDefiningOp();
|
Operation* op = value.getDefiningOp();
|
||||||
|
//TODO ExtractSliceOp is not the only legal host op to traverse! dio
|
||||||
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op)) {
|
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op)) {
|
||||||
if (isInternalInputOp(extract.getOperation()))
|
if (isOldComputeResult(extract.getOperation()))
|
||||||
internalInputOpsToErase.insert(extract.getOperation());
|
internalInputOpsToErase.insert(extract.getOperation());
|
||||||
value = extract.getSource();
|
value = extract.getSource();
|
||||||
op = value.getDefiningOp();
|
op = value.getDefiningOp();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void collectExternalUsers(Operation *op) {
|
void collectExternalUsers(Operation* op) {
|
||||||
if (!externalUsersToMove.insert(op).second)
|
if (!externalUsersToMove.insert(op).second)
|
||||||
return;
|
return;
|
||||||
for (Value result : op->getResults()) {
|
for (Value result : op->getResults()) {
|
||||||
for (Operation *user : result.getUsers()) {
|
for (Operation* user : result.getUsers()) {
|
||||||
if (toEraseSet.contains(user) || isa<func::ReturnOp>(user))
|
if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
|
||||||
continue;
|
continue;
|
||||||
collectExternalUsers(user);
|
collectExternalUsers(user);
|
||||||
}
|
}
|
||||||
@@ -158,10 +155,12 @@ private:
|
|||||||
void collectScheduledTasks() {
|
void collectScheduledTasks() {
|
||||||
size_t nextOrder = 0;
|
size_t nextOrder = 0;
|
||||||
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
||||||
toEraseSet.insert(scheduledInstance.op);
|
oldComputeOps.insert(scheduledInstance.op);
|
||||||
scheduledTasks.push_back(
|
scheduledTasks.push_back({scheduledInstance,
|
||||||
{scheduledInstance, scheduledInstance.op, schedule->computeToCpuMap.lookup(scheduledInstance),
|
scheduledInstance.op,
|
||||||
schedule->computeToCpuSlotMap.lookup(scheduledInstance), nextOrder++});
|
schedule->computeToCpuMap.lookup(scheduledInstance),
|
||||||
|
schedule->computeToCpuSlotMap.lookup(scheduledInstance),
|
||||||
|
nextOrder++});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,42 +170,42 @@ private:
|
|||||||
orderedCpus.push_back(cpu);
|
orderedCpus.push_back(cpu);
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const ScheduledTask &task : scheduledTasks) {
|
for (const ScheduledTask& task : scheduledTasks) {
|
||||||
taskByKey[task.key] = task;
|
taskByComputeInstance[task.computeInstance] = task;
|
||||||
tasksByCpu[task.cpu].push_back(task);
|
tasksByCpu[task.cpu].push_back(task);
|
||||||
markCpuSeen(task.cpu);
|
markCpuSeen(task.cpu);
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::sort(orderedCpus);
|
llvm::sort(orderedCpus);
|
||||||
for (size_t cpu : orderedCpus) {
|
for (size_t cpu : orderedCpus) {
|
||||||
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask &lhs, const ScheduledTask &rhs) {
|
llvm::stable_sort(tasksByCpu[cpu],
|
||||||
if (lhs.slot != rhs.slot)
|
[&](const ScheduledTask& lhs, const ScheduledTask& rhs) { return lhs.order < rhs.order; });
|
||||||
return lhs.slot < rhs.slot;
|
|
||||||
return lhs.order < rhs.order;
|
|
||||||
});
|
|
||||||
for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) {
|
for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) {
|
||||||
task.executionOrder = executionOrder;
|
task.executionOrder = executionOrder;
|
||||||
taskByKey[task.key].executionOrder = executionOrder;
|
taskByComputeInstance[task.computeInstance].executionOrder = executionOrder;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void collectExternalInputsAndWeights() {
|
void collectExternalInputsAndWeights() {
|
||||||
for (size_t cpu : orderedCpus) {
|
for (size_t cpu : orderedCpus) {
|
||||||
for (const ScheduledTask &task : tasksByCpu[cpu]) {
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
auto taskWeights = getComputeInstanceWeights(task.key);
|
auto& thisCpuWeights = cpuWeights[cpu];
|
||||||
|
auto& thisSeenWeights = seenWeightsByCpu[cpu];
|
||||||
|
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||||
for (Value weight : taskWeights)
|
for (Value weight : taskWeights)
|
||||||
appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight);
|
if (thisSeenWeights.insert(weight).second)
|
||||||
|
thisCpuWeights.push_back(weight);
|
||||||
|
|
||||||
auto taskInputs = getComputeInstanceInputs(task.key);
|
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||||
auto &remoteInputs = remoteInputsByTask[task.key];
|
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||||
remoteInputs.resize(taskInputs.size());
|
remoteInputs.resize(taskInputs.size());
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
auto producerRef = getProducerValueRef(input);
|
auto producerRef = getProducerValueRef(input);
|
||||||
if (producerRef) {
|
if (producerRef) {
|
||||||
collectInternalInputOps(input);
|
collectInternalInputOps(input);
|
||||||
auto producerIt = taskByKey.find(producerRef->instance);
|
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||||
if (producerIt != taskByKey.end()) {
|
if (producerIt != taskByComputeInstance.end()) {
|
||||||
if (producerIt->second.cpu != cpu) {
|
if (producerIt->second.cpu != cpu) {
|
||||||
ChannelInfo info {
|
ChannelInfo info {
|
||||||
(*nextChannelId)++,
|
(*nextChannelId)++,
|
||||||
@@ -214,11 +213,11 @@ private:
|
|||||||
getPhysicalCoreId(cpu),
|
getPhysicalCoreId(cpu),
|
||||||
};
|
};
|
||||||
remoteInputs[inputIndex] = info;
|
remoteInputs[inputIndex] = info;
|
||||||
auto &perResultChannels = remoteSendsByTask[producerRef->instance];
|
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||||
if (perResultChannels.empty())
|
if (perResultChannels.empty())
|
||||||
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.key).size());
|
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
|
||||||
perResultChannels[producerRef->resultIndex].push_back(
|
perResultChannels[producerRef->resultIndex].push_back(
|
||||||
{info, task.key, inputIndex, task.executionOrder, 0});
|
{info, task.computeInstance, inputIndex, task.executionOrder, 0});
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -226,19 +225,19 @@ private:
|
|||||||
appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input);
|
appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto taskOutputs = getComputeInstanceOutputValues(task.key);
|
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
||||||
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||||
bool hasExternalUser = false;
|
bool hasExternalUser = false;
|
||||||
for (auto &use : output.getUses()) {
|
for (auto& use : output.getUses()) {
|
||||||
Operation *useOwner = use.getOwner();
|
Operation* useOwner = use.getOwner();
|
||||||
if (toEraseSet.contains(useOwner))
|
if (oldComputeOps.contains(useOwner))
|
||||||
continue;
|
continue;
|
||||||
hasExternalUser = true;
|
hasExternalUser = true;
|
||||||
if (!isa<func::ReturnOp>(useOwner))
|
if (!isa<func::ReturnOp>(useOwner))
|
||||||
collectExternalUsers(useOwner);
|
collectExternalUsers(useOwner);
|
||||||
}
|
}
|
||||||
if (hasExternalUser)
|
if (hasExternalUser)
|
||||||
cpuExternalOutputs[cpu].push_back({task.key, resultIndex});
|
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -248,12 +247,12 @@ private:
|
|||||||
for (size_t cpu : orderedCpus) {
|
for (size_t cpu : orderedCpus) {
|
||||||
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
||||||
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
||||||
for (const ScheduledTask &task : tasksByCpu[cpu]) {
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
auto sendsIt = remoteSendsByTask.find(task.key);
|
auto sendsIt = remoteSendsByTask.find(task.computeInstance);
|
||||||
if (sendsIt == remoteSendsByTask.end())
|
if (sendsIt == remoteSendsByTask.end())
|
||||||
continue;
|
continue;
|
||||||
for (auto &sendInfos : sendsIt->second) {
|
for (auto& sendInfos : sendsIt->second) {
|
||||||
for (RemoteSendInfo &sendInfo : sendInfos) {
|
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
|
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
|
||||||
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
|
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
|
||||||
@@ -269,10 +268,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void planReceiveReordering() {
|
void planReceiveReordering() {
|
||||||
DenseMap<uint64_t, SmallVector<RemoteSendInfo *>> reorderedSendsByPair;
|
DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
|
||||||
for (auto &taskSends : remoteSendsByTask) {
|
for (auto& taskSends : remoteSendsByTask) {
|
||||||
for (auto &sendInfos : taskSends.second) {
|
for (auto& sendInfos : taskSends.second) {
|
||||||
for (RemoteSendInfo &sendInfo : sendInfos) {
|
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
if (pairsNeedingReceiveReorder.contains(pairKey))
|
if (pairsNeedingReceiveReorder.contains(pairKey))
|
||||||
reorderedSendsByPair[pairKey].push_back(&sendInfo);
|
reorderedSendsByPair[pairKey].push_back(&sendInfo);
|
||||||
@@ -280,13 +279,13 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &pairSends : reorderedSendsByPair) {
|
for (auto& pairSends : reorderedSendsByPair) {
|
||||||
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo *lhs, const RemoteSendInfo *rhs) {
|
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
|
||||||
if (lhs->sourceOrder != rhs->sourceOrder)
|
if (lhs->sourceOrder != rhs->sourceOrder)
|
||||||
return lhs->sourceOrder < rhs->sourceOrder;
|
return lhs->sourceOrder < rhs->sourceOrder;
|
||||||
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
|
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
|
||||||
});
|
});
|
||||||
for (RemoteSendInfo *sendInfo : pairSends.second) {
|
for (RemoteSendInfo* sendInfo : pairSends.second) {
|
||||||
int64_t channelId = (*nextChannelId)++;
|
int64_t channelId = (*nextChannelId)++;
|
||||||
sendInfo->channelInfo.channelId = channelId;
|
sendInfo->channelInfo.channelId = channelId;
|
||||||
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
|
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
|
||||||
@@ -297,9 +296,9 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto &taskSends : remoteSendsByTask) {
|
for (const auto& taskSends : remoteSendsByTask) {
|
||||||
for (const auto &sendInfos : taskSends.second) {
|
for (const auto& sendInfos : taskSends.second) {
|
||||||
for (const RemoteSendInfo &sendInfo : sendInfos) {
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
|
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
|
||||||
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
|
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
|
||||||
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||||
@@ -309,9 +308,9 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &taskSends : remoteSendsByTask) {
|
for (auto& taskSends : remoteSendsByTask) {
|
||||||
for (const auto &sendInfos : taskSends.second) {
|
for (const auto& sendInfos : taskSends.second) {
|
||||||
for (const RemoteSendInfo &sendInfo : sendInfos) {
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
if (!pairsNeedingReceiveReorder.contains(pairKey))
|
if (!pairsNeedingReceiveReorder.contains(pairKey))
|
||||||
continue;
|
continue;
|
||||||
@@ -322,9 +321,9 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &cpuQueues : receiveQueuesByCpu) {
|
for (auto& cpuQueues : receiveQueuesByCpu) {
|
||||||
for (auto &pairQueue : cpuQueues.second) {
|
for (auto& pairQueue : cpuQueues.second) {
|
||||||
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry &lhs, const RemoteReceiveEntry &rhs) {
|
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
|
||||||
if (lhs.sourceOrder != rhs.sourceOrder)
|
if (lhs.sourceOrder != rhs.sourceOrder)
|
||||||
return lhs.sourceOrder < rhs.sourceOrder;
|
return lhs.sourceOrder < rhs.sourceOrder;
|
||||||
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
|
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
|
||||||
@@ -344,8 +343,8 @@ private:
|
|||||||
SmallVector<Type> resultTypes;
|
SmallVector<Type> resultTypes;
|
||||||
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
||||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||||
resultTypes.push_back(getComputeInstanceOutputTypes(task.key)[outputRef.resultIndex]);
|
resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(returnOp);
|
rewriter.setInsertionPoint(returnOp);
|
||||||
@@ -362,7 +361,7 @@ private:
|
|||||||
blockArgTypes.push_back(input.getType());
|
blockArgTypes.push_back(input.getType());
|
||||||
blockArgLocs.push_back(loc);
|
blockArgLocs.push_back(loc);
|
||||||
}
|
}
|
||||||
Block *newBlock =
|
Block* newBlock =
|
||||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
|
||||||
CpuProgram program;
|
CpuProgram program;
|
||||||
@@ -373,19 +372,19 @@ private:
|
|||||||
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
||||||
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
||||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
||||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.key)[outputRef.resultIndex]] =
|
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
||||||
newCompute.getResult(resultIndex);
|
newCompute.getResult(resultIndex);
|
||||||
}
|
}
|
||||||
cpuPrograms[cpu] = std::move(program);
|
cpuPrograms[cpu] = std::move(program);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> receiveThroughInput(IRRewriter &rewriter,
|
FailureOr<Value> receiveThroughInput(IRRewriter& rewriter,
|
||||||
size_t cpu,
|
size_t cpu,
|
||||||
DenseMap<uint64_t, size_t> &receiveQueueIndices,
|
DenseMap<uint64_t, size_t>& receiveQueueIndices,
|
||||||
DenseMap<ComputeInstance, SmallVector<Value>> &preReceivedInputsByTask,
|
DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
|
||||||
const ChannelInfo &requestedChannelInfo,
|
const ChannelInfo& requestedChannelInfo,
|
||||||
ComputeInstance requestedConsumer,
|
ComputeInstance requestedConsumer,
|
||||||
size_t requestedInputIndex) {
|
size_t requestedInputIndex) {
|
||||||
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
|
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
|
||||||
@@ -396,26 +395,25 @@ private:
|
|||||||
if (queueIt == cpuQueuesIt->second.end())
|
if (queueIt == cpuQueuesIt->second.end())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto &queue = queueIt->second;
|
auto& queue = queueIt->second;
|
||||||
size_t &queueIndex = receiveQueueIndices[pairKey];
|
size_t& queueIndex = receiveQueueIndices[pairKey];
|
||||||
while (queueIndex < queue.size()) {
|
while (queueIndex < queue.size()) {
|
||||||
const RemoteReceiveEntry &entry = queue[queueIndex++];
|
const RemoteReceiveEntry& entry = queue[queueIndex++];
|
||||||
auto consumerTaskIt = taskByKey.find(entry.consumer);
|
auto consumerTaskIt = taskByComputeInstance.find(entry.consumer);
|
||||||
if (consumerTaskIt == taskByKey.end())
|
if (consumerTaskIt == taskByComputeInstance.end())
|
||||||
return failure();
|
return failure();
|
||||||
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.key);
|
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance);
|
||||||
if (consumerInputs.size() <= entry.inputIndex)
|
if (consumerInputs.size() <= entry.inputIndex)
|
||||||
return failure();
|
return failure();
|
||||||
Type inputType = consumerInputs[entry.inputIndex].getType();
|
Type inputType = consumerInputs[entry.inputIndex].getType();
|
||||||
auto receive =
|
auto receive = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||||
spatial::SpatChannelReceiveOp::create(rewriter,
|
loc,
|
||||||
loc,
|
inputType,
|
||||||
inputType,
|
rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
||||||
rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
|
||||||
rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
|
rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
|
||||||
rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
|
|
||||||
|
|
||||||
auto &receivedInputs = preReceivedInputsByTask[entry.consumer];
|
auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
|
||||||
if (receivedInputs.size() <= entry.inputIndex)
|
if (receivedInputs.size() <= entry.inputIndex)
|
||||||
receivedInputs.resize(entry.inputIndex + 1);
|
receivedInputs.resize(entry.inputIndex + 1);
|
||||||
receivedInputs[entry.inputIndex] = receive.getResult();
|
receivedInputs[entry.inputIndex] = receive.getResult();
|
||||||
@@ -428,7 +426,7 @@ private:
|
|||||||
|
|
||||||
LogicalResult cloneTaskBodies() {
|
LogicalResult cloneTaskBodies() {
|
||||||
for (size_t cpu : orderedCpus) {
|
for (size_t cpu : orderedCpus) {
|
||||||
CpuProgram &program = cpuPrograms[cpu];
|
CpuProgram& program = cpuPrograms[cpu];
|
||||||
IRRewriter rewriter(func.getContext());
|
IRRewriter rewriter(func.getContext());
|
||||||
rewriter.setInsertionPointToEnd(program.block);
|
rewriter.setInsertionPointToEnd(program.block);
|
||||||
DenseMap<uint64_t, size_t> receiveQueueIndices;
|
DenseMap<uint64_t, size_t> receiveQueueIndices;
|
||||||
@@ -444,25 +442,24 @@ private:
|
|||||||
return value;
|
return value;
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const ScheduledTask &task : tasksByCpu[cpu]) {
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.key);
|
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||||
auto taskWeights = getComputeInstanceWeights(task.key);
|
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||||
Block &templateBlock = getComputeInstanceTemplateBlock(task.key);
|
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
|
||||||
|
|
||||||
SmallVector<Value> resolvedInputs;
|
SmallVector<Value> resolvedInputs;
|
||||||
resolvedInputs.reserve(taskInputs.size());
|
resolvedInputs.reserve(taskInputs.size());
|
||||||
auto remoteInputsIt = remoteInputsByTask.find(task.key);
|
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
auto producerRef = getProducerValueRef(input);
|
auto producerRef = getProducerValueRef(input);
|
||||||
if (producerRef) {
|
if (producerRef) {
|
||||||
auto producerIt = taskByKey.find(producerRef->instance);
|
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||||
if (producerIt != taskByKey.end()) {
|
if (producerIt != taskByComputeInstance.end()) {
|
||||||
if (producerIt->second.cpu == cpu) {
|
if (producerIt->second.cpu == cpu) {
|
||||||
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||||
task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization")
|
task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||||
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
||||||
<< " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot
|
|
||||||
<< " producerLaneStart=" << producerRef->instance.laneStart
|
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||||
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||||
return failure();
|
return failure();
|
||||||
@@ -470,20 +467,24 @@ private:
|
|||||||
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const ChannelInfo &channelInfo = *remoteInputsIt->second[inputIndex];
|
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||||
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||||
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||||
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.key, inputIndex)) {
|
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
|
||||||
resolvedInputs.push_back(*preReceived);
|
resolvedInputs.push_back(*preReceived);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
FailureOr<Value> received = receiveThroughInput(
|
FailureOr<Value> received = receiveThroughInput(rewriter,
|
||||||
rewriter, cpu, receiveQueueIndices, preReceivedInputsByTask, channelInfo, task.key, inputIndex);
|
cpu,
|
||||||
|
receiveQueueIndices,
|
||||||
|
preReceivedInputsByTask,
|
||||||
|
channelInfo,
|
||||||
|
task.computeInstance,
|
||||||
|
inputIndex);
|
||||||
if (failed(received)) {
|
if (failed(received)) {
|
||||||
task.sourceOp->emitOpError("failed to materialize reordered remote receive")
|
task.sourceOp->emitOpError("failed to materialize reordered remote receive")
|
||||||
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
|
||||||
<< " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId
|
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
|
||||||
<< " channelId=" << channelInfo.channelId;
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
resolvedInputs.push_back(*received);
|
resolvedInputs.push_back(*received);
|
||||||
@@ -510,14 +511,14 @@ private:
|
|||||||
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
|
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
|
||||||
mapper.map(oldArg, resolvedInputs[argIndex]);
|
mapper.map(oldArg, resolvedInputs[argIndex]);
|
||||||
|
|
||||||
for (Operation &op : templateBlock) {
|
for (Operation& op : templateBlock) {
|
||||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
for (Value yieldOperand : yield.getOperands())
|
for (Value yieldOperand : yield.getOperands())
|
||||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation *clonedOp = rewriter.clone(op, mapper);
|
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||||
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||||
@@ -529,24 +530,24 @@ private:
|
|||||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) {
|
else {
|
||||||
|
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
if (templateBlock.getNumArguments() == 1)
|
if (templateBlock.getNumArguments() == 1)
|
||||||
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
||||||
|
|
||||||
for (Operation &op : templateBlock) {
|
for (Operation& op : templateBlock) {
|
||||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
for (Value yieldOperand : yield.getOperands())
|
for (Value yieldOperand : yield.getOperands())
|
||||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation *clonedOp = rewriter.clone(op, mapper);
|
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||||
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||||
task.sourceOp->emitOpError(
|
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
"batched per-cpu merge materialization expects lane-local weight index 0");
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||||
@@ -554,8 +555,7 @@ private:
|
|||||||
}
|
}
|
||||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||||
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||||
task.sourceOp->emitOpError(
|
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
"batched per-cpu merge materialization expects lane-local weight index 0");
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||||
@@ -565,13 +565,13 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
producedValuesByTask[task.key] = taskYieldValues;
|
producedValuesByTask[task.computeInstance] = taskYieldValues;
|
||||||
if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) {
|
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
|
||||||
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
||||||
if (sendInfos.empty())
|
if (sendInfos.empty())
|
||||||
continue;
|
continue;
|
||||||
Value producedValue = taskYieldValues[resultIndex];
|
Value producedValue = taskYieldValues[resultIndex];
|
||||||
for (const RemoteSendInfo &sendInfo : sendInfos) {
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
spatial::SpatChannelSendOp::create(rewriter,
|
spatial::SpatChannelSendOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
|
rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
|
||||||
@@ -588,9 +588,9 @@ private:
|
|||||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||||
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||||
task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization")
|
task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||||
<< " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart;
|
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||||
@@ -603,31 +603,31 @@ private:
|
|||||||
|
|
||||||
void replaceExternalUses() {
|
void replaceExternalUses() {
|
||||||
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
||||||
for (auto &use : llvm::make_early_inc_range(oldValue.getUses()))
|
for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
|
||||||
if (!toEraseSet.contains(use.getOwner()))
|
if (!oldComputeOps.contains(use.getOwner()))
|
||||||
use.assign(newValue);
|
use.assign(newValue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult eraseOldScheduledOps() {
|
LogicalResult eraseOldScheduledOps() {
|
||||||
DenseSet<Operation *> allOpsToErase = toEraseSet;
|
DenseSet<Operation*> allOpsToErase = oldComputeOps;
|
||||||
for (Operation *op : internalInputOpsToErase)
|
for (Operation* op : internalInputOpsToErase)
|
||||||
allOpsToErase.insert(op);
|
allOpsToErase.insert(op);
|
||||||
|
|
||||||
SmallVector<Operation *> orderedOpsToErase;
|
SmallVector<Operation*> orderedOpsToErase;
|
||||||
for (Operation &op : func.getBody().front())
|
for (Operation& op : func.getBody().front())
|
||||||
if (allOpsToErase.contains(&op))
|
if (allOpsToErase.contains(&op))
|
||||||
orderedOpsToErase.push_back(&op);
|
orderedOpsToErase.push_back(&op);
|
||||||
|
|
||||||
for (Operation *op : llvm::reverse(orderedOpsToErase)) {
|
for (Operation* op : llvm::reverse(orderedOpsToErase)) {
|
||||||
SmallVector<Operation *> remainingUsers;
|
SmallVector<Operation*> remainingUsers;
|
||||||
for (Value result : op->getResults())
|
for (Value result : op->getResults())
|
||||||
for (Operation *user : result.getUsers())
|
for (Operation* user : result.getUsers())
|
||||||
remainingUsers.push_back(user);
|
remainingUsers.push_back(user);
|
||||||
if (!remainingUsers.empty()) {
|
if (!remainingUsers.empty()) {
|
||||||
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
|
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
|
||||||
<< "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no");
|
<< "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no");
|
||||||
for (Operation *user : remainingUsers) {
|
for (Operation* user : remainingUsers) {
|
||||||
diagnostic.attachNote(user->getLoc())
|
diagnostic.attachNote(user->getLoc())
|
||||||
<< "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no");
|
<< "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no");
|
||||||
}
|
}
|
||||||
@@ -640,32 +640,32 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void moveExternalUsersBeforeReturn() {
|
void moveExternalUsersBeforeReturn() {
|
||||||
SmallVector<Operation *> orderedUsersToMove;
|
SmallVector<Operation*> orderedUsersToMove;
|
||||||
for (Operation &op : func.getBody().front()) {
|
for (Operation& op : func.getBody().front()) {
|
||||||
if (&op == returnOp.getOperation())
|
if (&op == returnOp.getOperation())
|
||||||
break;
|
break;
|
||||||
if (externalUsersToMove.contains(&op))
|
if (externalUsersToMove.contains(&op))
|
||||||
orderedUsersToMove.push_back(&op);
|
orderedUsersToMove.push_back(&op);
|
||||||
}
|
}
|
||||||
for (Operation *op : orderedUsersToMove)
|
for (Operation* op : orderedUsersToMove)
|
||||||
op->moveBefore(returnOp);
|
op->moveBefore(returnOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
func::FuncOp func;
|
func::FuncOp func;
|
||||||
const MergeScheduleResult *schedule = nullptr;
|
const MergeScheduleResult* schedule = nullptr;
|
||||||
int64_t *nextChannelId = nullptr;
|
int64_t* nextChannelId = nullptr;
|
||||||
Location loc;
|
Location loc;
|
||||||
func::ReturnOp returnOp;
|
func::ReturnOp returnOp;
|
||||||
|
|
||||||
SmallVector<ScheduledTask> scheduledTasks;
|
SmallVector<ScheduledTask> scheduledTasks;
|
||||||
DenseSet<Operation *> toEraseSet;
|
DenseSet<Operation*> oldComputeOps;
|
||||||
DenseMap<ComputeInstance, ScheduledTask> taskByKey;
|
DenseMap<ComputeInstance, ScheduledTask> taskByComputeInstance;
|
||||||
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||||
SmallVector<size_t> orderedCpus;
|
SmallVector<size_t> orderedCpus;
|
||||||
DenseSet<size_t> seenCpus;
|
DenseSet<size_t> seenCpus;
|
||||||
DenseSet<Operation *> internalInputOpsToErase;
|
DenseSet<Operation*> internalInputOpsToErase;
|
||||||
DenseMap<Operation *, bool> isInternalInputOpCache;
|
DenseMap<Operation*, bool> isInternalInputOpCache;
|
||||||
DenseSet<Operation *> externalUsersToMove;
|
DenseSet<Operation*> externalUsersToMove;
|
||||||
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||||
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||||
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||||
@@ -683,7 +683,7 @@ private:
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId) {
|
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
|
||||||
return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId);
|
return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user