From 37a59054a52bc564873f349e539f94076a82d94c Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 3 Jun 2026 16:01:19 +0200 Subject: [PATCH] better loop compaction in MaterializeMergeSchedule.cpp --- .../MaterializeMergeSchedule.cpp | 1033 ++++++++++++----- 1 file changed, 739 insertions(+), 294 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 06df444..2253a75 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -55,6 +55,43 @@ getCheckedCoreIds(Operation* anchor, ArrayRef cpus, StringRef fieldName) return coreIds; } +struct MessageVector { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + size_t size() const { return channelIds.size(); } + bool empty() const { return channelIds.empty(); } + + LogicalResult verify(Operation* anchor) const { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return anchor->emitError("message metadata is inconsistent"); + return success(); + } + + void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) { + channelIds.push_back(channelId); + sourceCoreIds.push_back(sourceCoreId); + targetCoreIds.push_back(targetCoreId); + } + + void append(ArrayRef channels, ArrayRef sources, ArrayRef targets) { + assert(channels.size() == sources.size() && "channel/source count mismatch"); + assert(channels.size() == targets.size() && "channel/target count mismatch"); + llvm::append_range(channelIds, channels); + llvm::append_range(sourceCoreIds, sources); + llvm::append_range(targetCoreIds, targets); + } + + MessageVector slice(size_t offset, size_t count) const { + MessageVector result; + result.append(ArrayRef(channelIds).slice(offset, count), + ArrayRef(sourceCoreIds).slice(offset, count), + ArrayRef(targetCoreIds).slice(offset, count)); + return result; + } +}; + struct ProducerKey { ComputeInstance instance; size_t resultIndex = 0; @@ -119,10 +156,16 @@ struct PackedScalarRunValue { RankedTensorType fragmentType; SmallVector slots; + MessageVector messages; +}; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; +struct IndexedBatchRunValue { + ClassId targetClass = 0; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + RankedTensorType fragmentType; + SmallVector slots; + MessageVector messages; }; struct MaterializationRunSlot { @@ -139,9 +182,7 @@ struct OutputDestinationGroup { struct BatchRunSendPlan { size_t resultIndex = 0; ClassId destinationClass = 0; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; + MessageVector messages; }; struct ProjectedBatchInputKey { @@ -186,6 +227,11 @@ struct ProjectedExtractReplacement { unsigned fragmentsPerLane = 1; }; +struct CloneIndexingContext { + std::optional runSlotIndex; + std::optional projectionSlotIndex; +}; + struct MaterializerState; class AvailableValueStore { @@ -193,10 +239,12 @@ public: void record(ProducerKey key, ClassId classId, Value value) { exactValues[key][classId] = value; } void recordPackedRun(PackedScalarRunValue run) { packedScalarRuns.push_back(std::move(run)); } + void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } std::optional lookupExact(ProducerKey key, ClassId classId) const; std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); + IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); SmallVectorImpl& getPackedScalarRuns() { return packedScalarRuns; } @@ -205,6 +253,7 @@ private: DenseMap, ProducerKeyInfo> exactValues; SmallVector packedScalarRuns; + SmallVector indexedBatchRuns; }; struct MaterializerState { @@ -801,24 +850,17 @@ Value getPackedSliceForRunIndex(MaterializerState& state, return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); } -SmallVector widenToI64(ArrayRef values) { - SmallVector widened; - widened.reserve(values.size()); - for (int32_t value : values) - widened.push_back(value); - return widened; -} - FailureOr createReceiveConcatLoop(MaterializerState& state, Operation* anchor, Operation* insertionPoint, RankedTensorType concatType, RankedTensorType fragmentType, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, + const MessageVector& messages, Location loc); +using IndexedFragmentBuilder = llvm::function_ref(Value flatIndex)>; +using IndexedInsertOffsetBuilder = llvm::function_ref(Value flatIndex)>; + FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, MaterializedClass& targetClass, PackedScalarRunValue& run, @@ -844,9 +886,11 @@ LogicalResult validatePackedScalarRunMetadata(Operation* anchor, const PackedSca if (receiveCount == 0) return anchor->emitError("packed scalar run has no receives"); - if (run.channelIds.size() != receiveCount || run.sourceCoreIds.size() != receiveCount - || run.targetCoreIds.size() != receiveCount) - return anchor->emitError("packed scalar run receive metadata is inconsistent"); + if (failed(run.messages.verify(anchor))) + return failure(); + + if (run.messages.size() != receiveCount) + return anchor->emitError("packed scalar run receive metadata count is inconsistent"); return success(); } @@ -872,18 +916,8 @@ FailureOr materializePackedScalarRunValue(MaterializerState& state, if (failed(fullPackedType)) return targetClass.op->emitError("cannot create lazy packed scalar run receive type"); - SmallVector sourceCoreIds = widenToI64(run.sourceCoreIds); - SmallVector targetCoreIds = widenToI64(run.targetCoreIds); - - auto packed = createReceiveConcatLoop(state, - targetClass.op, - targetClass.body->getTerminator(), - *fullPackedType, - run.fragmentType, - run.channelIds, - sourceCoreIds, - targetCoreIds, - loc); + auto packed = createReceiveConcatLoop( + state, targetClass.op, targetClass.body->getTerminator(), *fullPackedType, run.fragmentType, run.messages, loc); if (failed(packed)) return failure(); run.packed = *packed; @@ -934,6 +968,21 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta return std::nullopt; } +IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) { + for (IndexedBatchRunValue& run : indexedBatchRuns) { + if (run.targetClass != classId) + continue; + if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) + continue; + + for (const PackedScalarRunSlot& slot : run.slots) + if (llvm::is_contained(slot.keys, key)) + return &run; + } + + return nullptr; +} + std::optional AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { if (std::optional exact = lookupExact(key, classId)) return exact; @@ -1108,6 +1157,21 @@ Value createIndexedIndexValue( return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc); } +Value createIndexedChannelId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.channelIds), index, loc); +} + +Value createIndexedSourceCoreId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.sourceCoreIds), index, loc); +} + +Value createIndexedTargetCoreId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.targetCoreIds), index, loc); +} + Value createLaneIndexedIndexValue(MaterializerState& state, MaterializedClass& materializedClass, ArrayRef values, @@ -1587,20 +1651,17 @@ void appendScalarSend(MaterializerState& state, LogicalResult appendScalarSendLoop(MaterializerState& state, MaterializedClass& sourceClass, Value payload, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, + const MessageVector& messages, Location loc) { assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); - assert(channelIds.size() > 1 && "send loop is only useful for multiple sends"); - assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); - assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); + assert(messages.size() > 1 && "send loop is only useful for multiple sends"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(channelIds.size())); + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); auto sendLoop = buildNormalizedScfFor( @@ -1611,9 +1672,9 @@ LogicalResult appendScalarSendLoop(MaterializerState& state, step, ValueRange {}, [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); + Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); return success(); }); @@ -1670,22 +1731,19 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, MaterializedClass& sourceClass, Value payload, const ProjectedTransferDescriptor& descriptor, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, + const MessageVector& messages, Location loc) { assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); - assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); - assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); - assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorProjectedOffsets.size() + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + assert(messages.size() * descriptor.fragmentsPerLane == descriptor.laneMajorProjectedOffsets.size() && "projected send lane count mismatch"); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - if (channelIds.size() == 1) { - Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelIds.front()); - Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreIds.front()); - Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreIds.front()); + if (messages.size() == 1) { + Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); + Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); + Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); Value laneIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value sendPayload; if (descriptor.fragmentsPerLane == 1) { @@ -1707,7 +1765,7 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(channelIds.size())); + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); auto projectedSendLoop = buildNormalizedScfFor( @@ -1718,9 +1776,9 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, step, ValueRange {}, [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); + Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); Value sendPayload; if (descriptor.fragmentsPerLane == 1) { @@ -1746,31 +1804,33 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, LogicalResult appendSend(MaterializerState& state, MaterializedClass& sourceClass, Value payload, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, + const MessageVector& messages, Location loc) { - assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); - assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); - assert(!channelIds.empty() && "expected at least one send"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one send"); if (sourceClass.isBatch) { state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc); - Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc); - Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc); + Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); + Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); + Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); return success(); } - if (channelIds.size() == 1) { - appendScalarSend( - state, sourceClass, payload, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc); + if (messages.size() == 1) { + appendScalarSend(state, + sourceClass, + payload, + messages.channelIds.front(), + messages.sourceCoreIds.front(), + messages.targetCoreIds.front(), + loc); return success(); } - return appendScalarSendLoop(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + return appendScalarSendLoop(state, sourceClass, payload, messages, loc); } Value appendScalarReceive(MaterializerState& state, @@ -1790,29 +1850,28 @@ Value appendScalarReceive(MaterializerState& state, .getOutput(); } -Value appendReceive(MaterializerState& state, - MaterializedClass& targetClass, - Type type, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { - assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); - assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); - assert(!channelIds.empty() && "expected at least one receive"); +Value appendReceive( + MaterializerState& state, MaterializedClass& targetClass, Type type, const MessageVector& messages, Location loc) { + assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one receive"); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); if (targetClass.isBatch) { - Value channelId = createLaneIndexedIndexValue(state, targetClass, channelIds, loc); - Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, sourceCoreIds, loc); - Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, targetCoreIds, loc); + Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); + Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); + Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput(); } - assert(channelIds.size() == 1 && "scalar target class can only receive one message at a time"); - return appendScalarReceive( - state, targetClass, type, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc); + assert(messages.size() == 1 && "scalar target class can only receive one message at a time"); + return appendScalarReceive(state, + targetClass, + type, + messages.channelIds.front(), + messages.sourceCoreIds.front(), + messages.targetCoreIds.front(), + loc); } LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, @@ -1835,7 +1894,12 @@ LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, if (keys.size() != sourceClass.cpus.size()) return sourceClass.op->emitError("lazy packed scalar receive expects one producer key per source lane"); - if (keys.size() != channelIds.size() || keys.size() != sourceCoreIds.size() || keys.size() != targetCoreIds.size()) + MessageVector messages; + messages.append(channelIds, sourceCoreIds, targetCoreIds); + if (failed(messages.verify(targetClass.op))) + return failure(); + + if (keys.size() != messages.size()) return targetClass.op->emitError("lazy packed scalar receive metadata is inconsistent"); auto rankedFragmentType = dyn_cast(fragmentType); @@ -1864,9 +1928,7 @@ LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, packedRun.kind = PackedScalarRunKind::DeferredReceive; packedRun.fragmentType = rankedFragmentType; - llvm::append_range(packedRun.channelIds, channelIds); - llvm::append_range(packedRun.sourceCoreIds, sourceCoreIds); - llvm::append_range(packedRun.targetCoreIds, targetCoreIds); + packedRun.messages = std::move(messages); PackedScalarRunSlot slot; llvm::append_range(slot.keys, keys); @@ -1881,9 +1943,35 @@ LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, struct ScalarSourceReceivePlan { ClassId targetClass = 0; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; + MessageVector messages; + Type receiveType; + Operation* projectedExtractOp = nullptr; + RankedTensorType projectedFragmentType; + unsigned projectedFragmentsPerLane = 1; +}; + +struct ProjectedTransferCompatibilityKey { + RankedTensorType fragmentType; + RankedTensorType payloadType; + unsigned sourceProjectedDim = 0; + unsigned fragmentsPerLane = 1; + + bool operator==(const ProjectedTransferCompatibilityKey& other) const { + return fragmentType == other.fragmentType && payloadType == other.payloadType + && sourceProjectedDim == other.sourceProjectedDim && fragmentsPerLane == other.fragmentsPerLane; + } +}; + +struct ScalarSourceSendGroup { + MessageVector messages; + std::optional projectedKey; + SmallVector projectedOffsets; +}; + +struct ScalarSourceFanoutPlan { + SmallVector receivePlans; + std::optional ordinarySendGroup; + SmallVector projectedSendGroups; }; SmallVector collectDestinationClassesForKeys(MaterializerState& state, ArrayRef keys) { @@ -1898,55 +1986,42 @@ SmallVector collectDestinationClassesForKeys(MaterializerState& stat return destinations; } -FailureOr> emitScalarSourceSends(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - ArrayRef destinationClasses, - Value payload, - Location loc) { +FailureOr buildScalarSourceFanoutPlan(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + ArrayRef destinationClasses, + Value payload) { assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class"); auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "scalar source core id"); if (failed(sourceCpu)) return failure(); - SmallVector receivePlans; - receivePlans.reserve(destinationClasses.size()); + ScalarSourceFanoutPlan fanoutPlan; + fanoutPlan.receivePlans.reserve(destinationClasses.size()); - const auto tryEmitProjected = [&](ClassId destinationClass, - const SmallVector& channelIds, - const SmallVector& sourceCoreIds, - const SmallVector& targetCoreIds) -> FailureOr { + const auto getProjectedDescriptor = [&](ClassId destinationClass) -> const ProjectedTransferDescriptor* { if (keys.size() != 1) - return false; + return nullptr; MaterializedClass& targetClass = state.classes[destinationClass]; if (!targetClass.isBatch) - return false; + return nullptr; auto producerIt = state.projectedTransfers.find(keys.front()); if (producerIt == state.projectedTransfers.end()) - return false; + return nullptr; auto descriptorIt = producerIt->second.find(destinationClass); if (descriptorIt == producerIt->second.end()) - return false; + return nullptr; const ProjectedTransferDescriptor& descriptor = descriptorIt->second; if (descriptor.laneMajorProjectedOffsets.size() != targetClass.cpus.size() * static_cast(descriptor.fragmentsPerLane)) - return false; + return nullptr; - if (failed(appendProjectedScalarSendLoop( - state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc))) - return failure(); - - Value received = - appendReceive(state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc); - - state.projectedExtractReplacements[descriptor.extractOp][destinationClass] = - ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane}; - return true; + return &descriptor; }; for (ClassId destinationClass : destinationClasses) { @@ -1955,8 +2030,9 @@ FailureOr> emitScalarSourceSends(Materia MaterializedClass& targetClass = state.classes[destinationClass]; - ScalarSourceReceivePlan plan; - plan.targetClass = destinationClass; + ScalarSourceReceivePlan receivePlan; + receivePlan.targetClass = destinationClass; + receivePlan.receiveType = payload.getType(); auto appendMessage = [&](CpuId targetCpu) -> LogicalResult { auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetCpu, "scalar target core id"); @@ -1964,9 +2040,7 @@ FailureOr> emitScalarSourceSends(Materia return failure(); int64_t channelId = state.nextChannelId++; - plan.channelIds.push_back(channelId); - plan.sourceCoreIds.push_back(*sourceCpu); - plan.targetCoreIds.push_back(*checkedTargetCpu); + receivePlan.messages.append(channelId, *sourceCpu, *checkedTargetCpu); return success(); }; @@ -1980,18 +2054,68 @@ FailureOr> emitScalarSourceSends(Materia return failure(); } - auto emittedProjected = tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds); - if (failed(emittedProjected)) - return failure(); - if (*emittedProjected) - continue; + if (const ProjectedTransferDescriptor* descriptor = getProjectedDescriptor(destinationClass)) { + receivePlan.receiveType = descriptor->payloadType; + receivePlan.projectedExtractOp = descriptor->extractOp; + receivePlan.projectedFragmentType = descriptor->fragmentType; + receivePlan.projectedFragmentsPerLane = descriptor->fragmentsPerLane; - if (failed(appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc))) - return failure(); - receivePlans.push_back(std::move(plan)); + ProjectedTransferCompatibilityKey key {descriptor->fragmentType, + descriptor->payloadType, + descriptor->sourceProjectedDim, + descriptor->fragmentsPerLane}; + + auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ScalarSourceSendGroup& group) { + return group.projectedKey && *group.projectedKey == key; + }); + if (groupIt == fanoutPlan.projectedSendGroups.end()) { + ScalarSourceSendGroup group; + group.projectedKey = key; + fanoutPlan.projectedSendGroups.push_back(std::move(group)); + groupIt = std::prev(fanoutPlan.projectedSendGroups.end()); + } + + groupIt->messages.append( + receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); + llvm::append_range(groupIt->projectedOffsets, descriptor->laneMajorProjectedOffsets); + } + else { + if (!fanoutPlan.ordinarySendGroup) + fanoutPlan.ordinarySendGroup = ScalarSourceSendGroup {}; + fanoutPlan.ordinarySendGroup->messages.append( + receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); + } + + fanoutPlan.receivePlans.push_back(std::move(receivePlan)); } - return receivePlans; + return fanoutPlan; +} + +LogicalResult emitScalarSourceFanoutSends(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ScalarSourceFanoutPlan& plan, + Location loc) { + if (plan.ordinarySendGroup && failed(appendSend(state, sourceClass, payload, plan.ordinarySendGroup->messages, loc))) + return failure(); + + for (const ScalarSourceSendGroup& group : plan.projectedSendGroups) { + if (!group.projectedKey) + return sourceClass.op->emitError("projected scalar send group is missing a compatibility key"); + + ProjectedTransferDescriptor descriptor; + descriptor.fragmentType = group.projectedKey->fragmentType; + descriptor.payloadType = group.projectedKey->payloadType; + descriptor.sourceProjectedDim = group.projectedKey->sourceProjectedDim; + descriptor.fragmentsPerLane = group.projectedKey->fragmentsPerLane; + descriptor.laneMajorProjectedOffsets = group.projectedOffsets; + + if (failed(appendProjectedScalarSendLoop(state, sourceClass, payload, descriptor, group.messages, loc))) + return failure(); + } + + return success(); } LogicalResult emitScalarSourceCommunication( @@ -2002,15 +2126,22 @@ LogicalResult emitScalarSourceCommunication( state.availableValues.record(key, sourceClass.id, payload); SmallVector destinationClasses = collectDestinationClassesForKeys(state, keys); - auto receivePlans = emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc); - if (failed(receivePlans)) + auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, payload); + if (failed(fanoutPlan)) + return failure(); + if (failed(emitScalarSourceFanoutSends(state, sourceClass, payload, *fanoutPlan, loc))) return failure(); - for (const ScalarSourceReceivePlan& plan : *receivePlans) { + for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { MaterializedClass& targetClass = state.classes[plan.targetClass]; - Value received = appendReceive( - state, targetClass, payload.getType(), plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc); + Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); + + if (plan.projectedExtractOp) { + state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = + ProjectedExtractReplacement {received, plan.projectedFragmentType, plan.projectedFragmentsPerLane}; + continue; + } for (ProducerKey key : keys) state.availableValues.record(key, targetClass.id, received); @@ -2040,12 +2171,10 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, return sourceClass.op->emitError( "cannot materialize batch-to-scalar communication because source lanes are not contiguous"); - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(sourceClass.cpus.size()); - sourceCoreIds.reserve(sourceClass.cpus.size()); - targetCoreIds.reserve(sourceClass.cpus.size()); + MessageVector messages; + messages.channelIds.reserve(sourceClass.cpus.size()); + messages.sourceCoreIds.reserve(sourceClass.cpus.size()); + messages.targetCoreIds.reserve(sourceClass.cpus.size()); auto targetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus.front(), "batch-to-scalar target core id"); if (failed(targetCpu)) @@ -2054,27 +2183,29 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id"); if (failed(checkedSourceCpu)) return failure(); - channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(*checkedSourceCpu); - targetCoreIds.push_back(*targetCpu); + messages.append(state.nextChannelId++, *checkedSourceCpu, *targetCpu); } - if (failed(appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc))) + if (failed(appendSend(state, sourceClass, payload, messages, loc))) return failure(); - return registerLazyPackedScalarReceives( - state, sourceClass, targetClass, keys, payload.getType(), channelIds, sourceCoreIds, targetCoreIds); + return registerLazyPackedScalarReceives(state, + sourceClass, + targetClass, + keys, + payload.getType(), + messages.channelIds, + messages.sourceCoreIds, + messages.targetCoreIds); } if (sourceClass.cpus.size() != targetClass.cpus.size()) return sourceClass.op->emitError( "cannot materialize batch communication between equivalence classes of different sizes"); - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(sourceClass.cpus.size()); - sourceCoreIds.reserve(sourceClass.cpus.size()); - targetCoreIds.reserve(targetClass.cpus.size()); + MessageVector messages; + messages.channelIds.reserve(sourceClass.cpus.size()); + messages.sourceCoreIds.reserve(sourceClass.cpus.size()); + messages.targetCoreIds.reserve(targetClass.cpus.size()); for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch source core id"); @@ -2083,14 +2214,12 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus[lane], "batch target core id"); if (failed(checkedTargetCpu)) return failure(); - channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(*checkedSourceCpu); - targetCoreIds.push_back(*checkedTargetCpu); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); } - if (failed(appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc))) + if (failed(appendSend(state, sourceClass, payload, messages, loc))) return failure(); - Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = appendReceive(state, targetClass, payload.getType(), messages, loc); for (ProducerKey key : keys) state.availableValues.record(key, targetClass.id, received); @@ -2192,6 +2321,27 @@ struct DirectWholeBatchFragment { Value fragment; }; +enum class WholeBatchFragmentSourceKind { + DeferredReceive, + DeferredLocalCompute, + PackedValue, + DirectValue +}; + +struct WholeBatchFragmentGroup { + WholeBatchFragmentSourceKind kind = WholeBatchFragmentSourceKind::DirectValue; + RankedTensorType fragmentType; + SmallVector outputOffsets; + MessageVector messages; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + SmallVector sourceLanes; + Value packed; + RankedTensorType slotPackedType; + SmallVector slotIndices; + SmallVector, 16> directFragments; +}; + struct WholeBatchAssemblyPlan { RankedTensorType resultType; int64_t rowsPerLane = 0; @@ -2264,12 +2414,54 @@ SmallVector flattenPackedScalarRunKeys(const PackedScalarRunVal return keys; } +FailureOr emitIndexedFragmentInsertLoop(MaterializerState& state, + Operation* anchor, + Operation* insertionPoint, + Value destination, + int64_t itemCount, + IndexedFragmentBuilder buildFragment, + IndexedInsertOffsetBuilder buildOffset, + Location loc) { + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, itemCount); + Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); + + state.rewriter.setInsertionPoint(insertionPoint); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + FailureOr fragment = buildFragment(flatIndex); + if (failed(fragment)) + return failure(); + FailureOr offset = buildOffset(flatIndex); + if (failed(offset)) + return failure(); + yielded.push_back(insertFragmentIntoWholeBatch(state, *fragment, iterArgs.front(), *offset, loc)); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + FailureOr> cloneBatchBodyForLane(MaterializerState& state, MaterializedClass& targetClass, const ComputeInstance& instance, Value laneValue, ArrayRef resultIndices, - std::optional projectionSlotIndex = std::nullopt); + CloneIndexingContext indexing = {}); + +Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc); +FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, + MaterializedClass& targetClass, + IndexedBatchRunValue& run, + Value runSlotIndex, + Location loc); FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, MaterializedClass& targetClass, @@ -2315,7 +2507,12 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); + cloneBatchBodyForLane(state, + targetClass, + keys.front().instance, + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); if (failed(produced) || produced->size() != 1) return failure(); @@ -2330,12 +2527,12 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& return run.packed; } -FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const WholeBatchAssemblyPlan& plan, - PackedScalarRunValue& run, - Location loc) { +[[maybe_unused]] FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchAssemblyPlan& plan, + PackedScalarRunValue& run, + Location loc) { assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); if (run.fragmentType.getDimSize(0) != plan.rowsPerLane) @@ -2377,7 +2574,12 @@ FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); + cloneBatchBodyForLane(state, + targetClass, + keys.front().instance, + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); if (failed(produced) || produced->size() != 1) return failure(); @@ -2391,12 +2593,12 @@ FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt return loop->results.front(); } -FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const WholeBatchAssemblyPlan& plan, - PackedScalarRunValue& run, - Location loc) { +[[maybe_unused]] FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchAssemblyPlan& plan, + PackedScalarRunValue& run, + Location loc) { assert(run.kind == PackedScalarRunKind::DeferredReceive && "expected deferred receive packed scalar run"); if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) @@ -2417,12 +2619,12 @@ FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& } } - if (outputOffsets.size() != run.channelIds.size()) + if (outputOffsets.size() != run.messages.size()) return failure(); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); Value upperBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.channelIds.size())); + getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.messages.size())); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); @@ -2435,9 +2637,9 @@ FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& ValueRange {destination}, [&](OpBuilder&, Location, Value index, ValueRange iterArgs, SmallVectorImpl& yielded) { Value acc = iterArgs.front(); - Value channelId = createIndexedIndexValue(state, targetClass.op, run.channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, targetClass.op, run.sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, targetClass.op, run.targetCoreIds, index, loc); + Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, index, loc); Value received = SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) @@ -2452,12 +2654,12 @@ FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& return loop->results.front(); } -FailureOr insertPackedScalarRunIntoWholeBatch(MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const WholeBatchAssemblyPlan& plan, - PackedScalarRunValue& run, - Location loc) { +[[maybe_unused]] FailureOr insertPackedScalarRunIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchAssemblyPlan& plan, + PackedScalarRunValue& run, + Location loc) { if (run.slots.empty()) return destination; @@ -2616,6 +2818,195 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, return success(); } +LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, + MaterializedClass& targetClass, + const WholeBatchAssemblyPlan& plan, + SmallVectorImpl& groups) { + for (PackedScalarRunValue* run : plan.packedRuns) { + if (!run || run->slots.empty()) + continue; + if (run->fragmentType.getDimSize(0) != plan.rowsPerLane) + return failure(); + + if (run->kind == PackedScalarRunKind::DeferredReceive) { + if (failed(validatePackedScalarRunMetadata(targetClass.op, *run))) + return failure(); + + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == run->fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = run->fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + + groupIt->messages.append(run->messages.channelIds, run->messages.sourceCoreIds, run->messages.targetCoreIds); + for (const PackedScalarRunSlot& slot : run->slots) + for (ProducerKey fragmentKey : slot.keys) + groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + continue; + } + + if (run->kind == PackedScalarRunKind::DeferredLocalCompute) { + SmallVector keys = flattenPackedScalarRunKeys(*run); + if (keys.empty()) + return failure(); + + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredLocalCompute + && group.fragmentType == run->fragmentType && group.sourceOp == run->sourceOp + && group.resultIndex == run->resultIndex; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredLocalCompute; + group.fragmentType = run->fragmentType; + group.sourceOp = run->sourceOp; + group.resultIndex = run->resultIndex; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + + for (ProducerKey fragmentKey : keys) { + if (fragmentKey.instance.laneCount != 1) + return failure(); + groupIt->sourceLanes.push_back(fragmentKey.instance.laneStart); + groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + } + continue; + } + + auto sourceBatch = dyn_cast_or_null(run->sourceOp); + if (!sourceBatch || !run->packed) + return failure(); + + size_t slotLaneCount = run->slots.front().keys.size(); + if (slotLaneCount == 0) + return failure(); + FailureOr slotPackedType = getPackedBatchTensorType(run->fragmentType, slotLaneCount); + if (failed(slotPackedType)) + return failure(); + + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::PackedValue && group.fragmentType == run->fragmentType + && group.packed == run->packed && group.slotPackedType == *slotPackedType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::PackedValue; + group.fragmentType = run->fragmentType; + group.packed = run->packed; + group.slotPackedType = *slotPackedType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + + for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { + std::optional slotKey = getContiguousProducerKeyForKeys(slot.keys); + if (!slotKey) + return failure(); + groupIt->slotIndices.push_back(slotIndex); + groupIt->outputOffsets.push_back(static_cast(slotKey->instance.laneStart) * plan.rowsPerLane); + } + } + + for (const DirectWholeBatchFragment& fragment : plan.directFragments) { + WholeBatchFragmentGroup group; + auto fragmentType = dyn_cast(fragment.fragment.getType()); + if (!fragmentType) + return failure(); + group.kind = WholeBatchFragmentSourceKind::DirectValue; + group.fragmentType = fragmentType; + group.directFragments.push_back( + {fragment.fragment, static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane}); + groups.push_back(std::move(group)); + } + + return success(); +} + +FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchFragmentGroup& group, + Location loc) { + switch (group.kind) { + case WholeBatchFragmentSourceKind::DeferredReceive: + return emitIndexedFragmentInsertLoop( + state, + targetClass.op, + targetClass.body->getTerminator(), + destination, + static_cast(group.outputOffsets.size()), + [&](Value flatIndex) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); + return SpatChannelReceiveOp::create( + state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + case WholeBatchFragmentSourceKind::DeferredLocalCompute: { + SmallVector resultIndices {group.resultIndex}; + return emitIndexedFragmentInsertLoop( + state, + targetClass.op, + targetClass.body->getTerminator(), + destination, + static_cast(group.outputOffsets.size()), + [&](Value flatIndex) -> FailureOr { + Value sourceLane = createIndexedIndexValue(state, targetClass.op, group.sourceLanes, flatIndex, loc); + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + ComputeInstance {group.sourceOp, 0, 1}, + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = flatIndex, .projectionSlotIndex = flatIndex}); + if (failed(produced) || produced->size() != 1) + return failure(); + return produced->front(); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + } + case WholeBatchFragmentSourceKind::PackedValue: + return emitIndexedFragmentInsertLoop( + state, + targetClass.op, + targetClass.body->getTerminator(), + destination, + static_cast(group.slotIndices.size()), + [&](Value flatIndex) -> FailureOr { + Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc); + return extractPackedSlotForIndex( + state, targetClass.op, group.packed, group.slotPackedType, packedSlotIndex, loc); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + case WholeBatchFragmentSourceKind::DirectValue: + for (const auto& [fragment, offset] : group.directFragments) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + destination = insertFragmentIntoWholeBatch( + state, fragment, destination, getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset), loc); + } + return destination; + } + + return failure(); +} + FailureOr buildWholeBatchAssemblyPlan(MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, @@ -2652,22 +3043,17 @@ FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, tensor::EmptyOp::create(state.rewriter, loc, plan.resultType.getShape(), plan.resultType.getElementType()) .getResult(); - for (PackedScalarRunValue* run : plan.packedRuns) { - FailureOr updated = insertPackedScalarRunIntoWholeBatch(state, targetClass, result, plan, *run, loc); + SmallVector groups; + if (failed(collectWholeBatchFragmentGroups(state, targetClass, plan, groups))) + return failure(); + + for (const WholeBatchFragmentGroup& group : groups) { + FailureOr updated = emitWholeBatchFragmentGroup(state, targetClass, result, group, loc); if (failed(updated)) return failure(); - result = *updated; } - for (const DirectWholeBatchFragment& fragment : plan.directFragments) { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - - int64_t rowOffset = static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane; - Value outputOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset); - result = insertFragmentIntoWholeBatch(state, fragment.fragment, result, outputOffset, loc); - } - state.availableValues.record(key, targetClass.id, result); return result; } @@ -2688,14 +3074,36 @@ FailureOr materializeWholeBatchInput( FailureOr resolveInputValue(MaterializerState& state, MaterializedClass& targetClass, Value input, - const ComputeInstance& consumerInstance) { + const ComputeInstance& consumerInstance, + CloneIndexingContext indexing) { if (isConstantLike(input)) return input; if (std::optional producer = getProducerKey(input, &consumerInstance)) { + if (indexing.runSlotIndex) { + if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) + return materializeIndexedBatchRunReceive( + state, targetClass, *indexedRun, *indexing.runSlotIndex, consumerInstance.op->getLoc()); + } + if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) return *value; + if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { + size_t laneCount = targetClass.cpus.size(); + for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) { + if (!llvm::is_contained(slot.keys, *producer)) + continue; + + MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); + Value received = + appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); + for (ProducerKey slotKey : slot.keys) + state.availableValues.record(slotKey, targetClass.id, received); + return received; + } + } + if (isWholeBatchProducerKey(*producer)) return materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); @@ -2745,11 +3153,12 @@ void mapWeights(MaterializerState& state, LogicalResult mapInputs(MaterializerState& state, MaterializedClass& targetClass, const ComputeInstance& instance, - IRMapping& mapper) { + IRMapping& mapper, + CloneIndexingContext indexing) { Operation* op = instance.op; if (auto compute = dyn_cast(op)) { for (auto [index, input] : llvm::enumerate(compute.getInputs())) { - FailureOr mapped = resolveInputValue(state, targetClass, input, instance); + FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); if (failed(mapped)) return compute.emitOpError("failed to resolve materialized compute input"); auto inputArg = compute.getInputArgument(index); @@ -2765,7 +3174,7 @@ LogicalResult mapInputs(MaterializerState& state, if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) continue; - FailureOr mapped = resolveInputValue(state, targetClass, input, instance); + FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); if (failed(mapped)) return batch.emitOpError("failed to resolve materialized compute_batch input"); auto inputArg = batch.getInputArgument(index); @@ -2863,6 +3272,26 @@ FailureOr materializeProjectedExtractReplacement(MaterializerState& state state, extract.getLoc(), replacement.payload, packedOffset, replacement.fragmentType.getDimSize(0)); } +FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, + MaterializedClass& targetClass, + IndexedBatchRunValue& run, + Value runSlotIndex, + Location loc) { + if (!targetClass.isBatch) + return targetClass.op->emitError("indexed batch run receive requires a batch target class"); + if (failed(run.messages.verify(targetClass.op))) + return failure(); + + Value flatIndex = createBatchRunFlatIndex(state, targetClass, runSlotIndex, loc); + Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, flatIndex, loc); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + return SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); +} + FailureOr> cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef peers) { assert(!peers.empty() && "expected at least one peer instance"); @@ -2895,7 +3324,7 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); mapWeights(state, targetClass, instance, mapper); - if (failed(mapInputs(state, targetClass, instance, mapper))) + if (failed(mapInputs(state, targetClass, instance, mapper, {}))) return failure(); state.rewriter.restoreInsertionPoint(cloneInsertionPoint); @@ -3116,15 +3545,22 @@ LogicalResult emitPackedRunFanout(MaterializerState& state, Location loc) { assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); - auto receivePlans = emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc); - if (failed(receivePlans)) + auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, packed); + if (failed(fanoutPlan)) + return failure(); + if (failed(emitScalarSourceFanoutSends(state, sourceClass, packed, *fanoutPlan, loc))) return failure(); - for (const ScalarSourceReceivePlan& plan : *receivePlans) { + for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { MaterializedClass& targetClass = state.classes[plan.targetClass]; - Value received = - appendReceive(state, targetClass, packed.getType(), plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc); + Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); + + if (plan.projectedExtractOp) { + state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = + ProjectedExtractReplacement {received, plan.projectedFragmentType, plan.projectedFragmentsPerLane}; + continue; + } if (failed(registerPackedRunValue(state, targetClass, keys, received, fragmentType, loc))) return failure(); @@ -3138,7 +3574,7 @@ FailureOr> cloneBatchBodyForLane(MaterializerState& state, const ComputeInstance& instance, Value laneValue, ArrayRef resultIndices, - std::optional projectionSlotIndex) { + CloneIndexingContext indexing) { auto batch = dyn_cast(instance.op); if (!batch) return failure(); @@ -3153,7 +3589,7 @@ FailureOr> cloneBatchBodyForLane(MaterializerState& state, OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); mapWeights(state, targetClass, instance, mapper); - if (failed(mapInputs(state, targetClass, instance, mapper))) + if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) return failure(); state.rewriter.restoreInsertionPoint(cloneInsertionPoint); @@ -3163,8 +3599,8 @@ FailureOr> cloneBatchBodyForLane(MaterializerState& state, if (auto extract = dyn_cast(&op)) { if (std::optional replacement = lookupProjectedExtractReplacement(state, targetClass, extract)) { - FailureOr projected = - materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, projectionSlotIndex); + FailureOr projected = materializeProjectedExtractReplacement( + state, targetClass, extract, *replacement, indexing.projectionSlotIndex); if (failed(projected)) return failure(); @@ -3211,7 +3647,7 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, instance.laneStart); - return cloneBatchBodyForLane(state, targetClass, instance, laneValue, group.resultIndices); + return cloneBatchBodyForLane(state, targetClass, instance, laneValue, group.resultIndices, {}); } state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); @@ -3262,8 +3698,13 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc); - FailureOr> produced = cloneBatchBodyForLane( - state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex); + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + run.front().peers.front(), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); if (failed(produced)) return failure(); @@ -3576,9 +4017,9 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state, plan.destinationClass = destinationClass; size_t messageCount = run.size() * sourceClass.cpus.size(); - plan.channelIds.reserve(messageCount); - plan.sourceCoreIds.reserve(messageCount); - plan.targetCoreIds.reserve(messageCount); + plan.messages.channelIds.reserve(messageCount); + plan.messages.sourceCoreIds.reserve(messageCount); + plan.messages.targetCoreIds.reserve(messageCount); for ([[maybe_unused]] const MaterializationRunSlot& slot : run) { for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { @@ -3591,9 +4032,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state, "batch run target core id"); if (failed(checkedTargetCpu)) return failure(); - plan.channelIds.push_back(state.nextChannelId++); - plan.sourceCoreIds.push_back(*checkedSourceCpu); - plan.targetCoreIds.push_back(*checkedTargetCpu); + plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); } } @@ -3612,23 +4051,26 @@ void appendBatchRunSend(MaterializerState& state, Location loc) { assert(sourceClass.isBatch && "batch run send expects a materialized batch source"); - Value channelId = createIndexedIndexValue(state, sourceClass.op, plan.channelIds, flatIndex, loc); - Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, plan.sourceCoreIds, flatIndex, loc); - Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, plan.targetCoreIds, flatIndex, loc); + Value channelId = createIndexedChannelId(state, sourceClass.op, plan.messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, plan.messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, plan.messages, flatIndex, loc); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); } -ArrayRef sliceChannelsForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { - return ArrayRef(plan.channelIds).slice(slotIndex * laneCount, laneCount); +[[maybe_unused]] ArrayRef +sliceChannelsForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { + return ArrayRef(plan.messages.channelIds).slice(slotIndex * laneCount, laneCount); } -ArrayRef sliceSourcesForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { - return ArrayRef(plan.sourceCoreIds).slice(slotIndex * laneCount, laneCount); +[[maybe_unused]] ArrayRef +sliceSourcesForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { + return ArrayRef(plan.messages.sourceCoreIds).slice(slotIndex * laneCount, laneCount); } -ArrayRef sliceTargetsForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { - return ArrayRef(plan.targetCoreIds).slice(slotIndex * laneCount, laneCount); +[[maybe_unused]] ArrayRef +sliceTargetsForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { + return ArrayRef(plan.messages.targetCoreIds).slice(slotIndex * laneCount, laneCount); } LogicalResult appendPackedScalarRunReceives(MaterializerState& state, @@ -3643,8 +4085,10 @@ LogicalResult appendPackedScalarRunReceives(MaterializerState& state, size_t laneCount = sourceClass.cpus.size(); size_t receiveCount = run.size() * laneCount; - if (receiveCount != plan.channelIds.size() || receiveCount != plan.sourceCoreIds.size() - || receiveCount != plan.targetCoreIds.size()) + if (failed(plan.messages.verify(targetClass.op))) + return failure(); + + if (receiveCount != plan.messages.size()) return targetClass.op->emitError("inconsistent flattened batch run receive plan"); auto rankedFragmentType = dyn_cast(fragmentType); @@ -3658,9 +4102,7 @@ LogicalResult appendPackedScalarRunReceives(MaterializerState& state, packedRun.kind = PackedScalarRunKind::DeferredReceive; packedRun.fragmentType = rankedFragmentType; - packedRun.channelIds = plan.channelIds; - packedRun.sourceCoreIds = plan.sourceCoreIds; - packedRun.targetCoreIds = plan.targetCoreIds; + packedRun.messages = plan.messages; packedRun.slots.reserve(run.size()); for (const MaterializationRunSlot& slot : run) { @@ -3676,6 +4118,32 @@ LogicalResult appendPackedScalarRunReceives(MaterializerState& state, return success(); } +LogicalResult recordIndexedBatchRunReceives(MaterializerState& state, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("indexed batch run receive expects static ranked fragment type"); + + IndexedBatchRunValue indexedRun; + indexedRun.targetClass = targetClass.id; + indexedRun.sourceOp = run.front().peers.front().op; + indexedRun.resultIndex = plan.resultIndex; + indexedRun.fragmentType = rankedFragmentType; + indexedRun.messages = plan.messages; + indexedRun.slots.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + PackedScalarRunSlot indexedSlot; + indexedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); + indexedRun.slots.push_back(std::move(indexedSlot)); + } + + state.availableValues.recordIndexedBatchRun(std::move(indexedRun)); + return success(); +} + LogicalResult appendBatchRunReceives(MaterializerState& state, MaterializedClass& sourceClass, ArrayRef run, @@ -3683,24 +4151,10 @@ LogicalResult appendBatchRunReceives(MaterializerState& state, Type fragmentType, Location loc) { MaterializedClass& targetClass = state.classes[plan.destinationClass]; - size_t laneCount = sourceClass.cpus.size(); if (!targetClass.isBatch) return appendPackedScalarRunReceives(state, sourceClass, run, plan, fragmentType, loc); - - for (auto [slotIndex, slot] : llvm::enumerate(run)) { - SmallVector keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); - - ArrayRef channelIds = sliceChannelsForRunSlot(plan, slotIndex, laneCount); - ArrayRef sourceCoreIds = sliceSourcesForRunSlot(plan, slotIndex, laneCount); - ArrayRef targetCoreIds = sliceTargetsForRunSlot(plan, slotIndex, laneCount); - - Value received = appendReceive(state, targetClass, fragmentType, channelIds, sourceCoreIds, targetCoreIds, loc); - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, received); - } - - return success(); + return recordIndexedBatchRunReceives(state, run, plan, fragmentType); } LogicalResult materializeBatchClassRun(MaterializerState& state, @@ -3742,8 +4196,13 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - FailureOr> produced = cloneBatchBodyForLane( - state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex); + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + run.front().peers.front(), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); if (failed(produced)) return failure(); @@ -3755,18 +4214,18 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); } - - for (const BatchRunSendPlan& plan : sendPlans) { - if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) - return failure(); - - if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) - return failure(); - } return success(); }); if (failed(loop)) return failure(); + + for (const BatchRunSendPlan& plan : sendPlans) { + if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) + return failure(); + + if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) + return failure(); + } } return success(); @@ -3830,45 +4289,31 @@ FailureOr createReceiveConcatLoop(MaterializerState& state, Operation* insertionPoint, RankedTensorType concatType, RankedTensorType fragmentType, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, + const MessageVector& messages, Location loc) { - assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); - assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); - assert(!channelIds.empty() && "expected at least one receive"); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, static_cast(channelIds.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); + assert(succeeded(messages.verify(anchor)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one receive"); state.rewriter.setInsertionPoint(insertionPoint); Value init = tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {init}, - [&](OpBuilder&, Location, Value index, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value acc = iterArgs.front(); - Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc); - - Value received = - SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - Value firstOffset = scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); - Value next = createDim0InsertSlice(state, loc, received, acc, firstOffset); - yielded.push_back(next); - return success(); - }); - if (failed(loop)) - return failure(); - return loop->results.front(); + return emitIndexedFragmentInsertLoop( + state, + anchor, + insertionPoint, + init, + static_cast(messages.size()), + [&](Value index) -> FailureOr { + Value channelId = createIndexedChannelId(state, anchor, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, anchor, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, anchor, messages, index, loc); + return SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + [&](Value index) -> FailureOr { + return scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); + }, + loc); } void replaceHostUses(MaterializerState& state) {