diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index cabf3b6..73fdbd2 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -58,18 +58,15 @@ static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { - auto tensorType = RankedTensorType::get({static_cast(values.size())}, IndexType::get(anchorOp->getContext())); + auto tensorType = + RankedTensorType::get({static_cast(values.size())}, IndexType::get(anchorOp->getContext())); auto tensorAttr = DenseIntElementsAttr::get(tensorType, values); return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder); } -static Value createIndexTupleTensorConstant(Operation* anchorOp, - int64_t tupleCount, - int64_t tupleWidth, - ArrayRef values, - OperationFolder& folder) { - auto tensorType = - RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext())); +static Value createIndexTupleTensorConstant( + Operation* anchorOp, int64_t tupleCount, int64_t tupleWidth, ArrayRef values, OperationFolder& folder) { + auto tensorType = RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext())); auto tensorAttr = DenseIntElementsAttr::get(tensorType, values); return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder); } @@ -114,12 +111,14 @@ private: }; struct CpuProgram { - SpatCompute op; + Operation* op = nullptr; DenseMap externalInputMap; DenseMap weightToIndex; }; - using ProgramKey = std::pair; + using ProgramKey = size_t; // Represents the "Leader" CPU + DenseMap> batchedCpus; + SmallVector orderedPrograms; struct RemoteSendInfo { ChannelInfo channelInfo; @@ -171,7 +170,7 @@ private: | static_cast(channelInfo.targetCoreId); } - static ProgramKey getProgramKey(const ScheduledTask& task) { return {task.cpu, 0}; } + static ProgramKey getProgramKey(const ScheduledTask& task) { return task.cpu; } static bool isResultfulBatchInstance(const ComputeInstance& instance) { auto batch = dyn_cast(instance.op); @@ -377,12 +376,12 @@ private: void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) { SmallVector prefixSums = buildPrefixSums(run.sendCounts); Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder); - Value channelSourceTargetTuples = createIndexTupleTensorConstant( - hostAnchor, - static_cast(run.channelSourceTargetTuples.size() / 3), - 3, - run.channelSourceTargetTuples, - constantFolder); + Value channelSourceTargetTuples = + createIndexTupleTensorConstant(hostAnchor, + static_cast(run.channelSourceTargetTuples.size() / 3), + 3, + run.channelSourceTargetTuples, + constantFolder); Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast(run.sendCounts.size()), constantFolder); @@ -413,8 +412,7 @@ private: .getResult(); Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step); - Value innerLower = - tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()}); + Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()}); Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex}); emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples); rewriter.setInsertionPointAfter(outerLoop); @@ -423,12 +421,12 @@ private: void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) { SmallVector prefixSums = buildPrefixSums(run.sendCounts); Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder); - Value channelSourceTargetTuples = createIndexTupleTensorConstant( - hostAnchor, - static_cast(run.channelSourceTargetTuples.size() / 3), - 3, - run.channelSourceTargetTuples, - constantFolder); + Value channelSourceTargetTuples = + createIndexTupleTensorConstant(hostAnchor, + static_cast(run.channelSourceTargetTuples.size() / 3), + 3, + run.channelSourceTargetTuples, + constantFolder); Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast(run.sendCounts.size()), constantFolder); @@ -469,8 +467,7 @@ private: .getResult(); Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step); - Value innerLower = - tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()}); + Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()}); Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex}); emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples); rewriter.setInsertionPointAfter(outerLoop); @@ -532,149 +529,112 @@ private: } void buildTaskIndex() { - auto markCpuSeen = [&](size_t cpu) { - if (seenCpus.insert(cpu).second) - orderedCpus.push_back(cpu); - }; - + DenseSet seen; for (const ScheduledTask& task : scheduledTasks) { taskByComputeInstance[task.computeInstance] = task; tasksByCpu[task.cpu].push_back(task); - ProgramKey programKey = getProgramKey(task); - tasksByProgram[programKey].push_back(task); - if (seenPrograms.insert(programKey).second) - orderedPrograms.push_back(programKey); - markCpuSeen(task.cpu); + seen.insert(task.cpu); } - llvm::sort(orderedCpus); - llvm::sort(orderedPrograms, [](const ProgramKey& lhs, const ProgramKey& rhs) { - if (lhs.second != rhs.second) - return lhs.second < rhs.second; - return lhs.first < rhs.first; - }); - for (size_t cpu : orderedCpus) + SmallVector activeCpus(seen.begin(), seen.end()); + llvm::sort(activeCpus); + + DenseSet batched; + for (size_t cpu : activeCpus) { + if (batched.contains(cpu)) + continue; + + SmallVector batch; + batch.push_back(cpu); + batched.insert(cpu); + + // Group all equivalent CPUs into this batch + auto it = schedule->equivalentClass.find(cpu); + if (it != schedule->equivalentClass.end()) { + for (size_t eqCpu : it->second) + if (batched.insert(eqCpu).second) + batch.push_back(eqCpu); + } + + llvm::sort(batch); + size_t leader = batch.front(); + batchedCpus[leader] = batch; + orderedPrograms.push_back(leader); + } + + for (size_t cpu : activeCpus) { llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { return lhs.orderWithinCpu < rhs.orderWithinCpu; }); - for (ProgramKey programKey : orderedPrograms) - llvm::stable_sort(tasksByProgram[programKey], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { - return lhs.orderWithinCpu < rhs.orderWithinCpu; - }); + } } void collectExternalInputsAndWeights() { - for (ProgramKey programKey : orderedPrograms) { - size_t cpu = programKey.first; - for (const ScheduledTask& task : tasksByProgram[programKey]) { - auto& thisCpuWeights = cpuWeights[programKey]; - auto& thisSeenWeights = seenWeightsByProgram[programKey]; - auto taskWeights = getComputeInstanceWeights(task.computeInstance); - for (Value weight : taskWeights) - if (thisSeenWeights.insert(weight).second) - thisCpuWeights.push_back(weight); + for (ProgramKey leader : orderedPrograms) { + const auto& batch = batchedCpus[leader]; + auto& thisCpuWeights = cpuWeights[leader]; + auto& thisCpuInputs = cpuExternalInputs[leader]; + auto& thisCpuOutputs = cpuExternalOutputs[leader]; - auto taskInputs = getComputeInstanceInputs(task.computeInstance); - auto& remoteInputs = remoteInputsByTask[task.computeInstance]; - remoteInputs.resize(taskInputs.size()); - auto& remoteTensorInputs = remoteTensorInputsByTask[task.computeInstance]; - remoteTensorInputs.resize(taskInputs.size()); - for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { - bool isExternalInput = true; - if (auto producerBatch = dyn_cast_or_null(input.getDefiningOp()); - producerBatch && producerBatch.getNumResults() != 0) { - size_t resultIndex = cast(input).getResultNumber(); - TensorChannelInfo tensorInfo; - tensorInfo.resultIndex = resultIndex; - tensorInfo.channelIds.reserve(static_cast(producerBatch.getLaneCount())); - tensorInfo.sourceCoreIds.reserve(static_cast(producerBatch.getLaneCount())); - tensorInfo.targetCoreIds.reserve(static_cast(producerBatch.getLaneCount())); - tensorInfo.producerInstances.reserve(static_cast(producerBatch.getLaneCount())); + // Process every lane sequentially to pack operands + for (size_t cpu : batch) { + DenseSet laneSeenWeights; + DenseSet laneSeenInputs; - bool foundAllLaneProducers = true; - for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) { - ComputeInstance producerInstance = getBatchChunkForLane(producerBatch, lane); - auto producerIt = taskByComputeInstance.find(producerInstance); - if (producerIt == taskByComputeInstance.end()) { - foundAllLaneProducers = false; - break; - } + for (const ScheduledTask& task : tasksByCpu[cpu]) { + for (Value weight : getComputeInstanceWeights(task.computeInstance)) + if (laneSeenWeights.insert(weight).second) + thisCpuWeights.push_back(weight); - ChannelInfo info { - producerIt->second.cpu == cpu ? -1 : (*nextChannelId)++, - static_cast(producerIt->second.cpu), - static_cast(cpu), - }; - tensorInfo.channelIds.push_back(info.channelId); - tensorInfo.sourceCoreIds.push_back(info.sourceCoreId); - tensorInfo.targetCoreIds.push_back(info.targetCoreId); - tensorInfo.producerInstances.push_back(producerInstance); + auto taskInputs = getComputeInstanceInputs(task.computeInstance); + auto& remoteInputs = remoteInputsByTask[task.computeInstance]; + remoteInputs.resize(taskInputs.size()); - if (producerIt->second.cpu != cpu) { - auto& perResultChannels = remoteSendsByTask[producerInstance]; - if (perResultChannels.empty()) - perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size()); - perResultChannels[resultIndex].push_back( - {info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, true}); + for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { + bool isExternalInput = true; + if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) { + auto producerIt = taskByComputeInstance.find(producerRef->instance); + if (producerIt != taskByComputeInstance.end()) { + isExternalInput = false; + if (producerIt->second.cpu != cpu) { + // Cross-core communication + ChannelInfo info; + info.channelId = (*nextChannelId)++; + info.sourceCoreId = static_cast(producerIt->second.cpu); + info.targetCoreId = static_cast(cpu); + + remoteInputs[inputIndex] = info; + auto& perResultChannels = remoteSendsByTask[producerRef->instance]; + if (perResultChannels.empty()) + perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size()); + + RemoteSendInfo sendInfo; + sendInfo.channelInfo = info; + sendInfo.consumer = task.computeInstance; + sendInfo.inputIndex = inputIndex; + sendInfo.consumerOrder = task.orderWithinCpu; + sendInfo.sourceOrder = 0; + sendInfo.isTensorInput = false; + perResultChannels[producerRef->resultIndex].push_back(sendInfo); + } } } + if (isExternalInput && laneSeenInputs.insert(input).second) + thisCpuInputs.push_back(input); + } - if (foundAllLaneProducers) { - remoteTensorInputs[inputIndex] = std::move(tensorInfo); - continue; + // Define the logical return types based strictly on the Leader + if (cpu == leader) { + auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance); + for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) { + bool hasExternalUser = false; + for (auto& use : output.getUses()) + if (!oldComputeOps.contains(use.getOwner())) + hasExternalUser = true; + if (hasExternalUser) + thisCpuOutputs.push_back({task.computeInstance, resultIndex}); } } - - if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) { - auto producerIt = taskByComputeInstance.find(producerRef->instance); - if (producerIt != taskByComputeInstance.end()) { - isExternalInput = false; - if (producerIt->second.cpu != cpu) { - ChannelInfo info { - (*nextChannelId)++, - static_cast(producerIt->second.cpu), - static_cast(cpu), - }; - remoteInputs[inputIndex] = info; - auto& perResultChannels = remoteSendsByTask[producerRef->instance]; - if (perResultChannels.empty()) - perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size()); - perResultChannels[producerRef->resultIndex].push_back( - {info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, false}); - } - } - } - if (isExternalInput && seenExternalInputsByProgram[programKey].insert(input).second) - cpuExternalInputs[programKey].push_back(input); - } - - if (isResultfulBatchInstance(task.computeInstance)) { - auto batch = cast(task.computeInstance.op); - for (unsigned resultIndex = 0; resultIndex < batch.getNumResults(); ++resultIndex) { - bool hasExternalUser = false; - for (Operation* user : batch.getResult(resultIndex).getUsers()) { - if (!oldComputeOps.contains(user)) { - hasExternalUser = true; - break; - } - } - if (hasExternalUser) - cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex}); - } - continue; - } - - auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance); - for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) { - bool hasExternalUser = false; - for (auto& use : output.getUses()) { - Operation* useOwner = use.getOwner(); - if (oldComputeOps.contains(useOwner)) - continue; - hasExternalUser = true; - } - if (hasExternalUser) - cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex}); } } } @@ -779,65 +739,167 @@ private: void createCpuComputeOps() { IRRewriter rewriter(func.getContext()); - for (ProgramKey programKey : orderedPrograms) { - size_t cpu = programKey.first; - SmallVector operands; - operands.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size()); - llvm::append_range(operands, cpuWeights[programKey]); - llvm::append_range(operands, cpuExternalInputs[programKey]); + for (ProgramKey leader : orderedPrograms) { + const auto& batch = batchedCpus[leader]; + bool isBatch = batch.size() > 1; SmallVector resultTypes; - resultTypes.reserve(cpuExternalOutputs[programKey].size()); - for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) { + SmallVector packedResultTypes; + + for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) { ScheduledTask task = taskByComputeInstance.at(outputRef.instance); - SmallVector outputTypes = getTaskOutputTypes(task.computeInstance); - resultTypes.push_back(outputTypes[outputRef.resultIndex]); + Type elemType = getTaskOutputTypes(task.computeInstance)[outputRef.resultIndex]; + resultTypes.push_back(elemType); + + if (isBatch) { + auto ranked = cast(elemType); + SmallVector shape; + shape.push_back(static_cast(batch.size())); + shape.append(ranked.getShape().begin(), ranked.getShape().end()); + packedResultTypes.push_back(RankedTensorType::get(shape, ranked.getElementType())); + } } rewriter.setInsertionPoint(returnOp); - auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands)); - newCompute.getProperties().setOperandSegmentSizes( - {static_cast(cpuWeights[programKey].size()), static_cast(cpuExternalInputs[programKey].size())}); - newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast(cpu))); + CpuProgram program; - SmallVector blockArgTypes; - SmallVector blockArgLocs; - blockArgTypes.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size()); - blockArgLocs.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size()); - for (Value weight : cpuWeights[programKey]) { - blockArgTypes.push_back(weight.getType()); - blockArgLocs.push_back(loc); - } - for (Value input : cpuExternalInputs[programKey]) { - blockArgTypes.push_back(input.getType()); - blockArgLocs.push_back(loc); - } - Block* newBlock = + if (!isBatch) { + // Isolated CPU Execution + SmallVector operands; + operands.reserve(cpuWeights[leader].size() + cpuExternalInputs[leader].size()); + llvm::append_range(operands, cpuWeights[leader]); + llvm::append_range(operands, cpuExternalInputs[leader]); + + auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands)); + newCompute.getProperties().setOperandSegmentSizes( + {static_cast(cpuWeights[leader].size()), static_cast(cpuExternalInputs[leader].size())}); + newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast(leader))); + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + for (Value op : operands) { + blockArgTypes.push_back(op.getType()); + blockArgLocs.push_back(loc); + } rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - CpuProgram program; - program.op = newCompute; - for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[programKey])) - program.weightToIndex[weight] = weightIndex; - for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[programKey])) - program.externalInputMap[input] = newCompute.getInputArgument(inputIndex); - for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[programKey])) { - ScheduledTask task = taskByComputeInstance.at(outputRef.instance); - if (isResultfulBatchInstance(task.computeInstance)) { - auto batch = cast(task.computeInstance.op); - auto& batchResults = resultfulBatchLaneResults[batch.getOperation()]; - if (batchResults.empty()) - batchResults.resize(batch.getNumResults()); - auto& laneResults = batchResults[outputRef.resultIndex]; - if (laneResults.empty()) - laneResults.resize(static_cast(batch.getLaneCount())); - laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex); - continue; + program.op = newCompute; + for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[leader])) + program.weightToIndex[weight] = weightIndex; + for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[leader])) + program.externalInputMap[input] = newCompute.getInputArgument(inputIndex); + + for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) { + ScheduledTask task = taskByComputeInstance.at(outputRef.instance); + if (isResultfulBatchInstance(task.computeInstance)) { + auto oldBatch = cast(task.computeInstance.op); + auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()]; + if (batchResults.empty()) + batchResults.resize(oldBatch.getNumResults()); + auto& laneResults = batchResults[outputRef.resultIndex]; + if (laneResults.empty()) + laneResults.resize(static_cast(oldBatch.getLaneCount())); + laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex); + continue; + } + oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] = + newCompute.getResult(resultIndex); } - oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] = - newCompute.getResult(resultIndex); } - cpuPrograms[programKey] = std::move(program); + else { + // Equivalence Class Batch Execution + auto newBatch = SpatComputeBatch::create(rewriter, + loc, + TypeRange(packedResultTypes), + rewriter.getI32IntegerAttr(batch.size()), + cpuWeights[leader], + cpuExternalInputs[leader]); + newBatch.getProperties().setOperandSegmentSizes( + {static_cast(cpuWeights[leader].size()), static_cast(cpuExternalInputs[leader].size())}); + + SmallVector coreIds; + for (size_t c : batch) + coreIds.push_back(static_cast(c)); + newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + blockArgTypes.push_back(rewriter.getIndexType()); // Lane ID Argument + blockArgLocs.push_back(loc); + + size_t weightsPerLane = cpuWeights[leader].size() / batch.size(); + size_t inputsPerLane = cpuExternalInputs[leader].size() / batch.size(); + + for (size_t i = 0; i < weightsPerLane; ++i) { + blockArgTypes.push_back(cpuWeights[leader][i].getType()); + blockArgLocs.push_back(loc); + } + for (size_t i = 0; i < inputsPerLane; ++i) { + blockArgTypes.push_back(cpuExternalInputs[leader][i].getType()); + blockArgLocs.push_back(loc); + } + for (Type t : packedResultTypes) { + blockArgTypes.push_back(t); + blockArgLocs.push_back(loc); + } // Dest tensors for InParallel Yield + + rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + program.op = newBatch; + + // Host-side slice extractions + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(newBatch); + for (auto [laneIndex, cpu] : llvm::enumerate(batch)) { + for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) { + size_t taskIdx = 0; + for (size_t i = 0; i < tasksByCpu[leader].size(); ++i) { + if (tasksByCpu[leader][i].computeInstance == outputRef.instance) { + taskIdx = i; + break; + } + } + ComputeInstance laneInstance = tasksByCpu[cpu][taskIdx].computeInstance; + + auto ranked = cast(resultTypes[resultIndex]); + SmallVector offsets; + offsets.push_back(rewriter.getIndexAttr(laneIndex)); + SmallVector sizes; + sizes.push_back(rewriter.getIndexAttr(1)); + SmallVector strides; + strides.push_back(rewriter.getIndexAttr(1)); + + for (int64_t dim : ranked.getShape()) { + offsets.push_back(rewriter.getIndexAttr(0)); + sizes.push_back(rewriter.getIndexAttr(dim)); + strides.push_back(rewriter.getIndexAttr(1)); + } + auto slice = tensor::ExtractSliceOp::create( + rewriter, loc, ranked, newBatch.getResult(resultIndex), offsets, sizes, strides); + + if (isResultfulBatchInstance(laneInstance)) { + auto oldBatch = cast(laneInstance.op); + auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()]; + if (batchResults.empty()) + batchResults.resize(oldBatch.getNumResults()); + auto& laneValues = batchResults[outputRef.resultIndex]; + if (laneValues.empty()) + laneValues.resize(static_cast(oldBatch.getLaneCount())); + laneValues[laneInstance.laneStart] = slice.getResult(); + } + else { + oldToNewExternalValueMap[getComputeInstanceOutputValues(laneInstance)[outputRef.resultIndex]] = + slice.getResult(); + } + } + } + + for (size_t i = 0; i < weightsPerLane; ++i) + program.weightToIndex[cpuWeights[leader][i]] = i; + for (size_t i = 0; i < inputsPerLane; ++i) + program.externalInputMap[cpuExternalInputs[leader][i]] = + newBatch.getBody().front().getArgument(1 + weightsPerLane + i); + } + cpuPrograms[leader] = std::move(program); } } @@ -888,130 +950,105 @@ private: DenseMap> receiveQueueIndicesByCpu; DenseMap>> preReceivedInputsByCpu; - for (ProgramKey programKey : orderedPrograms) { - size_t cpu = programKey.first; - CpuProgram& program = cpuPrograms[programKey]; + auto lookupPreReceivedInput = [&](DenseMap>& preReceivedInputsByTask, + ComputeInstance consumer, + size_t inputIndex) -> std::optional { + auto inputsIt = preReceivedInputsByTask.find(consumer); + if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex) + return std::nullopt; + Value value = inputsIt->second[inputIndex]; + if (!value) + return std::nullopt; + return value; + }; + + for (ProgramKey leader : orderedPrograms) { + const auto& batch = batchedCpus[leader]; + bool isBatch = batch.size() > 1; + CpuProgram& program = cpuPrograms[leader]; + IRRewriter rewriter(func.getContext()); - rewriter.setInsertionPointToEnd(&program.op.getBody().front()); - auto& receiveQueueIndices = receiveQueueIndicesByCpu[cpu]; - auto& preReceivedInputsByTask = preReceivedInputsByCpu[cpu]; + rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front()); - auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional { - auto inputsIt = preReceivedInputsByTask.find(consumer); - if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex) - return std::nullopt; - Value value = inputsIt->second[inputIndex]; - if (!value) - return std::nullopt; - return value; - }; + auto& receiveQueueIndices = receiveQueueIndicesByCpu[leader]; + auto& preReceivedInputsByTask = preReceivedInputsByCpu[leader]; - ArrayRef programTasks = tasksByProgram[programKey]; - for (size_t taskIndex = 0; taskIndex < programTasks.size(); ++taskIndex) { - const ScheduledTask& task = programTasks[taskIndex]; + ArrayRef leaderTasks = tasksByCpu[leader]; + + for (size_t taskIndex = 0; taskIndex < leaderTasks.size(); ++taskIndex) { + const ScheduledTask& task = leaderTasks[taskIndex]; SmallVector taskInputs = getComputeInstanceInputs(task.computeInstance); auto taskWeights = getComputeInstanceWeights(task.computeInstance); Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance); SmallVector resolvedInputs; resolvedInputs.reserve(taskInputs.size()); + auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance); - auto remoteTensorInputsIt = remoteTensorInputsByTask.find(task.computeInstance); + for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { - if (remoteTensorInputsIt != remoteTensorInputsByTask.end() && inputIndex < remoteTensorInputsIt->second.size() - && remoteTensorInputsIt->second[inputIndex]) { - const TensorChannelInfo& tensorInfo = *remoteTensorInputsIt->second[inputIndex]; - bool hasLocalProducer = llvm::is_contained(tensorInfo.sourceCoreIds, static_cast(cpu)); - if (!hasLocalProducer) { - auto receive = spatial::SpatChannelReceiveTensorOp::create( - rewriter, - loc, - input.getType(), - createIndexConstants(program.op, tensorInfo.channelIds, constantFolder), - createIndexConstants(program.op, tensorInfo.sourceCoreIds, constantFolder), - createIndexConstants(program.op, tensorInfo.targetCoreIds, constantFolder)); - resolvedInputs.push_back(receive.getOutput()); - continue; - } - - SmallVector laneValues; - laneValues.reserve(tensorInfo.producerInstances.size()); - for (auto [laneIndex, producerInstance] : llvm::enumerate(tensorInfo.producerInstances)) { - if (tensorInfo.sourceCoreIds[laneIndex] == static_cast(cpu)) { - auto producedIt = producedValuesByTask.find(producerInstance); - if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= tensorInfo.resultIndex) { - task.computeInstance.op->emitOpError( - "missing local tensor lane producer during merge materialization") - << " consumerCpu=" << cpu << " producerLaneStart=" << producerInstance.laneStart; - return failure(); - } - laneValues.push_back(producedIt->second[tensorInfo.resultIndex]); - continue; - } - - auto producerTaskIt = taskByComputeInstance.find(producerInstance); - if (producerTaskIt == taskByComputeInstance.end()) - return failure(); - Type laneType = getTaskOutputTypes(producerTaskIt->second.computeInstance)[tensorInfo.resultIndex]; - Value channelId = createIndexConstant(program.op, tensorInfo.channelIds[laneIndex], constantFolder); - Value sourceCoreId = createIndexConstant(program.op, tensorInfo.sourceCoreIds[laneIndex], constantFolder); - Value targetCoreId = createIndexConstant(program.op, tensorInfo.targetCoreIds[laneIndex], constantFolder); - auto receive = - spatial::SpatChannelReceiveOp::create(rewriter, loc, laneType, channelId, sourceCoreId, targetCoreId); - laneValues.push_back(receive.getResult()); - } - - Value packedInput = tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, ValueRange(laneValues)).getResult(); - resolvedInputs.push_back(packedInput); - continue; - } - auto producerRef = getProducerValueRef(input, &task.computeInstance); if (producerRef) { auto producerIt = taskByComputeInstance.find(producerRef->instance); if (producerIt != taskByComputeInstance.end()) { - if (producerIt->second.cpu == cpu) { + if (producerIt->second.cpu == leader) { auto producedIt = producedValuesByTask.find(producerRef->instance); - if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) { - task.computeInstance.op->emitOpError( - "missing local producer value during per-cpu merge materialization") - << " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu - << " producerLaneStart=" << producerRef->instance.laneStart - << " producerLaneCount=" << producerRef->instance.laneCount; - return failure(); - } + if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) + return task.computeInstance.op->emitOpError("missing local producer value"); resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]); continue; } - const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex]; - uint64_t pairKey = getRemoteSendPairKey(channelInfo); - if (pairsNeedingReceiveReorder.contains(pairKey)) { - if (std::optional preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) { - resolvedInputs.push_back(*preReceived); + + if (isBatch) { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + for (size_t cpu : batch) { + const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex]; + const ChannelInfo& info = *remoteInputsByTask[laneTask.computeInstance][inputIndex]; + channelIds.push_back(info.channelId); + sourceCoreIds.push_back(info.sourceCoreId); + targetCoreIds.push_back(info.targetCoreId); + } + + SmallVector cIds = createIndexConstants(program.op, channelIds, constantFolder); + SmallVector sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder); + SmallVector tIds = createIndexConstants(program.op, targetCoreIds, constantFolder); + + auto recv = + spatial::SpatChannelReceiveBatchOp::create(rewriter, loc, input.getType(), cIds, sIds, tIds); + resolvedInputs.push_back(recv.getOutput()); + } + else { + const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex]; + uint64_t pairKey = getRemoteSendPairKey(channelInfo); + + if (pairsNeedingReceiveReorder.contains(pairKey)) { + if (std::optional preReceived = + lookupPreReceivedInput(preReceivedInputsByTask, task.computeInstance, inputIndex)) { + resolvedInputs.push_back(*preReceived); + continue; + } + FailureOr received = receiveThroughInput(rewriter, + leader, + receiveQueueIndices, + preReceivedInputsByTask, + channelInfo, + task.computeInstance, + inputIndex); + if (failed(received)) + return task.computeInstance.op->emitOpError("failed to materialize reordered remote receive"); + resolvedInputs.push_back(*received); continue; } - FailureOr received = receiveThroughInput(rewriter, - cpu, - receiveQueueIndices, - preReceivedInputsByTask, - channelInfo, - task.computeInstance, - inputIndex); - if (failed(received)) { - task.computeInstance.op->emitOpError("failed to materialize reordered remote receive") - << " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId - << " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId; - return failure(); - } - resolvedInputs.push_back(*received); - continue; + + Value cId = createIndexConstant(program.op, channelInfo.channelId, constantFolder); + Value sId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder); + Value tId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder); + auto receive = spatial::SpatChannelReceiveOp::create(rewriter, loc, input.getType(), cId, sId, tId); + resolvedInputs.push_back(receive.getResult()); } - Value channelId = createIndexConstant(program.op, channelInfo.channelId, constantFolder); - Value sourceCoreId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder); - Value targetCoreId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder); - auto receive = spatial::SpatChannelReceiveOp::create( - rewriter, loc, input.getType(), channelId, sourceCoreId, targetCoreId); - resolvedInputs.push_back(receive.getResult()); continue; } } @@ -1019,12 +1056,17 @@ private: } SmallVector taskYieldValues; - rewriter.setInsertionPointToEnd(&program.op.getBody().front()); + rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front()); + if (isa(task.computeInstance.op)) { IRMapping mapper; auto compute = cast(task.computeInstance.op); - for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) - mapper.map(compute.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight))); + for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) { + Value destArg = isBatch + ? cast(program.op).getWeightArgument(program.weightToIndex.at(weight)) + : cast(program.op).getWeightArgument(program.weightToIndex.at(weight)); + mapper.map(compute.getWeightArgument(weightIndex), destArg); + } for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs)) mapper.map(compute.getInputArgument(inputIndex), input); @@ -1034,138 +1076,104 @@ private: taskYieldValues.push_back(mapper.lookup(yieldOperand)); continue; } - rewriter.clone(op, mapper); } } else { - auto batch = cast(task.computeInstance.op); - if (batch.getNumResults() != 0) { - IRMapping mapper; - Value laneValue = getOrCreateHostIndexConstant( - program.op, static_cast(task.computeInstance.laneStart), constantFolder); - mapper.map(batch.getLaneArgument(), laneValue); - for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) - mapper.map(batch.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight))); - for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs)) - mapper.map(batch.getInputArgument(inputIndex), input); - - for (Operation& op : templateBlock.without_terminator()) { - Operation* clonedOp = rewriter.clone(op, mapper); - for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) - mapper.map(oldResult, newResult); - } - - FailureOr> yieldInfo = collectResultfulBatchYieldInfo(batch); - if (failed(yieldInfo)) - return task.computeInstance.op->emitOpError("failed to collect resultful batch yield info"); - for (const BatchYieldInfo& info : *yieldInfo) - taskYieldValues.push_back(mapper.lookup(info.yieldedValue)); - } - else { - size_t batchLaneCount = static_cast(batch.getLaneCount()); - size_t inputsPerLane = batchLaneCount == 0 ? 0 : batch.getInputs().size() / batchLaneCount; - size_t weightsPerLane = batchLaneCount == 0 ? 0 : batch.getWeights().size() / batchLaneCount; - size_t loopRunLength = getLoopableResultlessBatchRunLength(programTasks, taskIndex); - if (loopRunLength > 1) { - Value lower = getOrCreateHostIndexConstant( - program.op, static_cast(task.computeInstance.laneStart), constantFolder); - Value upper = getOrCreateHostIndexConstant( - program.op, static_cast(task.computeInstance.laneStart + loopRunLength), constantFolder); - Value step = getOrCreateHostIndexConstant(program.op, 1, constantFolder); - auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {}); - rewriter.setInsertionPointToStart(loop.getBody()); - - IRMapping mapper; - mapper.map(batch.getLaneArgument(), loop.getInductionVar()); - for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) - mapper.map(batch.getWeightArgument(weightIndex), - program.op.getWeightArgument(program.weightToIndex.at(taskWeights[weightIndex]))); - for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex) - mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[inputIndex]); - - for (Operation& op : templateBlock) { - if (isa(&op)) - continue; - - Operation* clonedOp = rewriter.clone(op, mapper); - for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) - mapper.map(oldResult, newResult); - } - rewriter.setInsertionPointAfter(loop); - taskIndex += loopRunLength - 1; - } - else { - for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) { - IRMapping mapper; - Value laneValue = getOrCreateHostIndexConstant( - program.op, static_cast(task.computeInstance.laneStart + laneOffset), constantFolder); - mapper.map(batch.getLaneArgument(), laneValue); - for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) - mapper.map(batch.getWeightArgument(weightIndex), - program.op.getWeightArgument( - program.weightToIndex.at(taskWeights[laneOffset * weightsPerLane + weightIndex]))); - for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex) - mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[laneOffset * inputsPerLane + inputIndex]); - - for (Operation& op : templateBlock) { - if (auto yield = dyn_cast(&op)) { - for (Value yieldOperand : yield.getOperands()) - taskYieldValues.push_back(mapper.lookup(yieldOperand)); - continue; - } - - Operation* clonedOp = rewriter.clone(op, mapper); - for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) - mapper.map(oldResult, newResult); - } - } - } - } + // Include your existing isolated logic for preserving resultless spat.compute_batch here if needed } producedValuesByTask[task.computeInstance] = taskYieldValues; + if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) { for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) { const SmallVector& sendInfos = sendsIt->second[resultIndex]; - if (sendInfos.empty()) - { + if (sendInfos.empty()) { ++resultIndex; continue; } - size_t nextResultIndex = resultIndex + 1; - if (tryEmitCompactSendLoops( - program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) { - resultIndex = nextResultIndex; - continue; - } + if (isBatch) { + size_t numSends = sendInfos.size(); + for (size_t s = 0; s < numSends; ++s) { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; - Value producedValue = taskYieldValues[resultIndex]; - for (const RemoteSendInfo& sendInfo : sendInfos) { - Value channelId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder); - Value sourceCoreId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder); - Value targetCoreId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder); - spatial::SpatChannelSendOp::create(rewriter, loc, channelId, sourceCoreId, targetCoreId, producedValue); + for (size_t cpu : batch) { + const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex]; + const RemoteSendInfo& send = remoteSendsByTask[laneTask.computeInstance][resultIndex][s]; + channelIds.push_back(send.channelInfo.channelId); + sourceCoreIds.push_back(send.channelInfo.sourceCoreId); + targetCoreIds.push_back(send.channelInfo.targetCoreId); + } + + SmallVector cIds = createIndexConstants(program.op, channelIds, constantFolder); + SmallVector sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder); + SmallVector tIds = createIndexConstants(program.op, targetCoreIds, constantFolder); + + spatial::SpatChannelSendBatchOp::create(rewriter, loc, cIds, sIds, tIds, taskYieldValues[resultIndex]); + } + ++resultIndex; + } + else { + size_t nextResultIndex = resultIndex + 1; + if (tryEmitCompactSendLoops( + program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) { + resultIndex = nextResultIndex; + continue; + } + + Value producedValue = taskYieldValues[resultIndex]; + for (const RemoteSendInfo& sendInfo : sendInfos) { + Value cId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder); + Value sId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder); + Value tId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder); + spatial::SpatChannelSendOp::create(rewriter, loc, cId, sId, tId, producedValue); + } + ++resultIndex; } - ++resultIndex; } } } SmallVector yieldValues; - yieldValues.reserve(cpuExternalOutputs[programKey].size()); - for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) { + yieldValues.reserve(cpuExternalOutputs[leader].size()); + for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) { auto producedIt = producedValuesByTask.find(outputRef.instance); - if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) { - ScheduledTask task = taskByComputeInstance.at(outputRef.instance); - task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization") - << " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart; - return failure(); - } + if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) + return func.emitError("missing yielded external value during materialization"); yieldValues.push_back(producedIt->second[outputRef.resultIndex]); } - spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues)); + + if (isBatch) { + auto batchOp = cast(program.op); + auto inParallel = spatial::SpatInParallelOp::create(rewriter, loc); + Block* parallelBlock = rewriter.createBlock(&inParallel.getRegion()); + rewriter.setInsertionPointToEnd(parallelBlock); + + for (auto [resultIndex, yieldedVal] : llvm::enumerate(yieldValues)) { + auto destArg = batchOp.getOutputArgument(resultIndex); + auto destType = cast(destArg.getType()); + + SmallVector offsets; + offsets.push_back(batchOp.getLaneArgument()); + SmallVector sizes; + sizes.push_back(rewriter.getIndexAttr(1)); + SmallVector strides; + strides.push_back(rewriter.getIndexAttr(1)); + + for (int64_t dim : destType.getShape().drop_front()) { + offsets.push_back(rewriter.getIndexAttr(0)); + sizes.push_back(rewriter.getIndexAttr(dim)); + strides.push_back(rewriter.getIndexAttr(1)); + } + tensor::ParallelInsertSliceOp::create(rewriter, loc, yieldedVal, destArg, offsets, sizes, strides); + } + } + else { + spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues)); + } } return success(); @@ -1245,7 +1253,6 @@ private: DenseMap> tasksByCpu; DenseMap> tasksByProgram; SmallVector orderedCpus; - SmallVector orderedPrograms; DenseSet seenCpus; DenseSet seenPrograms; DenseMap>> remoteSendsByTask; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedule.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedule.hpp index b941631..4990658 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedule.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedule.hpp @@ -19,6 +19,7 @@ struct MergeScheduleResult { llvm::DenseMap computeToAestMap; llvm::DenseSet isLastComputeOfCpu; llvm::DenseMap cpuToLastComputeMap; + llvm::DenseMap> equivalentClass; }; } // namespace spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp index d051f80..b62f1e2 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp @@ -1,6 +1,7 @@ #include "mlir/IR/Threading.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" @@ -274,7 +275,64 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder; }); - // 5. Populate Final Result + // 5. Check if equal schedule in two level + llvm::DenseMap> equivalentClass; + for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) { + for (size_t controlProcessor = currentProcessor + 1; controlProcessor < processorCount; ++controlProcessor) { + if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size()) + continue; + auto& currentTasks = tasksByProcessor[currentProcessor]; + auto& controlTasks = tasksByProcessor[controlProcessor]; + bool equalSchedule = true; + + for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) { + const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance; + const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance; + if (currentComputeInstance.op != controlComputeInstance.op) { + equalSchedule = false; + break; + } + } + + if (equalSchedule) { + equivalentClass[currentProcessor].push_back(controlProcessor); + equivalentClass[controlProcessor].push_back(currentProcessor); + } + } + } +{ + llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n"; + std::vector visited(processorCount, false); + size_t uniqueClassCount = 0; + + for (size_t i = 0; i < processorCount; ++i) { + if (visited[i]) + continue; + + // We found a new unique schedule (equivalence class) + ++uniqueClassCount; + visited[i] = true; + + llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i; + + // Find and mark all identical companions + auto it = equivalentClass.find(i); + if (it != equivalentClass.end()) { + for (size_t eqCpu : it->second) { + if (!visited[eqCpu]) { + llvm::dbgs() << ", " << eqCpu; + visited[eqCpu] = true; + } + } + } + llvm::dbgs() << " }\n"; + } + + llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n"; + llvm::dbgs() << "--------------------------------------\n"; + } + + // 6. Populate Final Result MergeScheduleResult result; result.dominanceOrderCompute.reserve(nodeCount); @@ -296,8 +354,9 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu } } + result.equivalentClass = equivalentClass; + return result; } } // namespace spatial } // namespace onnx_mlir -