diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 298e9c8..8dabf3f 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1,5 +1,3 @@ -#include "MaterializeMergeSchedule.hpp" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" @@ -16,6 +14,7 @@ #include #include +#include "MaterializeMergeSchedule.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -41,11 +40,9 @@ static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast(funcOp.getBody().front().getTerminator())) {} + : func(funcOp), loc(funcOp.getLoc()), returnOp(cast(funcOp.getBody().front().getTerminator())) {} - LogicalResult run(const MergeScheduleResult &scheduleResult, int64_t &nextChannelIdRef) { + LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) { schedule = &scheduleResult; nextChannelId = &nextChannelIdRef; @@ -66,10 +63,9 @@ public: private: struct ScheduledTask { - ComputeInstance key; - Operation *sourceOp = nullptr; + ComputeInstance computeInstance; + Operation* sourceOp = nullptr; size_t cpu = 0; - size_t slot = 0; size_t order = 0; size_t executionOrder = 0; }; @@ -82,7 +78,7 @@ private: struct CpuProgram { SpatCompute op; - Block *block = nullptr; + Block* block = nullptr; DenseMap externalInputMap; DenseMap weightToIndex; }; @@ -102,17 +98,17 @@ private: size_t sourceOrder = 0; }; - static uint64_t getRemoteSendPairKey(const ChannelInfo &channelInfo) { + static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) { return (static_cast(static_cast(channelInfo.sourceCoreId)) << 32) | static_cast(channelInfo.targetCoreId); } - static void appendUniqueValue(SmallVectorImpl &values, DenseSet &seen, Value value) { + static void appendUniqueValue(SmallVectorImpl& values, DenseSet& seen, Value value) { if (seen.insert(value).second) values.push_back(value); } - bool isInternalInputOp(Operation *op) { + bool isOldComputeResult(Operation* op) { auto it = isInternalInputOpCache.find(op); if (it != isInternalInputOpCache.end()) return it->second; @@ -122,10 +118,10 @@ private: return isInternalInputOpCache[op] = false; for (Value result : extract->getResults()) { - for (Operation *user : result.getUsers()) { - if (toEraseSet.contains(user)) + for (Operation* user : result.getUsers()) { + if (oldComputeOps.contains(user)) continue; - if (isInternalInputOp(user)) + if (isOldComputeResult(user)) continue; return isInternalInputOpCache[op] = false; } @@ -134,21 +130,22 @@ private: } void collectInternalInputOps(Value value) { - Operation *op = value.getDefiningOp(); + Operation* op = value.getDefiningOp(); + //TODO ExtractSliceOp is not the only legal host op to traverse! dio while (auto extract = dyn_cast_if_present(op)) { - if (isInternalInputOp(extract.getOperation())) + if (isOldComputeResult(extract.getOperation())) internalInputOpsToErase.insert(extract.getOperation()); value = extract.getSource(); op = value.getDefiningOp(); } } - void collectExternalUsers(Operation *op) { + void collectExternalUsers(Operation* op) { if (!externalUsersToMove.insert(op).second) return; for (Value result : op->getResults()) { - for (Operation *user : result.getUsers()) { - if (toEraseSet.contains(user) || isa(user)) + for (Operation* user : result.getUsers()) { + if (oldComputeOps.contains(user) || isa(user)) continue; collectExternalUsers(user); } @@ -158,10 +155,12 @@ private: void collectScheduledTasks() { size_t nextOrder = 0; for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) { - toEraseSet.insert(scheduledInstance.op); - scheduledTasks.push_back( - {scheduledInstance, scheduledInstance.op, schedule->computeToCpuMap.lookup(scheduledInstance), - schedule->computeToCpuSlotMap.lookup(scheduledInstance), nextOrder++}); + oldComputeOps.insert(scheduledInstance.op); + scheduledTasks.push_back({scheduledInstance, + scheduledInstance.op, + schedule->computeToCpuMap.lookup(scheduledInstance), + schedule->computeToCpuSlotMap.lookup(scheduledInstance), + nextOrder++}); } } @@ -171,42 +170,42 @@ private: orderedCpus.push_back(cpu); }; - for (const ScheduledTask &task : scheduledTasks) { - taskByKey[task.key] = task; + for (const ScheduledTask& task : scheduledTasks) { + taskByComputeInstance[task.computeInstance] = task; tasksByCpu[task.cpu].push_back(task); markCpuSeen(task.cpu); } llvm::sort(orderedCpus); for (size_t cpu : orderedCpus) { - llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask &lhs, const ScheduledTask &rhs) { - if (lhs.slot != rhs.slot) - return lhs.slot < rhs.slot; - return lhs.order < rhs.order; - }); + llvm::stable_sort(tasksByCpu[cpu], + [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { return lhs.order < rhs.order; }); for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) { task.executionOrder = executionOrder; - taskByKey[task.key].executionOrder = executionOrder; + taskByComputeInstance[task.computeInstance].executionOrder = executionOrder; } } } void collectExternalInputsAndWeights() { for (size_t cpu : orderedCpus) { - for (const ScheduledTask &task : tasksByCpu[cpu]) { - auto taskWeights = getComputeInstanceWeights(task.key); + for (const ScheduledTask& task : tasksByCpu[cpu]) { + auto& thisCpuWeights = cpuWeights[cpu]; + auto& thisSeenWeights = seenWeightsByCpu[cpu]; + auto taskWeights = getComputeInstanceWeights(task.computeInstance); for (Value weight : taskWeights) - appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight); + if (thisSeenWeights.insert(weight).second) + thisCpuWeights.push_back(weight); - auto taskInputs = getComputeInstanceInputs(task.key); - auto &remoteInputs = remoteInputsByTask[task.key]; + auto taskInputs = getComputeInstanceInputs(task.computeInstance); + auto& remoteInputs = remoteInputsByTask[task.computeInstance]; remoteInputs.resize(taskInputs.size()); for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { auto producerRef = getProducerValueRef(input); if (producerRef) { collectInternalInputOps(input); - auto producerIt = taskByKey.find(producerRef->instance); - if (producerIt != taskByKey.end()) { + auto producerIt = taskByComputeInstance.find(producerRef->instance); + if (producerIt != taskByComputeInstance.end()) { if (producerIt->second.cpu != cpu) { ChannelInfo info { (*nextChannelId)++, @@ -214,11 +213,11 @@ private: getPhysicalCoreId(cpu), }; remoteInputs[inputIndex] = info; - auto &perResultChannels = remoteSendsByTask[producerRef->instance]; + auto& perResultChannels = remoteSendsByTask[producerRef->instance]; if (perResultChannels.empty()) - perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.key).size()); + perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size()); perResultChannels[producerRef->resultIndex].push_back( - {info, task.key, inputIndex, task.executionOrder, 0}); + {info, task.computeInstance, inputIndex, task.executionOrder, 0}); } continue; } @@ -226,19 +225,19 @@ private: appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input); } - auto taskOutputs = getComputeInstanceOutputValues(task.key); + auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance); for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) { bool hasExternalUser = false; - for (auto &use : output.getUses()) { - Operation *useOwner = use.getOwner(); - if (toEraseSet.contains(useOwner)) + for (auto& use : output.getUses()) { + Operation* useOwner = use.getOwner(); + if (oldComputeOps.contains(useOwner)) continue; hasExternalUser = true; if (!isa(useOwner)) collectExternalUsers(useOwner); } if (hasExternalUser) - cpuExternalOutputs[cpu].push_back({task.key, resultIndex}); + cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex}); } } } @@ -248,12 +247,12 @@ private: for (size_t cpu : orderedCpus) { DenseMap nextSourceOrderByPair; DenseMap lastConsumerOrderByPair; - for (const ScheduledTask &task : tasksByCpu[cpu]) { - auto sendsIt = remoteSendsByTask.find(task.key); + for (const ScheduledTask& task : tasksByCpu[cpu]) { + auto sendsIt = remoteSendsByTask.find(task.computeInstance); if (sendsIt == remoteSendsByTask.end()) continue; - for (auto &sendInfos : sendsIt->second) { - for (RemoteSendInfo &sendInfo : sendInfos) { + for (auto& sendInfos : sendsIt->second) { + for (RemoteSendInfo& sendInfo : sendInfos) { uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++; auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder); @@ -269,10 +268,10 @@ private: } void planReceiveReordering() { - DenseMap> reorderedSendsByPair; - for (auto &taskSends : remoteSendsByTask) { - for (auto &sendInfos : taskSends.second) { - for (RemoteSendInfo &sendInfo : sendInfos) { + DenseMap> reorderedSendsByPair; + for (auto& taskSends : remoteSendsByTask) { + for (auto& sendInfos : taskSends.second) { + for (RemoteSendInfo& sendInfo : sendInfos) { uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); if (pairsNeedingReceiveReorder.contains(pairKey)) reorderedSendsByPair[pairKey].push_back(&sendInfo); @@ -280,13 +279,13 @@ private: } } - for (auto &pairSends : reorderedSendsByPair) { - llvm::stable_sort(pairSends.second, [](const RemoteSendInfo *lhs, const RemoteSendInfo *rhs) { + for (auto& pairSends : reorderedSendsByPair) { + llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) { if (lhs->sourceOrder != rhs->sourceOrder) return lhs->sourceOrder < rhs->sourceOrder; return lhs->channelInfo.channelId < rhs->channelInfo.channelId; }); - for (RemoteSendInfo *sendInfo : pairSends.second) { + for (RemoteSendInfo* sendInfo : pairSends.second) { int64_t channelId = (*nextChannelId)++; sendInfo->channelInfo.channelId = channelId; auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer); @@ -297,9 +296,9 @@ private: } } - for (const auto &taskSends : remoteSendsByTask) { - for (const auto &sendInfos : taskSends.second) { - for (const RemoteSendInfo &sendInfo : sendInfos) { + for (const auto& taskSends : remoteSendsByTask) { + for (const auto& sendInfos : taskSends.second) { + for (const RemoteSendInfo& sendInfo : sendInfos) { auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer); assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send"); assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range"); @@ -309,9 +308,9 @@ private: } } - for (auto &taskSends : remoteSendsByTask) { - for (const auto &sendInfos : taskSends.second) { - for (const RemoteSendInfo &sendInfo : sendInfos) { + for (auto& taskSends : remoteSendsByTask) { + for (const auto& sendInfos : taskSends.second) { + for (const RemoteSendInfo& sendInfo : sendInfos) { uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); if (!pairsNeedingReceiveReorder.contains(pairKey)) continue; @@ -322,9 +321,9 @@ private: } } - for (auto &cpuQueues : receiveQueuesByCpu) { - for (auto &pairQueue : cpuQueues.second) { - llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry &lhs, const RemoteReceiveEntry &rhs) { + for (auto& cpuQueues : receiveQueuesByCpu) { + for (auto& pairQueue : cpuQueues.second) { + llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) { if (lhs.sourceOrder != rhs.sourceOrder) return lhs.sourceOrder < rhs.sourceOrder; return lhs.channelInfo.channelId < rhs.channelInfo.channelId; @@ -344,8 +343,8 @@ private: SmallVector resultTypes; resultTypes.reserve(cpuExternalOutputs[cpu].size()); for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { - ScheduledTask task = taskByKey.at(outputRef.instance); - resultTypes.push_back(getComputeInstanceOutputTypes(task.key)[outputRef.resultIndex]); + ScheduledTask task = taskByComputeInstance.at(outputRef.instance); + resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]); } rewriter.setInsertionPoint(returnOp); @@ -362,7 +361,7 @@ private: blockArgTypes.push_back(input.getType()); blockArgLocs.push_back(loc); } - Block *newBlock = + Block* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); CpuProgram program; @@ -373,19 +372,19 @@ private: for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu])) program.externalInputMap[input] = newBlock->getArgument(inputIndex); for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) { - ScheduledTask task = taskByKey.at(outputRef.instance); - oldToNewExternalValueMap[getComputeInstanceOutputValues(task.key)[outputRef.resultIndex]] = + ScheduledTask task = taskByComputeInstance.at(outputRef.instance); + oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] = newCompute.getResult(resultIndex); } cpuPrograms[cpu] = std::move(program); } } - FailureOr receiveThroughInput(IRRewriter &rewriter, + FailureOr receiveThroughInput(IRRewriter& rewriter, size_t cpu, - DenseMap &receiveQueueIndices, - DenseMap> &preReceivedInputsByTask, - const ChannelInfo &requestedChannelInfo, + DenseMap& receiveQueueIndices, + DenseMap>& preReceivedInputsByTask, + const ChannelInfo& requestedChannelInfo, ComputeInstance requestedConsumer, size_t requestedInputIndex) { uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo); @@ -396,26 +395,25 @@ private: if (queueIt == cpuQueuesIt->second.end()) return failure(); - auto &queue = queueIt->second; - size_t &queueIndex = receiveQueueIndices[pairKey]; + auto& queue = queueIt->second; + size_t& queueIndex = receiveQueueIndices[pairKey]; while (queueIndex < queue.size()) { - const RemoteReceiveEntry &entry = queue[queueIndex++]; - auto consumerTaskIt = taskByKey.find(entry.consumer); - if (consumerTaskIt == taskByKey.end()) + const RemoteReceiveEntry& entry = queue[queueIndex++]; + auto consumerTaskIt = taskByComputeInstance.find(entry.consumer); + if (consumerTaskIt == taskByComputeInstance.end()) return failure(); - SmallVector consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.key); + SmallVector consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance); if (consumerInputs.size() <= entry.inputIndex) return failure(); Type inputType = consumerInputs[entry.inputIndex].getType(); - auto receive = - spatial::SpatChannelReceiveOp::create(rewriter, - loc, - inputType, - rewriter.getI64IntegerAttr(entry.channelInfo.channelId), - rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId)); + auto receive = spatial::SpatChannelReceiveOp::create(rewriter, + loc, + inputType, + rewriter.getI64IntegerAttr(entry.channelInfo.channelId), + rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId)); - auto &receivedInputs = preReceivedInputsByTask[entry.consumer]; + auto& receivedInputs = preReceivedInputsByTask[entry.consumer]; if (receivedInputs.size() <= entry.inputIndex) receivedInputs.resize(entry.inputIndex + 1); receivedInputs[entry.inputIndex] = receive.getResult(); @@ -428,7 +426,7 @@ private: LogicalResult cloneTaskBodies() { for (size_t cpu : orderedCpus) { - CpuProgram &program = cpuPrograms[cpu]; + CpuProgram& program = cpuPrograms[cpu]; IRRewriter rewriter(func.getContext()); rewriter.setInsertionPointToEnd(program.block); DenseMap receiveQueueIndices; @@ -444,25 +442,24 @@ private: return value; }; - for (const ScheduledTask &task : tasksByCpu[cpu]) { - SmallVector taskInputs = getComputeInstanceInputs(task.key); - auto taskWeights = getComputeInstanceWeights(task.key); - Block &templateBlock = getComputeInstanceTemplateBlock(task.key); + for (const ScheduledTask& task : tasksByCpu[cpu]) { + SmallVector taskInputs = getComputeInstanceInputs(task.computeInstance); + auto taskWeights = getComputeInstanceWeights(task.computeInstance); + Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance); SmallVector resolvedInputs; resolvedInputs.reserve(taskInputs.size()); - auto remoteInputsIt = remoteInputsByTask.find(task.key); + auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance); for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { auto producerRef = getProducerValueRef(input); if (producerRef) { - auto producerIt = taskByKey.find(producerRef->instance); - if (producerIt != taskByKey.end()) { + auto producerIt = taskByComputeInstance.find(producerRef->instance); + if (producerIt != taskByComputeInstance.end()) { if (producerIt->second.cpu == cpu) { auto producedIt = producedValuesByTask.find(producerRef->instance); if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) { task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization") - << " consumerCpu=" << cpu << " consumerSlot=" << task.slot - << " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot + << " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu << " producerLaneStart=" << producerRef->instance.laneStart << " producerLaneCount=" << producerRef->instance.laneCount; return failure(); @@ -470,20 +467,24 @@ private: resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]); continue; } - const ChannelInfo &channelInfo = *remoteInputsIt->second[inputIndex]; + const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex]; uint64_t pairKey = getRemoteSendPairKey(channelInfo); if (pairsNeedingReceiveReorder.contains(pairKey)) { - if (std::optional preReceived = lookupPreReceivedInput(task.key, inputIndex)) { + if (std::optional preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) { resolvedInputs.push_back(*preReceived); continue; } - FailureOr received = receiveThroughInput( - rewriter, cpu, receiveQueueIndices, preReceivedInputsByTask, channelInfo, task.key, inputIndex); + FailureOr received = receiveThroughInput(rewriter, + cpu, + receiveQueueIndices, + preReceivedInputsByTask, + channelInfo, + task.computeInstance, + inputIndex); if (failed(received)) { task.sourceOp->emitOpError("failed to materialize reordered remote receive") - << " consumerCpu=" << cpu << " consumerSlot=" << task.slot - << " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId - << " channelId=" << channelInfo.channelId; + << " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId + << " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId; return failure(); } resolvedInputs.push_back(*received); @@ -510,14 +511,14 @@ private: for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments())) mapper.map(oldArg, resolvedInputs[argIndex]); - for (Operation &op : templateBlock) { + for (Operation& op : templateBlock) { if (auto yield = dyn_cast(&op)) { for (Value yieldOperand : yield.getOperands()) taskYieldValues.push_back(mapper.lookup(yieldOperand)); continue; } - Operation *clonedOp = rewriter.clone(op, mapper); + Operation* clonedOp = rewriter.clone(op, mapper); if (auto oldWeightedMvmOp = dyn_cast(&op)) { auto newWeightedMvmOp = cast(clonedOp); Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()]; @@ -529,24 +530,24 @@ private: newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight)); } } - } else { - for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) { + } + else { + for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) { IRMapping mapper; if (templateBlock.getNumArguments() == 1) mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]); - for (Operation &op : templateBlock) { + for (Operation& op : templateBlock) { if (auto yield = dyn_cast(&op)) { for (Value yieldOperand : yield.getOperands()) taskYieldValues.push_back(mapper.lookup(yieldOperand)); continue; } - Operation *clonedOp = rewriter.clone(op, mapper); + Operation* clonedOp = rewriter.clone(op, mapper); if (auto oldWeightedMvmOp = dyn_cast(&op)) { if (oldWeightedMvmOp.getWeightIndex() != 0) { - task.sourceOp->emitOpError( - "batched per-cpu merge materialization expects lane-local weight index 0"); + task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); return failure(); } auto newWeightedMvmOp = cast(clonedOp); @@ -554,8 +555,7 @@ private: } if (auto oldWeightedVmmOp = dyn_cast(&op)) { if (oldWeightedVmmOp.getWeightIndex() != 0) { - task.sourceOp->emitOpError( - "batched per-cpu merge materialization expects lane-local weight index 0"); + task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); return failure(); } auto newWeightedVmmOp = cast(clonedOp); @@ -565,13 +565,13 @@ private: } } - producedValuesByTask[task.key] = taskYieldValues; - if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) { + producedValuesByTask[task.computeInstance] = taskYieldValues; + if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) { for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) { if (sendInfos.empty()) continue; Value producedValue = taskYieldValues[resultIndex]; - for (const RemoteSendInfo &sendInfo : sendInfos) { + for (const RemoteSendInfo& sendInfo : sendInfos) { spatial::SpatChannelSendOp::create(rewriter, loc, rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId), @@ -588,9 +588,9 @@ private: for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { auto producedIt = producedValuesByTask.find(outputRef.instance); if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) { - ScheduledTask task = taskByKey.at(outputRef.instance); + ScheduledTask task = taskByComputeInstance.at(outputRef.instance); task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization") - << " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart; + << " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart; return failure(); } yieldValues.push_back(producedIt->second[outputRef.resultIndex]); @@ -603,31 +603,31 @@ private: void replaceExternalUses() { for (auto [oldValue, newValue] : oldToNewExternalValueMap) { - for (auto &use : llvm::make_early_inc_range(oldValue.getUses())) - if (!toEraseSet.contains(use.getOwner())) + for (auto& use : llvm::make_early_inc_range(oldValue.getUses())) + if (!oldComputeOps.contains(use.getOwner())) use.assign(newValue); } } LogicalResult eraseOldScheduledOps() { - DenseSet allOpsToErase = toEraseSet; - for (Operation *op : internalInputOpsToErase) + DenseSet allOpsToErase = oldComputeOps; + for (Operation* op : internalInputOpsToErase) allOpsToErase.insert(op); - SmallVector orderedOpsToErase; - for (Operation &op : func.getBody().front()) + SmallVector orderedOpsToErase; + for (Operation& op : func.getBody().front()) if (allOpsToErase.contains(&op)) orderedOpsToErase.push_back(&op); - for (Operation *op : llvm::reverse(orderedOpsToErase)) { - SmallVector remainingUsers; + for (Operation* op : llvm::reverse(orderedOpsToErase)) { + SmallVector remainingUsers; for (Value result : op->getResults()) - for (Operation *user : result.getUsers()) + for (Operation* user : result.getUsers()) remainingUsers.push_back(user); if (!remainingUsers.empty()) { InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup") << "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no"); - for (Operation *user : remainingUsers) { + for (Operation* user : remainingUsers) { diagnostic.attachNote(user->getLoc()) << "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no"); } @@ -640,32 +640,32 @@ private: } void moveExternalUsersBeforeReturn() { - SmallVector orderedUsersToMove; - for (Operation &op : func.getBody().front()) { + SmallVector orderedUsersToMove; + for (Operation& op : func.getBody().front()) { if (&op == returnOp.getOperation()) break; if (externalUsersToMove.contains(&op)) orderedUsersToMove.push_back(&op); } - for (Operation *op : orderedUsersToMove) + for (Operation* op : orderedUsersToMove) op->moveBefore(returnOp); } func::FuncOp func; - const MergeScheduleResult *schedule = nullptr; - int64_t *nextChannelId = nullptr; + const MergeScheduleResult* schedule = nullptr; + int64_t* nextChannelId = nullptr; Location loc; func::ReturnOp returnOp; SmallVector scheduledTasks; - DenseSet toEraseSet; - DenseMap taskByKey; + DenseSet oldComputeOps; + DenseMap taskByComputeInstance; DenseMap> tasksByCpu; SmallVector orderedCpus; DenseSet seenCpus; - DenseSet internalInputOpsToErase; - DenseMap isInternalInputOpCache; - DenseSet externalUsersToMove; + DenseSet internalInputOpsToErase; + DenseMap isInternalInputOpCache; + DenseSet externalUsersToMove; DenseMap>> remoteSendsByTask; DenseMap>> remoteInputsByTask; DenseMap> cpuExternalInputs; @@ -683,7 +683,7 @@ private: } // namespace LogicalResult -MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId) { +MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) { return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId); }