From e33f517221bc8798eb0e95fdfe5cb6f21ac0f30d Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 3 Jun 2026 19:40:34 +0200 Subject: [PATCH] faster scheduling: split batches into numCores tasks before scheduling instead of numLanes tasks --- .../MaterializeMergeSchedule.cpp | 838 +++++++++--------- .../Scheduling/ComputeGraph.cpp | 142 ++- .../Scheduling/ComputeInstanceUtils.cpp | 60 +- .../Scheduling/ComputeInstanceUtils.hpp | 10 + .../Scheduling/PeftScheduler.cpp | 2 - 5 files changed, 606 insertions(+), 446 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 2253a75..650d4bb 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -117,7 +117,6 @@ struct ProducerKeyInfo { static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } }; -using CpuSlotKey = std::pair; using ClassSlotKey = std::pair; struct MaterializedClass { @@ -168,6 +167,11 @@ struct IndexedBatchRunValue { MessageVector messages; }; +struct LogicalSlotRange { + SlotId start = 0; + SlotId count = 0; +}; + struct MaterializationRunSlot { SmallVector peers; }; @@ -265,8 +269,10 @@ struct MaterializerState { SmallVector classes; DenseMap cpuToClass; - DenseMap cpuSlotToInstance; - DenseSet materializedSlots; + DenseMap> logicalInstancesByCpu; + DenseMap scheduledInstanceToLogicalSlots; + DenseMap logicalInstanceToScheduledChunk; + DenseSet materializedLogicalSlots; DenseMap, ProducerKeyInfo> producerDestClasses; DenseMap, ProducerKeyInfo> sameClassConsumers; @@ -345,19 +351,7 @@ bool isWholeBatchProducerKey(ProducerKey key) { && key.instance.laneCount == static_cast(batch.getLaneCount()); } -SmallVector expandWholeBatchProducerKey(ProducerKey key) { - if (!isWholeBatchProducerKey(key)) - return SmallVector {key}; - - auto batch = cast(key.instance.op); - SmallVector keys; - keys.reserve(batch.getLaneCount()); - for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) - keys.push_back(getBatchLaneProducerKey(batch, lane, 1, key.resultIndex)); - return keys; -} - -std::optional getContiguousProducerKeyForKeys(ArrayRef keys) { +std::optional getContiguousProducerRangeForKeys(ArrayRef keys) { if (keys.empty()) return std::nullopt; @@ -366,15 +360,23 @@ std::optional getContiguousProducerKeyForKeys(ArrayRef if (!batch) return std::nullopt; - uint32_t laneStart = first.instance.laneStart; - for (auto [index, key] : llvm::enumerate(keys)) { - if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount != 1) + SmallVector sorted(keys.begin(), keys.end()); + llvm::sort(sorted, [](ProducerKey lhs, ProducerKey rhs) { + return std::tie(lhs.instance.laneStart, lhs.instance.laneCount, lhs.resultIndex) + < std::tie(rhs.instance.laneStart, rhs.instance.laneCount, rhs.resultIndex); + }); + + uint32_t laneStart = sorted.front().instance.laneStart; + uint32_t nextLane = laneStart; + for (ProducerKey key : sorted) { + if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) return std::nullopt; - if (key.instance.laneStart != laneStart + static_cast(index)) + if (key.instance.laneStart != nextLane) return std::nullopt; + nextLane += key.instance.laneCount; } - uint32_t laneCount = static_cast(keys.size()); + uint32_t laneCount = nextLane - laneStart; if (laneStart + laneCount > static_cast(batch.getLaneCount())) return std::nullopt; @@ -397,7 +399,87 @@ LogicalResult verifyPackableFragmentType(Operation* anchor, Type fragmentType, s return success(); } -std::optional getProducerKey(Value value, const ComputeInstance* consumerInstance = nullptr) { +ComputeInstance getScheduledChunkForLogicalInstance(MaterializerState& state, ComputeInstance logicalInstance) { + auto it = state.logicalInstanceToScheduledChunk.find(logicalInstance); + if (it != state.logicalInstanceToScheduledChunk.end()) + return it->second; + return logicalInstance; +} + +SmallVector +collectProducerKeysForDestinations(Value value, std::optional logicalConsumer = std::nullopt) { + // Destination collection works in the materializer's logical one-lane key domain. + // Whole-batch resultful producers are expanded into per-lane producer keys here. + SmallVector keys; + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return keys; + + while (auto extract = dyn_cast(definingOp)) { + Value source = extract.getSource(); + auto batch = dyn_cast_or_null(source.getDefiningOp()); + if (batch && batch.getNumResults() != 0) { + auto result = dyn_cast(source); + if (!result) + return {}; + + if (std::optional lane = getConstantFirstSliceOffset(extract)) { + if (*lane >= static_cast(batch.getLaneCount())) + return {}; + keys.push_back(getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber())); + return keys; + } + + if (logicalConsumer && isa(logicalConsumer->op)) { + keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); + return keys; + } + + return {}; + } + + value = source; + definingOp = value.getDefiningOp(); + if (!definingOp) + return {}; + } + + if (auto compute = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return {}; + keys.push_back({{compute.getOperation(), 0, 1}, result.getResultNumber()}); + return keys; + } + + if (auto batch = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return {}; + + if (batch.getNumResults() != 0) { + if (logicalConsumer && isa(logicalConsumer->op)) { + keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); + return keys; + } + + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) + keys.push_back(getBatchLaneProducerKey(batch, lane, 1, result.getResultNumber())); + return keys; + } + + ComputeInstance chunk = getBatchChunkForLane(batch, result.getResultNumber()); + keys.push_back({chunk, static_cast(result.getResultNumber() - chunk.laneStart)}); + return keys; + } + + return keys; +} + +std::optional +getInputRequestProducerKey(Value value, std::optional logicalConsumer = std::nullopt) { + // Input resolution may request a whole-batch key for scalar consumers that read + // a complete resultful compute_batch value. Operation* definingOp = value.getDefiningOp(); if (!definingOp) return std::nullopt; @@ -410,23 +492,13 @@ std::optional getProducerKey(Value value, const ComputeInstance* co if (!result) return std::nullopt; - uint32_t laneStart = 0; - uint32_t laneCount = 1; - if (std::optional lane = getConstantFirstSliceOffset(extract)) { - laneStart = *lane; - } - else if (consumerInstance && isa(consumerInstance->op)) { - laneStart = consumerInstance->laneStart; - laneCount = consumerInstance->laneCount; - } - else { - return std::nullopt; - } + if (std::optional lane = getConstantFirstSliceOffset(extract)) + return getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber()); - if (laneStart + laneCount > static_cast(batch.getLaneCount())) - return std::nullopt; + if (logicalConsumer && isa(logicalConsumer->op)) + return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); - return getBatchLaneProducerKey(batch, laneStart, laneCount, result.getResultNumber()); + return std::nullopt; } value = source; @@ -439,10 +511,7 @@ std::optional getProducerKey(Value value, const ComputeInstance* co auto result = dyn_cast(value); if (!result) return std::nullopt; - return ProducerKey { - {compute.getOperation(), 0, 1}, - result.getResultNumber() - }; + return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()}; } if (auto batch = dyn_cast(definingOp)) { @@ -451,10 +520,8 @@ std::optional getProducerKey(Value value, const ComputeInstance* co return std::nullopt; if (batch.getNumResults() != 0) { - if (consumerInstance && isa(consumerInstance->op)) - return getBatchLaneProducerKey( - batch, consumerInstance->laneStart, consumerInstance->laneCount, result.getResultNumber()); - + if (logicalConsumer && isa(logicalConsumer->op)) + return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); return getWholeBatchProducerKey(batch, result.getResultNumber()); } @@ -492,7 +559,51 @@ private: DenseMap parent; }; -LogicalResult buildEquivalenceClasses(MaterializerState& state) { +LogicalResult buildMaterializationWorkStreams(MaterializerState& state) { + DenseMap> scheduledInstancesByCpu; + for (const auto& [instance, cpu] : state.schedule.computeToCpuMap) { + state.oldComputeOps.insert(instance.op); + scheduledInstancesByCpu[cpu].push_back(instance); + state.logicalInstancesByCpu.try_emplace(cpu); + } + + for (auto& [cpu, scheduledInstances] : scheduledInstancesByCpu) { + llvm::sort(scheduledInstances, [&](const ComputeInstance& lhs, const ComputeInstance& rhs) { + auto lhsIt = state.schedule.computeToCpuSlotMap.find(lhs); + auto rhsIt = state.schedule.computeToCpuSlotMap.find(rhs); + assert(lhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); + assert(rhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); + return lhsIt->second < rhsIt->second; + }); + + SmallVector& logicalInstances = state.logicalInstancesByCpu[cpu]; + SlotId logicalSlot = 0; + for (const ComputeInstance& instance : scheduledInstances) { + LogicalSlotRange range {logicalSlot, 1}; + if (isa(instance.op)) + range.count = instance.laneCount; + + state.scheduledInstanceToLogicalSlots[instance] = range; + + if (isa(instance.op)) { + for (uint32_t localLane = 0; localLane < instance.laneCount; ++localLane, ++logicalSlot) { + uint32_t logicalLane = instance.laneStart + localLane; + ComputeInstance logicalInstance {instance.op, logicalLane, 1}; + logicalInstances.push_back(logicalInstance); + state.logicalInstanceToScheduledChunk[logicalInstance] = instance; + } + continue; + } + + logicalInstances.push_back(instance); + ++logicalSlot; + } + } + + return success(); +} + +LogicalResult buildMaterializationClassesFromScheduleEquivalence(MaterializerState& state) { DenseSet usedCpus; for (const auto& entry : state.schedule.cpuToLastComputeMap) usedCpus.insert(entry.first); @@ -535,12 +646,69 @@ LogicalResult buildEquivalenceClasses(MaterializerState& state) { state.classes.push_back(std::move(materializedClass)); } - for (const auto& [instance, cpu] : state.schedule.computeToCpuMap) { - auto slotIt = state.schedule.computeToCpuSlotMap.find(instance); - if (slotIt == state.schedule.computeToCpuSlotMap.end()) - return instance.op->emitError("schedule materialization expected a CPU slot for every compute instance"); - state.cpuSlotToInstance[{cpu, slotIt->second}] = instance; - state.oldComputeOps.insert(instance.op); + return success(); +} + +LogicalResult verifyScheduleEquivalenceMatchesLogicalStreams(MaterializerState& state) { + for (const MaterializedClass& materializedClass : state.classes) { + if (materializedClass.cpus.empty()) + continue; + + auto referenceIt = state.logicalInstancesByCpu.find(materializedClass.cpus.front()); + if (referenceIt == state.logicalInstancesByCpu.end()) + return state.func.emitError("missing logical stream for materialized class reference CPU"); + + ArrayRef referenceStream(referenceIt->second); + for (CpuId cpu : materializedClass.cpus) { + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end()) + return state.func.emitError("missing logical stream for materialized class CPU"); + + ArrayRef stream(streamIt->second); + if (stream.size() != referenceStream.size()) + return state.func.emitError("materialized class CPUs have mismatched logical stream lengths"); + + for (auto [slot, zipped] : llvm::enumerate(llvm::zip(referenceStream, stream))) { + const ComputeInstance& referenceInstance = std::get<0>(zipped); + const ComputeInstance& currentInstance = std::get<1>(zipped); + if (referenceInstance.op != currentInstance.op) + return state.func.emitError("materialized class logical slot source op mismatch"); + if (isa(referenceInstance.op) != isa(currentInstance.op)) + return state.func.emitError("materialized class logical slot batch/scalar mismatch"); + (void)slot; + } + } + } + + return success(); +} + +LogicalResult forEachLogicalConsumerInMaterializationOrder( + MaterializerState& state, + llvm::function_ref + callback) { + for (const ComputeInstance& scheduledInstance : state.schedule.dominanceOrderCompute) { + auto cpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return scheduledInstance.op->emitError("missing CPU assignment for scheduled logical-slot iteration"); + + auto rangeIt = state.scheduledInstanceToLogicalSlots.find(scheduledInstance); + if (rangeIt == state.scheduledInstanceToLogicalSlots.end()) + return scheduledInstance.op->emitError("missing logical slot range for scheduled logical-slot iteration"); + + CpuId cpu = cpuIt->second; + ClassId classId = state.cpuToClass.lookup(cpu); + LogicalSlotRange range = rangeIt->second; + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end()) + return scheduledInstance.op->emitError("missing logical stream for CPU"); + for (SlotId logicalSlot = range.start; logicalSlot < range.start + range.count; ++logicalSlot) { + if (logicalSlot >= streamIt->second.size()) + return scheduledInstance.op->emitError("missing logical slot materialization instance"); + if (failed(callback(cpu, classId, scheduledInstance, streamIt->second[logicalSlot], logicalSlot))) + return failure(); + } } return success(); @@ -932,7 +1100,7 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta continue; for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { - std::optional slotKey = getContiguousProducerKeyForKeys(slot.keys); + std::optional slotKey = getContiguousProducerRangeForKeys(slot.keys); if (!slotKey || !containsProducerKey(*slotKey, key)) continue; @@ -990,9 +1158,6 @@ std::optional AvailableValueStore::lookup(MaterializerState& state, Produ if (std::optional packedRunValue = lookupPackedRun(state, key, classId)) return packedRunValue; - if (key.instance.laneCount != 1) - return std::nullopt; - MaterializedClass& materializedClass = state.classes[classId]; ProducerKey containingKey; @@ -1202,14 +1367,14 @@ Value createLaneIndexedIndexValue(MaterializerState& state, } FailureOr> -getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) { +getPeerLogicalInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId logicalSlot) { SmallVector peers; peers.reserve(materializedClass.cpus.size()); for (CpuId cpu : materializedClass.cpus) { - auto it = state.cpuSlotToInstance.find({cpu, slot}); - if (it == state.cpuSlotToInstance.end()) + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) return failure(); - peers.push_back(it->second); + peers.push_back(streamIt->second[logicalSlot]); } return peers; } @@ -1279,34 +1444,30 @@ void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSetemitError("schedule materialization expected a CPU assignment for every compute instance"); - ClassId targetClass = state.cpuToClass.lookup(cpuIt->second); + return forEachLogicalConsumerInMaterializationOrder( + state, + [&](CpuId, ClassId targetClass, ComputeInstance scheduledConsumer, ComputeInstance logicalConsumer, SlotId) + -> LogicalResult { + for (Value input : getComputeInstanceInputs(scheduledConsumer)) { + for (ProducerKey producerKey : collectProducerKeysForDestinations(input, logicalConsumer)) { + ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producerKey.instance); + auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); + if (producerCpuIt == state.schedule.computeToCpuMap.end()) + return logicalConsumer.op->emitError( + "schedule materialization found an input produced by an unscheduled compute"); - for (Value input : getComputeInstanceInputs(consumer)) { - std::optional producer = getProducerKey(input, &consumer); - if (!producer) - continue; + ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); + if (sourceClass == targetClass) { + state.sameClassConsumers[producerKey].insert(targetClass); + continue; + } - for (ProducerKey producerKey : expandWholeBatchProducerKey(*producer)) { - auto producerCpuIt = state.schedule.computeToCpuMap.find(producerKey.instance); - if (producerCpuIt == state.schedule.computeToCpuMap.end()) - return consumer.op->emitError("schedule materialization found an input produced by an unscheduled compute"); - - ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); - if (sourceClass == targetClass) { - state.sameClassConsumers[producerKey].insert(targetClass); - continue; + appendDestinationClass(state, producerKey, targetClass); } - - appendDestinationClass(state, producerKey, targetClass); } - } - } - return success(); + return success(); + }); } static bool isLaneProjectedOffsetValue(Value value, Value expected, bool& usesExpected) { @@ -1470,82 +1631,89 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { DenseMap, ProducerKeyInfo> pending; - for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) { - auto batch = dyn_cast(consumer.op); - if (!batch || consumer.laneCount != 1) - continue; + if (failed(forEachLogicalConsumerInMaterializationOrder( + state, + [&](CpuId cpu, + ClassId targetClassId, + ComputeInstance consumer, + ComputeInstance logicalConsumer, + SlotId logicalSlot) -> LogicalResult { + auto batch = dyn_cast(consumer.op); + if (!batch) + return success(); - auto cpuIt = state.schedule.computeToCpuMap.find(consumer); - if (cpuIt == state.schedule.computeToCpuMap.end()) - return consumer.op->emitError("projected transfer collection expected scheduled consumer"); + MaterializedClass& targetClass = state.classes[targetClassId]; + if (!targetClass.isBatch) + return success(); - ClassId targetClassId = state.cpuToClass.lookup(cpuIt->second); - MaterializedClass& targetClass = state.classes[targetClassId]; - if (!targetClass.isBatch) - continue; + auto targetLaneIt = targetClass.cpuToLane.find(cpu); + if (targetLaneIt == targetClass.cpuToLane.end()) + return consumer.op->emitError("projected transfer collection could not recover target lane"); - auto targetLaneIt = targetClass.cpuToLane.find(cpuIt->second); - if (targetLaneIt == targetClass.cpuToLane.end()) - return consumer.op->emitError("projected transfer collection could not recover target lane"); + unsigned targetLane = targetLaneIt->second; - unsigned targetLane = targetLaneIt->second; + for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { + SmallVector producers = collectProducerKeysForDestinations(input, logicalConsumer); + if (producers.size() != 1) + continue; + ProducerKey producer = producers.front(); - for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { - std::optional producer = getProducerKey(input, &consumer); - if (!producer) - continue; + ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producer.instance); + auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); + if (producerCpuIt == state.schedule.computeToCpuMap.end()) + continue; - auto producerCpuIt = state.schedule.computeToCpuMap.find(producer->instance); - if (producerCpuIt == state.schedule.computeToCpuMap.end()) - continue; + ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second); + if (sourceClassId == targetClassId) + continue; - ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second); - if (sourceClassId == targetClassId) - continue; + std::optional extract = + matchSimpleLaneProjectedInput(batch, static_cast(inputIndex)); + if (!extract) + continue; - std::optional extract = - matchSimpleLaneProjectedInput(batch, static_cast(inputIndex)); - if (!extract) - continue; + auto inputType = cast(extract->getSource().getType()); + auto fragmentType = cast((*extract).getResult().getType()); + SmallVector offsets = extract->getMixedOffsets(); + std::optional sourceProjectedDim = getLaneProjectedDim(offsets, *batch.getLaneArgument()); + if (!sourceProjectedDim) + continue; - auto inputType = cast(extract->getSource().getType()); - auto fragmentType = cast((*extract).getResult().getType()); - SmallVector offsets = extract->getMixedOffsets(); - std::optional sourceProjectedDim = getLaneProjectedDim(offsets, *batch.getLaneArgument()); - if (!sourceProjectedDim) - continue; + PendingProjectedTransferDescriptor& descriptor = pending[producer][targetClassId]; + if (descriptor.offsetsByLane.empty()) { + descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; + descriptor.extractOp = extract->getOperation(); + descriptor.fragmentType = fragmentType; + descriptor.sourceProjectedDim = *sourceProjectedDim; + descriptor.offsetsByLane.resize(targetClass.cpus.size()); + } - PendingProjectedTransferDescriptor& descriptor = pending[*producer][targetClassId]; - if (descriptor.offsetsByLane.empty()) { - descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; - descriptor.extractOp = extract->getOperation(); - descriptor.fragmentType = fragmentType; - descriptor.sourceProjectedDim = *sourceProjectedDim; - descriptor.offsetsByLane.resize(targetClass.cpus.size()); - } + ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; + if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation() + || descriptor.fragmentType != fragmentType || descriptor.sourceProjectedDim != *sourceProjectedDim) { + descriptor.invalid = true; + continue; + } - ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; - if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation() - || descriptor.fragmentType != fragmentType || descriptor.sourceProjectedDim != *sourceProjectedDim) { - descriptor.invalid = true; - continue; - } + if (targetLane >= descriptor.offsetsByLane.size()) { + descriptor.invalid = true; + continue; + } - if (targetLane >= descriptor.offsetsByLane.size()) { - descriptor.invalid = true; - continue; - } + FailureOr offset = evaluateProjectedOffsetForLane( + offsets[*sourceProjectedDim], *batch.getLaneArgument(), logicalConsumer.laneStart); + if (failed(offset) || !isStaticSliceInBounds(*offset, inputType, fragmentType, *sourceProjectedDim)) { + descriptor.invalid = true; + continue; + } - FailureOr offset = - evaluateProjectedOffsetForLane(offsets[*sourceProjectedDim], *batch.getLaneArgument(), consumer.laneStart); - if (failed(offset) || !isStaticSliceInBounds(*offset, inputType, fragmentType, *sourceProjectedDim)) { - descriptor.invalid = true; - continue; - } + (void)logicalSlot; + descriptor.offsetsByLane[targetLane].push_back(*offset); + } - descriptor.offsetsByLane[targetLane].push_back(*offset); - } - } + return success(); + }))) + return failure(); for (auto& producerEntry : pending) { ProducerKey producer = producerEntry.first; @@ -1595,14 +1763,6 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { return success(); } -SmallVector getOutputKeysForPeers(ArrayRef peers, size_t resultIndex) { - SmallVector keys; - keys.reserve(peers.size()); - for (const ComputeInstance& peer : peers) - keys.push_back({peer, resultIndex}); - return keys; -} - bool haveSameDestinationClasses(MaterializerState& state, ArrayRef keys) { if (keys.empty()) return true; @@ -2166,7 +2326,7 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, return sourceClass.op->emitError("scalar-source communication must be emitted through the scalar fanout planner"); if (!targetClass.isBatch) { - std::optional packedKey = getContiguousProducerKeyForKeys(keys); + std::optional packedKey = getContiguousProducerRangeForKeys(keys); if (!packedKey) return sourceClass.op->emitError( "cannot materialize batch-to-scalar communication because source lanes are not contiguous"); @@ -2472,7 +2632,6 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& SmallVector keys = flattenPackedScalarRunKeys(run); if (keys.empty()) return failure(); - FailureOr packedType = getPackedBatchTensorType(run.fragmentType, keys.size()); if (failed(packedType)) return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor"); @@ -2527,207 +2686,6 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& return run.packed; } -[[maybe_unused]] FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const WholeBatchAssemblyPlan& plan, - PackedScalarRunValue& run, - Location loc) { - assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); - - if (run.fragmentType.getDimSize(0) != plan.rowsPerLane) - return failure(); - - SmallVector keys = flattenPackedScalarRunKeys(run); - if (keys.empty()) - return failure(); - - SmallVector sourceLanes; - SmallVector outputOffsets; - sourceLanes.reserve(keys.size()); - outputOffsets.reserve(keys.size()); - - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return failure(); - - sourceLanes.push_back(key.instance.laneStart); - outputOffsets.push_back(static_cast(key.instance.laneStart) * plan.rowsPerLane); - } - - SmallVector resultIndices {run.resultIndex}; - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {destination}, - [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - keys.front().instance, - sourceLane, - resultIndices, - CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); - if (failed(produced) || produced->size() != 1) - return failure(); - - Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, loopIndex, loc); - Value next = insertFragmentIntoWholeBatch(state, produced->front(), acc, outputOffset, loc); - yielded.push_back(next); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - -[[maybe_unused]] FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const WholeBatchAssemblyPlan& plan, - PackedScalarRunValue& run, - Location loc) { - assert(run.kind == PackedScalarRunKind::DeferredReceive && "expected deferred receive packed scalar run"); - - if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) - return failure(); - - if (run.fragmentType.getDimSize(0) != plan.rowsPerLane) - return failure(); - - SmallVector outputOffsets; - outputOffsets.reserve(getPackedScalarRunReceiveCount(run)); - - for (const PackedScalarRunSlot& slot : run.slots) { - for (ProducerKey key : slot.keys) { - if (key.instance.laneCount == 0) - return failure(); - - outputOffsets.push_back(static_cast(key.instance.laneStart) * plan.rowsPerLane); - } - } - - if (outputOffsets.size() != run.messages.size()) - return failure(); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.messages.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {destination}, - [&](OpBuilder&, Location, Value index, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, index, loc); - - Value received = - SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, index, loc); - Value next = insertFragmentIntoWholeBatch(state, received, acc, outputOffset, loc); - yielded.push_back(next); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - -[[maybe_unused]] FailureOr insertPackedScalarRunIntoWholeBatch(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const WholeBatchAssemblyPlan& plan, - PackedScalarRunValue& run, - Location loc) { - if (run.slots.empty()) - return destination; - - if (!run.packed) { - if (isDeferredLocalPackedScalarRun(run)) - return insertDeferredLocalPackedScalarRunIntoWholeBatch(state, targetClass, destination, plan, run, loc); - return insertDeferredPackedScalarRunIntoWholeBatch(state, targetClass, destination, plan, run, loc); - } - - auto sourceBatch = dyn_cast_or_null(run.sourceOp); - if (!sourceBatch) - return failure(); - - int64_t batchLaneCount = sourceBatch.getLaneCount(); - if (batchLaneCount <= 0) - return failure(); - - if (run.fragmentType.getDimSize(0) != plan.rowsPerLane) - return failure(); - - size_t slotLaneCount = run.slots.front().keys.size(); - if (slotLaneCount == 0) - return failure(); - - FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slotLaneCount); - if (failed(slotPackedType)) - return failure(); - - SmallVector slotRowOffsets; - slotRowOffsets.reserve(run.slots.size()); - - for (const PackedScalarRunSlot& slot : run.slots) { - if (slot.keys.size() != slotLaneCount) - return failure(); - - std::optional slotKey = getContiguousProducerKeyForKeys(slot.keys); - if (!slotKey) - return failure(); - - slotRowOffsets.push_back(static_cast(slotKey->instance.laneStart) * plan.rowsPerLane); - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.slots.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {destination}, - [&](OpBuilder&, Location, Value slotIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value slotPacked = extractPackedSlotForIndex(state, targetClass.op, run.packed, *slotPackedType, slotIndex, loc); - Value outputOffset = createIndexedIndexValue(state, targetClass.op, slotRowOffsets, slotIndex, loc); - Value next = insertFragmentIntoWholeBatch(state, slotPacked, acc, outputOffset, loc); - yielded.push_back(next); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); -} - LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, @@ -2905,7 +2863,7 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, } for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { - std::optional slotKey = getContiguousProducerKeyForKeys(slot.keys); + std::optional slotKey = getContiguousProducerRangeForKeys(slot.keys); if (!slotKey) return failure(); groupIt->slotIndices.push_back(slotIndex); @@ -3079,7 +3037,7 @@ FailureOr resolveInputValue(MaterializerState& state, if (isConstantLike(input)) return input; - if (std::optional producer = getProducerKey(input, &consumerInstance)) { + if (std::optional producer = getInputRequestProducerKey(input, consumerInstance)) { if (indexing.runSlotIndex) { if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) return materializeIndexedBatchRunReceive( @@ -3104,9 +3062,19 @@ FailureOr resolveInputValue(MaterializerState& state, } } - if (isWholeBatchProducerKey(*producer)) - return materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); + if (isWholeBatchProducerKey(*producer)) { + FailureOr wholeBatch = + materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); + if (failed(wholeBatch)) + consumerInstance.op->emitError("failed to materialize whole-batch input") + << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + return wholeBatch; + } + consumerInstance.op->emitError("failed to resolve producer value") + << " from op '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; return failure(); } @@ -3159,8 +3127,15 @@ LogicalResult mapInputs(MaterializerState& state, if (auto compute = dyn_cast(op)) { for (auto [index, input] : llvm::enumerate(compute.getInputs())) { FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); - if (failed(mapped)) - return compute.emitOpError("failed to resolve materialized compute input"); + if (failed(mapped)) { + std::optional producer = getInputRequestProducerKey(input, instance); + auto diagnostic = compute.emitOpError("failed to resolve materialized compute input") << " #" << index; + if (producer) { + diagnostic << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + } + return failure(); + } auto inputArg = compute.getInputArgument(index); if (!inputArg) return compute.emitOpError("expected compute input block argument while materializing inputs"); @@ -3293,7 +3268,9 @@ FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, } FailureOr> -cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef peers) { +cloneInstanceBody(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef peers) { assert(!peers.empty() && "expected at least one peer instance"); const ComputeInstance& instance = peers.front(); Operation* sourceOp = instance.op; @@ -3308,10 +3285,6 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra sourceOp->emitError("equivalence class slot contains different source compute_batch operations"); return failure(); } - if (peer.laneCount != 1) { - sourceOp->emitError("schedule materialization currently expects one original batch lane per CPU"); - return failure(); - } } auto laneArg = batch.getLaneArgument(); if (!laneArg) { @@ -3512,7 +3485,7 @@ LogicalResult registerPackedRunValue(MaterializerState& state, return materializedClass.op->emitError("packed run registration expects one lane per packed fragment"); } - if (std::optional contiguousKey = getContiguousProducerKeyForKeys(keys)) { + if (std::optional contiguousKey = getContiguousProducerRangeForKeys(keys)) { state.availableValues.record(*contiguousKey, materializedClass.id, packed); return success(); } @@ -3643,11 +3616,12 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta if (run.front().peers.size() != 1) return sourceOp->emitError("scalar batch output loop expects exactly one peer in singleton slot"); - const ComputeInstance& instance = run.front().peers.front(); + const ComputeInstance& item = run.front().peers.front(); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, instance.laneStart); - return cloneBatchBodyForLane(state, targetClass, instance, laneValue, group.resultIndices, {}); + Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, item.laneStart); + return cloneBatchBodyForLane( + state, targetClass, item, laneValue, group.resultIndices, {}); } state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); @@ -3669,18 +3643,18 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult()); } - SmallVector laneStarts; - laneStarts.reserve(run.size()); + SmallVector logicalLanes; + logicalLanes.reserve(run.size()); for (const MaterializationRunSlot& slot : run) { if (slot.peers.size() != 1) return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot"); - const ComputeInstance& instance = slot.peers.front(); - if (instance.op != sourceOp) + const ComputeInstance& item = slot.peers.front(); + if (item.op != sourceOp) return sourceOp->emitError("materialization run contains different source operations"); - laneStarts.push_back(instance.laneStart); + logicalLanes.push_back(item.laneStart); } Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); @@ -3696,7 +3670,7 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta step, ValueRange(initValues), [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc); + Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); FailureOr> produced = cloneBatchBodyForLane(state, @@ -3737,15 +3711,15 @@ SmallVector getMaterializationRunSlotOutputKeys(const Materializ } FailureOr> -getMaterializationRunSlotPeers(MaterializerState& state, MaterializedClass& targetClass, SlotId slot) { +getMaterializationRunSlotPeers(MaterializerState& state, MaterializedClass& targetClass, SlotId logicalSlot) { if (targetClass.isBatch) - return getPeerInstances(state, targetClass, slot); + return getPeerLogicalInstances(state, targetClass, logicalSlot); - auto instanceIt = state.cpuSlotToInstance.find({targetClass.cpus.front(), slot}); - if (instanceIt == state.cpuSlotToInstance.end()) + auto streamIt = state.logicalInstancesByCpu.find(targetClass.cpus.front()); + if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) return failure(); - return SmallVector {instanceIt->second}; + return SmallVector {streamIt->second[logicalSlot]}; } FailureOr collectBatchMaterializationRun(MaterializerState& state, @@ -3756,10 +3730,11 @@ FailureOr collectBatchMaterializationRun(MaterializerState& for (SlotId slot = startSlot;; ++slot) { ClassSlotKey classSlot {targetClass.id, slot}; - if (state.materializedSlots.contains(classSlot)) + if (state.materializedLogicalSlots.contains(classSlot)) break; - FailureOr> peers = getMaterializationRunSlotPeers(state, targetClass, slot); + FailureOr> peers = + getMaterializationRunSlotPeers(state, targetClass, slot); if (failed(peers) || peers->empty()) break; @@ -3769,9 +3744,6 @@ FailureOr collectBatchMaterializationRun(MaterializerState& validSlot = false; break; } - - if (peer.laneCount != 1) - return peer.op->emitError("batch run materialization expects one scheduled source lane per materialized lane"); } if (!validSlot) @@ -3838,12 +3810,30 @@ bool hasMaterializationRunGroupLiveExternalUse(MaterializerState& state, return false; } +bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId); + +bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state, + ClassId classId, + ArrayRef run, + const OutputDestinationGroup& group) { + for (size_t resultIndex : group.resultIndices) { + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + if (hasSameClassConsumer(state, {peer, resultIndex}, classId)) + return true; + } + } + } + + return false; +} + void markMaterializationRunSlots(MaterializerState& state, ClassId classId, SlotId startSlot, ArrayRef run) { for (auto slotIndex : llvm::seq(0, run.size())) - state.materializedSlots.insert({classId, startSlot + static_cast(slotIndex)}); + state.materializedLogicalSlots.insert({classId, startSlot + static_cast(slotIndex)}); } LogicalResult materializeScalarBatchRun(MaterializerState& state, @@ -3864,7 +3854,8 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, for (const OutputDestinationGroup& group : groups) { if (run.size() > 1 && group.destinationClasses.empty() - && !hasMaterializationRunGroupLiveExternalUse(state, run, group)) { + && !hasMaterializationRunGroupLiveExternalUse(state, run, group) + && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group)) { for (size_t resultIndex : group.resultIndices) { if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) return sourceBatch.emitOpError("failed to recover per-lane output type for deferred local packed run"); @@ -3924,8 +3915,15 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, } bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { - auto it = state.sameClassConsumers.find(producerKey); - return it != state.sameClassConsumers.end() && it->second.contains(classId); + for (const auto& [key, consumers] : state.sameClassConsumers) { + if (!consumers.contains(classId)) + continue; + if (!sameProducerResult(key, producerKey)) + continue; + if (containsProducerKey(key, producerKey) || containsProducerKey(producerKey, key)) + return true; + } + return false; } bool canCompactBatchClassRun(MaterializerState& state, @@ -4021,7 +4019,8 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state, plan.messages.sourceCoreIds.reserve(messageCount); plan.messages.targetCoreIds.reserve(messageCount); - for ([[maybe_unused]] const MaterializationRunSlot& slot : run) { + for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) { + (void)slotIndex; for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id"); if (failed(checkedSourceCpu)) @@ -4058,21 +4057,6 @@ void appendBatchRunSend(MaterializerState& state, SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); } -[[maybe_unused]] ArrayRef -sliceChannelsForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { - return ArrayRef(plan.messages.channelIds).slice(slotIndex * laneCount, laneCount); -} - -[[maybe_unused]] ArrayRef -sliceSourcesForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { - return ArrayRef(plan.messages.sourceCoreIds).slice(slotIndex * laneCount, laneCount); -} - -[[maybe_unused]] ArrayRef -sliceTargetsForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { - return ArrayRef(plan.messages.targetCoreIds).slice(slotIndex * laneCount, laneCount); -} - LogicalResult appendPackedScalarRunReceives(MaterializerState& state, MaterializedClass& sourceClass, ArrayRef run, @@ -4199,7 +4183,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, FailureOr> produced = cloneBatchBodyForLane(state, targetClass, - run.front().peers.front(), + getScheduledChunkForLogicalInstance(state, run.front().peers.front()), sourceLane, group.resultIndices, CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); @@ -4235,35 +4219,41 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns auto cpuIt = state.schedule.computeToCpuMap.find(instance); if (cpuIt == state.schedule.computeToCpuMap.end()) return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); - auto slotIt = state.schedule.computeToCpuSlotMap.find(instance); - if (slotIt == state.schedule.computeToCpuSlotMap.end()) - return instance.op->emitError("schedule materialization expected a CPU slot for every compute instance"); + auto logicalRangeIt = state.scheduledInstanceToLogicalSlots.find(instance); + if (logicalRangeIt == state.scheduledInstanceToLogicalSlots.end()) + return instance.op->emitError("schedule materialization expected logical slots for every compute instance"); ClassId classId = state.cpuToClass.lookup(cpuIt->second); MaterializedClass& targetClass = state.classes[classId]; - ClassSlotKey classSlot {classId, slotIt->second}; - if (state.materializedSlots.contains(classSlot)) + LogicalSlotRange logicalRange = logicalRangeIt->second; + SlotId startLogicalSlot = logicalRange.start; + while (startLogicalSlot < logicalRange.start + logicalRange.count + && state.materializedLogicalSlots.contains({classId, startLogicalSlot})) { + ++startLogicalSlot; + } + if (startLogicalSlot == logicalRange.start + logicalRange.count) return success(); if (isa(instance.op)) { - FailureOr run = collectBatchMaterializationRun(state, targetClass, slotIt->second, instance.op); + FailureOr run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); if (succeeded(run)) { if (!targetClass.isBatch) - return materializeScalarBatchRun(state, targetClass, slotIt->second, *run); + return materializeScalarBatchRun(state, targetClass, startLogicalSlot, *run); - if (succeeded(materializeBatchClassRun(state, targetClass, slotIt->second, *run))) + if (succeeded(materializeBatchClassRun(state, targetClass, startLogicalSlot, *run))) return success(); } } - if (!state.materializedSlots.insert(classSlot).second) + if (!state.materializedLogicalSlots.insert({classId, startLogicalSlot}).second) return success(); - FailureOr> peers = getPeerInstances(state, targetClass, slotIt->second); + FailureOr> peers = + getMaterializationRunSlotPeers(state, targetClass, startLogicalSlot); if (failed(peers)) - return instance.op->emitError("failed to collect peer compute instances for equivalence class slot"); + return instance.op->emitError("failed to collect peer compute instances for equivalence class logical slot"); FailureOr> materializedOutputs = cloneInstanceBody(state, targetClass, *peers); if (failed(materializedOutputs)) @@ -4276,7 +4266,9 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns for (auto [resultIndex, zipped] : llvm::enumerate(llvm::zip(*materializedOutputs, originalOutputs))) { Value materializedOutput = std::get<0>(zipped); Value originalOutput = std::get<1>(zipped); - SmallVector keys = getOutputKeysForPeers(*peers, resultIndex); + MaterializationRunSlot slot; + slot.peers = *peers; + SmallVector keys = getMaterializationRunSlotOutputKeys(slot, resultIndex); if (failed(emitOutputFanout(state, targetClass, keys, materializedOutput, originalOutput, instance.op->getLoc()))) return failure(); } @@ -4340,7 +4332,11 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch return success(); MaterializerState state(func, schedule, nextChannelId); - if (failed(buildEquivalenceClasses(state))) + if (failed(buildMaterializationWorkStreams(state))) + return failure(); + if (failed(buildMaterializationClassesFromScheduleEquivalence(state))) + return failure(); + if (failed(verifyScheduleEquivalenceMatchesLogicalStreams(state))) return failure(); if (state.classes.empty()) return success(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp index 95c9c22..9968590 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp @@ -105,6 +105,28 @@ bool isProjectedBatchOffset(OpFoldResult offset, Value laneArg) { && succeeded(evaluateIndexLike(offset, bindings, /*lane=*/1, laneArg)); } +std::optional getConstantExtractLane(tensor::ExtractSliceOp extract) { + if (extract.getMixedOffsets().empty()) + return std::nullopt; + + OpFoldResult offset = extract.getMixedOffsets().front(); + if (auto attr = llvm::dyn_cast(offset)) { + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() < 0) + return std::nullopt; + return static_cast(intAttr.getInt()); + } + + Value offsetValue = llvm::cast(offset); + if (auto constantIndex = offsetValue.getDefiningOp()) { + if (constantIndex.value() < 0) + return std::nullopt; + return static_cast(constantIndex.value()); + } + + return std::nullopt; +} + std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) { auto inputIt = llvm::find(batch.getInputs(), input); if (inputIt == batch.getInputs().end()) @@ -143,6 +165,102 @@ Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input) return static_cast(getSizeInBytes(inputType)); } +uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance& rhs) { + uint32_t lhsEnd = lhs.laneStart + lhs.laneCount; + uint32_t rhsEnd = rhs.laneStart + rhs.laneCount; + return std::max(lhs.laneStart, rhs.laneStart) < std::min(lhsEnd, rhsEnd) + ? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart) + : 0; +} + +Cost scaleTransferCostByLaneCount(Cost totalCost, uint32_t totalLaneCount, uint32_t fragmentLaneCount) { + assert(totalLaneCount > 0 && "laneCount must be positive"); + assert(fragmentLaneCount > 0 && "fragmentLaneCount must be positive"); + if (fragmentLaneCount >= totalLaneCount) + return totalCost; + return checkedMultiply(totalCost, static_cast(fragmentLaneCount)) / static_cast(totalLaneCount); +} + +SmallVector collectProducerValueRefs(Value value, const ComputeInstance& consumerInstance) { + SmallVector producers; + Operation* op = value.getDefiningOp(); + if (!op) + return producers; + + while (auto extract = dyn_cast(op)) { + Value source = extract.getSource(); + auto batch = dyn_cast_or_null(source.getDefiningOp()); + if (batch && batch.getNumResults() != 0) { + if (std::optional lane = getConstantExtractLane(extract)) { + ComputeInstance instance = getBatchChunkForLane(batch, *lane); + producers.push_back({instance, 0}); + return producers; + } + + if (isa(consumerInstance.op)) { + for (ComputeInstance instance : + getBatchChunksForRange(batch, consumerInstance.laneStart, consumerInstance.laneCount)) + producers.push_back({instance, 0}); + } + else { + for (ComputeInstance instance : + getBatchChunksForRange(batch, 0, static_cast(batch.getLaneCount()))) + producers.push_back({instance, 0}); + } + return producers; + } + + value = source; + op = value.getDefiningOp(); + if (!op) + return producers; + } + + if (auto compute = dyn_cast(op)) { + producers.push_back({ComputeInstance {compute.getOperation(), 0, 1}, + static_cast(cast(value).getResultNumber())}); + return producers; + } + + if (auto batch = dyn_cast(op)) { + if (batch.getNumResults() != 0) { + uint32_t laneStart = isa(consumerInstance.op) ? consumerInstance.laneStart : 0; + uint32_t laneCount = isa(consumerInstance.op) + ? consumerInstance.laneCount + : static_cast(batch.getLaneCount()); + for (ComputeInstance instance : getBatchChunksForRange(batch, laneStart, laneCount)) + producers.push_back({instance, 0}); + return producers; + } + + uint32_t lane = cast(value).getResultNumber(); + ComputeInstance instance = getBatchChunkForLane(batch, lane); + producers.push_back({instance, lane - instance.laneStart}); + return producers; + } + + return producers; +} + +Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstance, const ProducerValueRef& producerRef) { + Cost transferCost = getInputTransferCost(consumerInstance, input); + auto producerBatch = dyn_cast(producerRef.instance.op); + if (!producerBatch || producerBatch.getNumResults() == 0) + return transferCost; + + if (auto consumerBatch = dyn_cast(consumerInstance.op)) { + if (std::optional projectedCost = getBatchProjectedInputTransferCost(consumerBatch, input)) { + uint32_t overlapLaneCount = getLaneOverlapCount(consumerInstance, producerRef.instance); + assert(overlapLaneCount > 0 && "projected batch edge must overlap consumer lanes"); + return checkedMultiply(*projectedCost, static_cast(overlapLaneCount)); + } + } + + return scaleTransferCostByLaneCount(transferCost, + static_cast(producerBatch.getLaneCount()), + producerRef.instance.laneCount); +} + static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional lane) { CrossbarWeight weight; weight.opaqueValue = value; @@ -458,25 +576,13 @@ ComputeGraph buildComputeGraph(Operation* entryOp) { for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) { llvm::SmallVector inputs = getComputeInstanceInputs(node.instance); for (Value input : inputs) { - Cost transferCost = getInputTransferCost(node.instance, input); - if (auto producerBatch = dyn_cast_or_null(input.getDefiningOp()); - producerBatch && producerBatch.getNumResults() != 0 && !isa(node.instance.op)) { - for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) { - auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane)); - if (producerIt == graph.instanceToIndex.end()) - continue; - rawEdges.push_back({producerIt->second, targetIndex, transferCost}); - } - continue; + for (const ProducerValueRef& producerRef : collectProducerValueRefs(input, node.instance)) { + auto producerIt = graph.instanceToIndex.find(producerRef.instance); + if (producerIt == graph.instanceToIndex.end()) + continue; + rawEdges.push_back( + {producerIt->second, targetIndex, getProducerTransferCost(input, node.instance, producerRef)}); } - - auto producerInstance = getComputeProducerInstance(input, &node.instance); - if (!producerInstance) - continue; - auto producerIt = graph.instanceToIndex.find(*producerInstance); - if (producerIt == graph.instanceToIndex.end()) - continue; - rawEdges.push_back({producerIt->second, targetIndex, transferCost}); } } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp index 698f689..e0c1db0 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp @@ -20,17 +20,67 @@ size_t getSchedulingCpuBudget() { size_t getBatchChunkTargetCount(int32_t laneCount) { assert(laneCount > 0 && "laneCount must be positive"); - return static_cast(laneCount); + return std::min(static_cast(laneCount), getSchedulingCpuBudget()); +} + +BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex) { + assert(laneCount > 0 && "laneCount must be positive"); + size_t chunkCount = getBatchChunkTargetCount(laneCount); + assert(chunkIndex < chunkCount && "chunkIndex out of range"); + + size_t laneCountSize = static_cast(laneCount); + size_t baseChunkSize = laneCountSize / chunkCount; + size_t remainder = laneCountSize % chunkCount; + size_t extraBefore = std::min(chunkIndex, remainder); + size_t start = chunkIndex * baseChunkSize + extraBefore; + size_t count = baseChunkSize + (chunkIndex < remainder ? 1 : 0); + assert(count > 0 && "chunk size must be positive"); + return {static_cast(start), static_cast(count)}; +} + +size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane) { + assert(laneCount > 0 && "laneCount must be positive"); + assert(lane < static_cast(laneCount) && "lane out of range"); + + size_t chunkCount = getBatchChunkTargetCount(laneCount); + size_t laneCountSize = static_cast(laneCount); + size_t baseChunkSize = laneCountSize / chunkCount; + size_t remainder = laneCountSize % chunkCount; + size_t largeChunkSize = baseChunkSize + 1; + size_t laneIndex = static_cast(lane); + size_t largerChunkLanes = remainder * largeChunkSize; + + if (laneIndex < largerChunkLanes) + return laneIndex / largeChunkSize; + return remainder + ((laneIndex - largerChunkLanes) / baseChunkSize); } ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { - assert(chunkIndex < static_cast(batch.getLaneCount()) && "chunkIndex out of range"); - return {batch.getOperation(), static_cast(chunkIndex), 1}; + BatchChunkRange chunk = getBatchChunkRange(batch.getLaneCount(), chunkIndex); + return {batch.getOperation(), chunk.laneStart, chunk.laneCount}; } ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) { - assert(lane < static_cast(batch.getLaneCount()) && "lane out of range"); - return {batch.getOperation(), lane, 1}; + return getBatchChunkForIndex(batch, getBatchChunkIndexForLane(batch.getLaneCount(), lane)); +} + +llvm::SmallVector getBatchChunksForRange(SpatComputeBatch batch, + uint32_t laneStart, + uint32_t laneCount) { + llvm::SmallVector chunks; + if (laneCount == 0) + return chunks; + + uint32_t laneEnd = laneStart + laneCount; + assert(laneEnd >= laneStart && "lane range overflow"); + assert(laneEnd <= static_cast(batch.getLaneCount()) && "lane range out of bounds"); + + size_t firstChunk = getBatchChunkIndexForLane(batch.getLaneCount(), laneStart); + size_t lastChunk = getBatchChunkIndexForLane(batch.getLaneCount(), laneEnd - 1); + chunks.reserve(lastChunk - firstChunk + 1); + for (size_t chunkIndex = firstChunk; chunkIndex <= lastChunk; ++chunkIndex) + chunks.push_back(getBatchChunkForIndex(batch, chunkIndex)); + return chunks; } static std::optional getConstantExtractLane(tensor::ExtractSliceOp extract) { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp index c257a1c..39adb83 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp @@ -21,10 +21,20 @@ struct ProducerValueRef { size_t resultIndex = 0; }; +struct BatchChunkRange { + uint32_t laneStart = 0; + uint32_t laneCount = 0; +}; + size_t getSchedulingCpuBudget(); size_t getBatchChunkTargetCount(int32_t laneCount); +BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex); +size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane); ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex); ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane); +llvm::SmallVector getBatchChunksForRange(SpatComputeBatch batch, + uint32_t laneStart, + uint32_t laneCount); std::optional getProducerValueRef(mlir::Value value, const ComputeInstance* consumerInstance = nullptr); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp index 508de55..6c9f486 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp @@ -1,11 +1,9 @@ #include "mlir/IR/Threading.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" -#include #include #include #include