Materialize modification
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-18 17:22:13 +02:00
parent aa088e2ba5
commit 34c29fdec4
@@ -1,5 +1,3 @@
#include "MaterializeMergeSchedule.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
@@ -16,6 +14,7 @@
#include <optional> #include <optional>
#include <utility> #include <utility>
#include "MaterializeMergeSchedule.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -41,11 +40,9 @@ static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast<int32
class MergeScheduleMaterializerImpl { class MergeScheduleMaterializerImpl {
public: public:
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp) explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
: func(funcOp), : func(funcOp), loc(funcOp.getLoc()), returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
loc(funcOp.getLoc()),
returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
LogicalResult run(const MergeScheduleResult &scheduleResult, int64_t &nextChannelIdRef) { LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) {
schedule = &scheduleResult; schedule = &scheduleResult;
nextChannelId = &nextChannelIdRef; nextChannelId = &nextChannelIdRef;
@@ -66,10 +63,9 @@ public:
private: private:
struct ScheduledTask { struct ScheduledTask {
ComputeInstance key; ComputeInstance computeInstance;
Operation *sourceOp = nullptr; Operation* sourceOp = nullptr;
size_t cpu = 0; size_t cpu = 0;
size_t slot = 0;
size_t order = 0; size_t order = 0;
size_t executionOrder = 0; size_t executionOrder = 0;
}; };
@@ -82,7 +78,7 @@ private:
struct CpuProgram { struct CpuProgram {
SpatCompute op; SpatCompute op;
Block *block = nullptr; Block* block = nullptr;
DenseMap<Value, Value> externalInputMap; DenseMap<Value, Value> externalInputMap;
DenseMap<Value, size_t> weightToIndex; DenseMap<Value, size_t> weightToIndex;
}; };
@@ -102,17 +98,17 @@ private:
size_t sourceOrder = 0; size_t sourceOrder = 0;
}; };
static uint64_t getRemoteSendPairKey(const ChannelInfo &channelInfo) { static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) {
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32) return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
| static_cast<uint32_t>(channelInfo.targetCoreId); | static_cast<uint32_t>(channelInfo.targetCoreId);
} }
static void appendUniqueValue(SmallVectorImpl<Value> &values, DenseSet<Value> &seen, Value value) { static void appendUniqueValue(SmallVectorImpl<Value>& values, DenseSet<Value>& seen, Value value) {
if (seen.insert(value).second) if (seen.insert(value).second)
values.push_back(value); values.push_back(value);
} }
bool isInternalInputOp(Operation *op) { bool isOldComputeResult(Operation* op) {
auto it = isInternalInputOpCache.find(op); auto it = isInternalInputOpCache.find(op);
if (it != isInternalInputOpCache.end()) if (it != isInternalInputOpCache.end())
return it->second; return it->second;
@@ -122,10 +118,10 @@ private:
return isInternalInputOpCache[op] = false; return isInternalInputOpCache[op] = false;
for (Value result : extract->getResults()) { for (Value result : extract->getResults()) {
for (Operation *user : result.getUsers()) { for (Operation* user : result.getUsers()) {
if (toEraseSet.contains(user)) if (oldComputeOps.contains(user))
continue; continue;
if (isInternalInputOp(user)) if (isOldComputeResult(user))
continue; continue;
return isInternalInputOpCache[op] = false; return isInternalInputOpCache[op] = false;
} }
@@ -134,21 +130,22 @@ private:
} }
void collectInternalInputOps(Value value) { void collectInternalInputOps(Value value) {
Operation *op = value.getDefiningOp(); Operation* op = value.getDefiningOp();
//TODO ExtractSliceOp is not the only legal host op to traverse! dio
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op)) { while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op)) {
if (isInternalInputOp(extract.getOperation())) if (isOldComputeResult(extract.getOperation()))
internalInputOpsToErase.insert(extract.getOperation()); internalInputOpsToErase.insert(extract.getOperation());
value = extract.getSource(); value = extract.getSource();
op = value.getDefiningOp(); op = value.getDefiningOp();
} }
} }
void collectExternalUsers(Operation *op) { void collectExternalUsers(Operation* op) {
if (!externalUsersToMove.insert(op).second) if (!externalUsersToMove.insert(op).second)
return; return;
for (Value result : op->getResults()) { for (Value result : op->getResults()) {
for (Operation *user : result.getUsers()) { for (Operation* user : result.getUsers()) {
if (toEraseSet.contains(user) || isa<func::ReturnOp>(user)) if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
continue; continue;
collectExternalUsers(user); collectExternalUsers(user);
} }
@@ -158,10 +155,12 @@ private:
void collectScheduledTasks() { void collectScheduledTasks() {
size_t nextOrder = 0; size_t nextOrder = 0;
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) { for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
toEraseSet.insert(scheduledInstance.op); oldComputeOps.insert(scheduledInstance.op);
scheduledTasks.push_back( scheduledTasks.push_back({scheduledInstance,
{scheduledInstance, scheduledInstance.op, schedule->computeToCpuMap.lookup(scheduledInstance), scheduledInstance.op,
schedule->computeToCpuSlotMap.lookup(scheduledInstance), nextOrder++}); schedule->computeToCpuMap.lookup(scheduledInstance),
schedule->computeToCpuSlotMap.lookup(scheduledInstance),
nextOrder++});
} }
} }
@@ -171,42 +170,42 @@ private:
orderedCpus.push_back(cpu); orderedCpus.push_back(cpu);
}; };
for (const ScheduledTask &task : scheduledTasks) { for (const ScheduledTask& task : scheduledTasks) {
taskByKey[task.key] = task; taskByComputeInstance[task.computeInstance] = task;
tasksByCpu[task.cpu].push_back(task); tasksByCpu[task.cpu].push_back(task);
markCpuSeen(task.cpu); markCpuSeen(task.cpu);
} }
llvm::sort(orderedCpus); llvm::sort(orderedCpus);
for (size_t cpu : orderedCpus) { for (size_t cpu : orderedCpus) {
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask &lhs, const ScheduledTask &rhs) { llvm::stable_sort(tasksByCpu[cpu],
if (lhs.slot != rhs.slot) [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { return lhs.order < rhs.order; });
return lhs.slot < rhs.slot;
return lhs.order < rhs.order;
});
for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) { for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) {
task.executionOrder = executionOrder; task.executionOrder = executionOrder;
taskByKey[task.key].executionOrder = executionOrder; taskByComputeInstance[task.computeInstance].executionOrder = executionOrder;
} }
} }
} }
void collectExternalInputsAndWeights() { void collectExternalInputsAndWeights() {
for (size_t cpu : orderedCpus) { for (size_t cpu : orderedCpus) {
for (const ScheduledTask &task : tasksByCpu[cpu]) { for (const ScheduledTask& task : tasksByCpu[cpu]) {
auto taskWeights = getComputeInstanceWeights(task.key); auto& thisCpuWeights = cpuWeights[cpu];
auto& thisSeenWeights = seenWeightsByCpu[cpu];
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
for (Value weight : taskWeights) for (Value weight : taskWeights)
appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight); if (thisSeenWeights.insert(weight).second)
thisCpuWeights.push_back(weight);
auto taskInputs = getComputeInstanceInputs(task.key); auto taskInputs = getComputeInstanceInputs(task.computeInstance);
auto &remoteInputs = remoteInputsByTask[task.key]; auto& remoteInputs = remoteInputsByTask[task.computeInstance];
remoteInputs.resize(taskInputs.size()); remoteInputs.resize(taskInputs.size());
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
auto producerRef = getProducerValueRef(input); auto producerRef = getProducerValueRef(input);
if (producerRef) { if (producerRef) {
collectInternalInputOps(input); collectInternalInputOps(input);
auto producerIt = taskByKey.find(producerRef->instance); auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByKey.end()) { if (producerIt != taskByComputeInstance.end()) {
if (producerIt->second.cpu != cpu) { if (producerIt->second.cpu != cpu) {
ChannelInfo info { ChannelInfo info {
(*nextChannelId)++, (*nextChannelId)++,
@@ -214,11 +213,11 @@ private:
getPhysicalCoreId(cpu), getPhysicalCoreId(cpu),
}; };
remoteInputs[inputIndex] = info; remoteInputs[inputIndex] = info;
auto &perResultChannels = remoteSendsByTask[producerRef->instance]; auto& perResultChannels = remoteSendsByTask[producerRef->instance];
if (perResultChannels.empty()) if (perResultChannels.empty())
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.key).size()); perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
perResultChannels[producerRef->resultIndex].push_back( perResultChannels[producerRef->resultIndex].push_back(
{info, task.key, inputIndex, task.executionOrder, 0}); {info, task.computeInstance, inputIndex, task.executionOrder, 0});
} }
continue; continue;
} }
@@ -226,19 +225,19 @@ private:
appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input); appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input);
} }
auto taskOutputs = getComputeInstanceOutputValues(task.key); auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) { for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
bool hasExternalUser = false; bool hasExternalUser = false;
for (auto &use : output.getUses()) { for (auto& use : output.getUses()) {
Operation *useOwner = use.getOwner(); Operation* useOwner = use.getOwner();
if (toEraseSet.contains(useOwner)) if (oldComputeOps.contains(useOwner))
continue; continue;
hasExternalUser = true; hasExternalUser = true;
if (!isa<func::ReturnOp>(useOwner)) if (!isa<func::ReturnOp>(useOwner))
collectExternalUsers(useOwner); collectExternalUsers(useOwner);
} }
if (hasExternalUser) if (hasExternalUser)
cpuExternalOutputs[cpu].push_back({task.key, resultIndex}); cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
} }
} }
} }
@@ -248,12 +247,12 @@ private:
for (size_t cpu : orderedCpus) { for (size_t cpu : orderedCpus) {
DenseMap<uint64_t, size_t> nextSourceOrderByPair; DenseMap<uint64_t, size_t> nextSourceOrderByPair;
DenseMap<uint64_t, size_t> lastConsumerOrderByPair; DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
for (const ScheduledTask &task : tasksByCpu[cpu]) { for (const ScheduledTask& task : tasksByCpu[cpu]) {
auto sendsIt = remoteSendsByTask.find(task.key); auto sendsIt = remoteSendsByTask.find(task.computeInstance);
if (sendsIt == remoteSendsByTask.end()) if (sendsIt == remoteSendsByTask.end())
continue; continue;
for (auto &sendInfos : sendsIt->second) { for (auto& sendInfos : sendsIt->second) {
for (RemoteSendInfo &sendInfo : sendInfos) { for (RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++; sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder); auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
@@ -269,10 +268,10 @@ private:
} }
void planReceiveReordering() { void planReceiveReordering() {
DenseMap<uint64_t, SmallVector<RemoteSendInfo *>> reorderedSendsByPair; DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
for (auto &taskSends : remoteSendsByTask) { for (auto& taskSends : remoteSendsByTask) {
for (auto &sendInfos : taskSends.second) { for (auto& sendInfos : taskSends.second) {
for (RemoteSendInfo &sendInfo : sendInfos) { for (RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey)) if (pairsNeedingReceiveReorder.contains(pairKey))
reorderedSendsByPair[pairKey].push_back(&sendInfo); reorderedSendsByPair[pairKey].push_back(&sendInfo);
@@ -280,13 +279,13 @@ private:
} }
} }
for (auto &pairSends : reorderedSendsByPair) { for (auto& pairSends : reorderedSendsByPair) {
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo *lhs, const RemoteSendInfo *rhs) { llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
if (lhs->sourceOrder != rhs->sourceOrder) if (lhs->sourceOrder != rhs->sourceOrder)
return lhs->sourceOrder < rhs->sourceOrder; return lhs->sourceOrder < rhs->sourceOrder;
return lhs->channelInfo.channelId < rhs->channelInfo.channelId; return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
}); });
for (RemoteSendInfo *sendInfo : pairSends.second) { for (RemoteSendInfo* sendInfo : pairSends.second) {
int64_t channelId = (*nextChannelId)++; int64_t channelId = (*nextChannelId)++;
sendInfo->channelInfo.channelId = channelId; sendInfo->channelInfo.channelId = channelId;
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer); auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
@@ -297,9 +296,9 @@ private:
} }
} }
for (const auto &taskSends : remoteSendsByTask) { for (const auto& taskSends : remoteSendsByTask) {
for (const auto &sendInfos : taskSends.second) { for (const auto& sendInfos : taskSends.second) {
for (const RemoteSendInfo &sendInfo : sendInfos) { for (const RemoteSendInfo& sendInfo : sendInfos) {
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer); auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send"); assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range"); assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
@@ -309,9 +308,9 @@ private:
} }
} }
for (auto &taskSends : remoteSendsByTask) { for (auto& taskSends : remoteSendsByTask) {
for (const auto &sendInfos : taskSends.second) { for (const auto& sendInfos : taskSends.second) {
for (const RemoteSendInfo &sendInfo : sendInfos) { for (const RemoteSendInfo& sendInfo : sendInfos) {
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
if (!pairsNeedingReceiveReorder.contains(pairKey)) if (!pairsNeedingReceiveReorder.contains(pairKey))
continue; continue;
@@ -322,9 +321,9 @@ private:
} }
} }
for (auto &cpuQueues : receiveQueuesByCpu) { for (auto& cpuQueues : receiveQueuesByCpu) {
for (auto &pairQueue : cpuQueues.second) { for (auto& pairQueue : cpuQueues.second) {
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry &lhs, const RemoteReceiveEntry &rhs) { llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
if (lhs.sourceOrder != rhs.sourceOrder) if (lhs.sourceOrder != rhs.sourceOrder)
return lhs.sourceOrder < rhs.sourceOrder; return lhs.sourceOrder < rhs.sourceOrder;
return lhs.channelInfo.channelId < rhs.channelInfo.channelId; return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
@@ -344,8 +343,8 @@ private:
SmallVector<Type> resultTypes; SmallVector<Type> resultTypes;
resultTypes.reserve(cpuExternalOutputs[cpu].size()); resultTypes.reserve(cpuExternalOutputs[cpu].size());
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
ScheduledTask task = taskByKey.at(outputRef.instance); ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
resultTypes.push_back(getComputeInstanceOutputTypes(task.key)[outputRef.resultIndex]); resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]);
} }
rewriter.setInsertionPoint(returnOp); rewriter.setInsertionPoint(returnOp);
@@ -362,7 +361,7 @@ private:
blockArgTypes.push_back(input.getType()); blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(loc); blockArgLocs.push_back(loc);
} }
Block *newBlock = Block* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
CpuProgram program; CpuProgram program;
@@ -373,19 +372,19 @@ private:
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu])) for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
program.externalInputMap[input] = newBlock->getArgument(inputIndex); program.externalInputMap[input] = newBlock->getArgument(inputIndex);
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) { for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
ScheduledTask task = taskByKey.at(outputRef.instance); ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.key)[outputRef.resultIndex]] = oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
newCompute.getResult(resultIndex); newCompute.getResult(resultIndex);
} }
cpuPrograms[cpu] = std::move(program); cpuPrograms[cpu] = std::move(program);
} }
} }
FailureOr<Value> receiveThroughInput(IRRewriter &rewriter, FailureOr<Value> receiveThroughInput(IRRewriter& rewriter,
size_t cpu, size_t cpu,
DenseMap<uint64_t, size_t> &receiveQueueIndices, DenseMap<uint64_t, size_t>& receiveQueueIndices,
DenseMap<ComputeInstance, SmallVector<Value>> &preReceivedInputsByTask, DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
const ChannelInfo &requestedChannelInfo, const ChannelInfo& requestedChannelInfo,
ComputeInstance requestedConsumer, ComputeInstance requestedConsumer,
size_t requestedInputIndex) { size_t requestedInputIndex) {
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo); uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
@@ -396,26 +395,25 @@ private:
if (queueIt == cpuQueuesIt->second.end()) if (queueIt == cpuQueuesIt->second.end())
return failure(); return failure();
auto &queue = queueIt->second; auto& queue = queueIt->second;
size_t &queueIndex = receiveQueueIndices[pairKey]; size_t& queueIndex = receiveQueueIndices[pairKey];
while (queueIndex < queue.size()) { while (queueIndex < queue.size()) {
const RemoteReceiveEntry &entry = queue[queueIndex++]; const RemoteReceiveEntry& entry = queue[queueIndex++];
auto consumerTaskIt = taskByKey.find(entry.consumer); auto consumerTaskIt = taskByComputeInstance.find(entry.consumer);
if (consumerTaskIt == taskByKey.end()) if (consumerTaskIt == taskByComputeInstance.end())
return failure(); return failure();
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.key); SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance);
if (consumerInputs.size() <= entry.inputIndex) if (consumerInputs.size() <= entry.inputIndex)
return failure(); return failure();
Type inputType = consumerInputs[entry.inputIndex].getType(); Type inputType = consumerInputs[entry.inputIndex].getType();
auto receive = auto receive = spatial::SpatChannelReceiveOp::create(rewriter,
spatial::SpatChannelReceiveOp::create(rewriter,
loc, loc,
inputType, inputType,
rewriter.getI64IntegerAttr(entry.channelInfo.channelId), rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId), rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId)); rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
auto &receivedInputs = preReceivedInputsByTask[entry.consumer]; auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
if (receivedInputs.size() <= entry.inputIndex) if (receivedInputs.size() <= entry.inputIndex)
receivedInputs.resize(entry.inputIndex + 1); receivedInputs.resize(entry.inputIndex + 1);
receivedInputs[entry.inputIndex] = receive.getResult(); receivedInputs[entry.inputIndex] = receive.getResult();
@@ -428,7 +426,7 @@ private:
LogicalResult cloneTaskBodies() { LogicalResult cloneTaskBodies() {
for (size_t cpu : orderedCpus) { for (size_t cpu : orderedCpus) {
CpuProgram &program = cpuPrograms[cpu]; CpuProgram& program = cpuPrograms[cpu];
IRRewriter rewriter(func.getContext()); IRRewriter rewriter(func.getContext());
rewriter.setInsertionPointToEnd(program.block); rewriter.setInsertionPointToEnd(program.block);
DenseMap<uint64_t, size_t> receiveQueueIndices; DenseMap<uint64_t, size_t> receiveQueueIndices;
@@ -444,25 +442,24 @@ private:
return value; return value;
}; };
for (const ScheduledTask &task : tasksByCpu[cpu]) { for (const ScheduledTask& task : tasksByCpu[cpu]) {
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.key); SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
auto taskWeights = getComputeInstanceWeights(task.key); auto taskWeights = getComputeInstanceWeights(task.computeInstance);
Block &templateBlock = getComputeInstanceTemplateBlock(task.key); Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
SmallVector<Value> resolvedInputs; SmallVector<Value> resolvedInputs;
resolvedInputs.reserve(taskInputs.size()); resolvedInputs.reserve(taskInputs.size());
auto remoteInputsIt = remoteInputsByTask.find(task.key); auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
auto producerRef = getProducerValueRef(input); auto producerRef = getProducerValueRef(input);
if (producerRef) { if (producerRef) {
auto producerIt = taskByKey.find(producerRef->instance); auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByKey.end()) { if (producerIt != taskByComputeInstance.end()) {
if (producerIt->second.cpu == cpu) { if (producerIt->second.cpu == cpu) {
auto producedIt = producedValuesByTask.find(producerRef->instance); auto producedIt = producedValuesByTask.find(producerRef->instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) { if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization") task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization")
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot << " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
<< " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot
<< " producerLaneStart=" << producerRef->instance.laneStart << " producerLaneStart=" << producerRef->instance.laneStart
<< " producerLaneCount=" << producerRef->instance.laneCount; << " producerLaneCount=" << producerRef->instance.laneCount;
return failure(); return failure();
@@ -470,20 +467,24 @@ private:
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]); resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
continue; continue;
} }
const ChannelInfo &channelInfo = *remoteInputsIt->second[inputIndex]; const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
uint64_t pairKey = getRemoteSendPairKey(channelInfo); uint64_t pairKey = getRemoteSendPairKey(channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey)) { if (pairsNeedingReceiveReorder.contains(pairKey)) {
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.key, inputIndex)) { if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
resolvedInputs.push_back(*preReceived); resolvedInputs.push_back(*preReceived);
continue; continue;
} }
FailureOr<Value> received = receiveThroughInput( FailureOr<Value> received = receiveThroughInput(rewriter,
rewriter, cpu, receiveQueueIndices, preReceivedInputsByTask, channelInfo, task.key, inputIndex); cpu,
receiveQueueIndices,
preReceivedInputsByTask,
channelInfo,
task.computeInstance,
inputIndex);
if (failed(received)) { if (failed(received)) {
task.sourceOp->emitOpError("failed to materialize reordered remote receive") task.sourceOp->emitOpError("failed to materialize reordered remote receive")
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot << " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
<< " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId << " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
<< " channelId=" << channelInfo.channelId;
return failure(); return failure();
} }
resolvedInputs.push_back(*received); resolvedInputs.push_back(*received);
@@ -510,14 +511,14 @@ private:
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments())) for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
mapper.map(oldArg, resolvedInputs[argIndex]); mapper.map(oldArg, resolvedInputs[argIndex]);
for (Operation &op : templateBlock) { for (Operation& op : templateBlock) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) { if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
for (Value yieldOperand : yield.getOperands()) for (Value yieldOperand : yield.getOperands())
taskYieldValues.push_back(mapper.lookup(yieldOperand)); taskYieldValues.push_back(mapper.lookup(yieldOperand));
continue; continue;
} }
Operation *clonedOp = rewriter.clone(op, mapper); Operation* clonedOp = rewriter.clone(op, mapper);
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) { if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp); auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()]; Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
@@ -529,24 +530,24 @@ private:
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight)); newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
} }
} }
} else { }
for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) { else {
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
IRMapping mapper; IRMapping mapper;
if (templateBlock.getNumArguments() == 1) if (templateBlock.getNumArguments() == 1)
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]); mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
for (Operation &op : templateBlock) { for (Operation& op : templateBlock) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) { if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
for (Value yieldOperand : yield.getOperands()) for (Value yieldOperand : yield.getOperands())
taskYieldValues.push_back(mapper.lookup(yieldOperand)); taskYieldValues.push_back(mapper.lookup(yieldOperand));
continue; continue;
} }
Operation *clonedOp = rewriter.clone(op, mapper); Operation* clonedOp = rewriter.clone(op, mapper);
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) { if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
if (oldWeightedMvmOp.getWeightIndex() != 0) { if (oldWeightedMvmOp.getWeightIndex() != 0) {
task.sourceOp->emitOpError( task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
"batched per-cpu merge materialization expects lane-local weight index 0");
return failure(); return failure();
} }
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp); auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
@@ -554,8 +555,7 @@ private:
} }
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) { if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
if (oldWeightedVmmOp.getWeightIndex() != 0) { if (oldWeightedVmmOp.getWeightIndex() != 0) {
task.sourceOp->emitOpError( task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
"batched per-cpu merge materialization expects lane-local weight index 0");
return failure(); return failure();
} }
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp); auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
@@ -565,13 +565,13 @@ private:
} }
} }
producedValuesByTask[task.key] = taskYieldValues; producedValuesByTask[task.computeInstance] = taskYieldValues;
if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) { if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) { for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
if (sendInfos.empty()) if (sendInfos.empty())
continue; continue;
Value producedValue = taskYieldValues[resultIndex]; Value producedValue = taskYieldValues[resultIndex];
for (const RemoteSendInfo &sendInfo : sendInfos) { for (const RemoteSendInfo& sendInfo : sendInfos) {
spatial::SpatChannelSendOp::create(rewriter, spatial::SpatChannelSendOp::create(rewriter,
loc, loc,
rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId), rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
@@ -588,9 +588,9 @@ private:
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
auto producedIt = producedValuesByTask.find(outputRef.instance); auto producedIt = producedValuesByTask.find(outputRef.instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) { if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
ScheduledTask task = taskByKey.at(outputRef.instance); ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization") task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization")
<< " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart; << " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
return failure(); return failure();
} }
yieldValues.push_back(producedIt->second[outputRef.resultIndex]); yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
@@ -603,31 +603,31 @@ private:
void replaceExternalUses() { void replaceExternalUses() {
for (auto [oldValue, newValue] : oldToNewExternalValueMap) { for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
for (auto &use : llvm::make_early_inc_range(oldValue.getUses())) for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
if (!toEraseSet.contains(use.getOwner())) if (!oldComputeOps.contains(use.getOwner()))
use.assign(newValue); use.assign(newValue);
} }
} }
LogicalResult eraseOldScheduledOps() { LogicalResult eraseOldScheduledOps() {
DenseSet<Operation *> allOpsToErase = toEraseSet; DenseSet<Operation*> allOpsToErase = oldComputeOps;
for (Operation *op : internalInputOpsToErase) for (Operation* op : internalInputOpsToErase)
allOpsToErase.insert(op); allOpsToErase.insert(op);
SmallVector<Operation *> orderedOpsToErase; SmallVector<Operation*> orderedOpsToErase;
for (Operation &op : func.getBody().front()) for (Operation& op : func.getBody().front())
if (allOpsToErase.contains(&op)) if (allOpsToErase.contains(&op))
orderedOpsToErase.push_back(&op); orderedOpsToErase.push_back(&op);
for (Operation *op : llvm::reverse(orderedOpsToErase)) { for (Operation* op : llvm::reverse(orderedOpsToErase)) {
SmallVector<Operation *> remainingUsers; SmallVector<Operation*> remainingUsers;
for (Value result : op->getResults()) for (Value result : op->getResults())
for (Operation *user : result.getUsers()) for (Operation* user : result.getUsers())
remainingUsers.push_back(user); remainingUsers.push_back(user);
if (!remainingUsers.empty()) { if (!remainingUsers.empty()) {
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup") InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
<< "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no"); << "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no");
for (Operation *user : remainingUsers) { for (Operation* user : remainingUsers) {
diagnostic.attachNote(user->getLoc()) diagnostic.attachNote(user->getLoc())
<< "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no"); << "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no");
} }
@@ -640,32 +640,32 @@ private:
} }
void moveExternalUsersBeforeReturn() { void moveExternalUsersBeforeReturn() {
SmallVector<Operation *> orderedUsersToMove; SmallVector<Operation*> orderedUsersToMove;
for (Operation &op : func.getBody().front()) { for (Operation& op : func.getBody().front()) {
if (&op == returnOp.getOperation()) if (&op == returnOp.getOperation())
break; break;
if (externalUsersToMove.contains(&op)) if (externalUsersToMove.contains(&op))
orderedUsersToMove.push_back(&op); orderedUsersToMove.push_back(&op);
} }
for (Operation *op : orderedUsersToMove) for (Operation* op : orderedUsersToMove)
op->moveBefore(returnOp); op->moveBefore(returnOp);
} }
func::FuncOp func; func::FuncOp func;
const MergeScheduleResult *schedule = nullptr; const MergeScheduleResult* schedule = nullptr;
int64_t *nextChannelId = nullptr; int64_t* nextChannelId = nullptr;
Location loc; Location loc;
func::ReturnOp returnOp; func::ReturnOp returnOp;
SmallVector<ScheduledTask> scheduledTasks; SmallVector<ScheduledTask> scheduledTasks;
DenseSet<Operation *> toEraseSet; DenseSet<Operation*> oldComputeOps;
DenseMap<ComputeInstance, ScheduledTask> taskByKey; DenseMap<ComputeInstance, ScheduledTask> taskByComputeInstance;
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu; DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
SmallVector<size_t> orderedCpus; SmallVector<size_t> orderedCpus;
DenseSet<size_t> seenCpus; DenseSet<size_t> seenCpus;
DenseSet<Operation *> internalInputOpsToErase; DenseSet<Operation*> internalInputOpsToErase;
DenseMap<Operation *, bool> isInternalInputOpCache; DenseMap<Operation*, bool> isInternalInputOpCache;
DenseSet<Operation *> externalUsersToMove; DenseSet<Operation*> externalUsersToMove;
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask; DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask; DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs; DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
@@ -683,7 +683,7 @@ private:
} // namespace } // namespace
LogicalResult LogicalResult
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId) { MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId); return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId);
} }