diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 59b1ada..fb2ea0f 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -124,6 +124,49 @@ struct BatchRunSendPlan { SmallVector targetCoreIds; }; +struct ProjectedBatchInputKey { + Operation* consumerOp = nullptr; + unsigned inputIndex = 0; + + bool operator==(const ProjectedBatchInputKey& other) const { + return consumerOp == other.consumerOp && inputIndex == other.inputIndex; + } +}; + +struct ProjectedBatchInputKeyInfo { + static ProjectedBatchInputKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; + } + + static ProjectedBatchInputKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ProjectedBatchInputKey& key) { + return llvm::hash_combine(key.consumerOp, key.inputIndex); + } + + static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { + return lhs == rhs; + } +}; + +struct ProjectedTransferDescriptor { + ProjectedBatchInputKey inputKey; + Operation* extractOp = nullptr; + + RankedTensorType fragmentType; + RankedTensorType payloadType; + unsigned fragmentsPerLane = 1; + SmallVector laneMajorSourceDim0Offsets; +}; + +struct ProjectedExtractReplacement { + Value payload; + RankedTensorType fragmentType; + unsigned fragmentsPerLane = 1; +}; + struct MaterializerState; class AvailableValueStore { @@ -158,6 +201,8 @@ struct MaterializerState { DenseSet materializedSlots; DenseMap, ProducerKeyInfo> producerDestClasses; + DenseMap, ProducerKeyInfo> projectedTransfers; + DenseMap> projectedExtractReplacements; AvailableValueStore availableValues; DenseMap hostReplacements; DenseSet oldComputeOps; @@ -1153,6 +1198,197 @@ LogicalResult collectProducerDestinations(MaterializerState& state) { return success(); } +bool isValueOffset(OpFoldResult offset, Value expected) { + auto value = dyn_cast(offset); + return value && value == expected; +} + +bool isStaticIndexAttr(OpFoldResult value, int64_t expected) { + auto attr = dyn_cast(value); + if (!attr) + return false; + + auto intAttr = dyn_cast(attr); + return intAttr && intAttr.getInt() == expected; +} + +std::optional matchSimpleLaneProjectedInput(SpatComputeBatch batch, unsigned inputIndex) { + std::optional inputArg = batch.getInputArgument(inputIndex); + std::optional laneArg = batch.getLaneArgument(); + if (!inputArg || !laneArg) + return std::nullopt; + + if (!inputArg->hasOneUse()) + return std::nullopt; + + Operation* user = *inputArg->getUsers().begin(); + auto extract = dyn_cast(user); + if (!extract || extract.getSource() != *inputArg) + return std::nullopt; + + auto inputType = dyn_cast(inputArg->getType()); + auto fragmentType = dyn_cast(extract.getResult().getType()); + if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape()) + return std::nullopt; + + if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank()) + return std::nullopt; + + SmallVector offsets = extract.getMixedOffsets(); + SmallVector sizes = extract.getMixedSizes(); + SmallVector strides = extract.getMixedStrides(); + + if (offsets.size() != static_cast(inputType.getRank()) + || sizes.size() != static_cast(inputType.getRank()) + || strides.size() != static_cast(inputType.getRank())) + return std::nullopt; + + if (!isValueOffset(offsets.front(), *laneArg)) + return std::nullopt; + if (!isStaticIndexAttr(sizes.front(), 1) || !isStaticIndexAttr(strides.front(), 1)) + return std::nullopt; + + for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { + if (!isStaticIndexAttr(offsets[dim], 0)) + return std::nullopt; + if (!isStaticIndexAttr(sizes[dim], inputType.getDimSize(dim))) + return std::nullopt; + if (!isStaticIndexAttr(strides[dim], 1)) + return std::nullopt; + } + + if (fragmentType.getDimSize(0) != 1) + return std::nullopt; + for (int64_t dim = 1; dim < inputType.getRank(); ++dim) + if (fragmentType.getDimSize(dim) != inputType.getDimSize(dim)) + return std::nullopt; + + return extract; +} + +LogicalResult collectProjectedTransfers(MaterializerState& state) { + struct PendingProjectedTransferDescriptor { + ProjectedBatchInputKey inputKey; + Operation* extractOp = nullptr; + RankedTensorType fragmentType; + SmallVector, 8> offsetsByLane; + bool invalid = false; + }; + + DenseMap, ProducerKeyInfo> pending; + + for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) { + auto batch = dyn_cast(consumer.op); + if (!batch || consumer.laneCount != 1) + continue; + + auto cpuIt = state.schedule.computeToCpuMap.find(consumer); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return consumer.op->emitError("projected transfer collection expected scheduled consumer"); + + ClassId targetClassId = state.cpuToClass.lookup(cpuIt->second); + MaterializedClass& targetClass = state.classes[targetClassId]; + if (!targetClass.isBatch) + continue; + + auto targetLaneIt = targetClass.cpuToLane.find(cpuIt->second); + if (targetLaneIt == targetClass.cpuToLane.end()) + return consumer.op->emitError("projected transfer collection could not recover target lane"); + + unsigned targetLane = targetLaneIt->second; + + for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { + std::optional producer = getProducerKey(input, &consumer); + if (!producer) + continue; + + auto producerCpuIt = state.schedule.computeToCpuMap.find(producer->instance); + if (producerCpuIt == state.schedule.computeToCpuMap.end()) + continue; + + ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second); + if (sourceClassId == targetClassId) + continue; + + std::optional extract = + matchSimpleLaneProjectedInput(batch, static_cast(inputIndex)); + if (!extract) + continue; + + auto fragmentType = cast((*extract).getResult().getType()); + + PendingProjectedTransferDescriptor& descriptor = pending[*producer][targetClassId]; + if (descriptor.offsetsByLane.empty()) { + descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; + descriptor.extractOp = extract->getOperation(); + descriptor.fragmentType = fragmentType; + descriptor.offsetsByLane.resize(targetClass.cpus.size()); + } + + ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; + if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation() + || descriptor.fragmentType != fragmentType) { + descriptor.invalid = true; + continue; + } + + if (targetLane >= descriptor.offsetsByLane.size()) { + descriptor.invalid = true; + continue; + } + + descriptor.offsetsByLane[targetLane].push_back(static_cast(consumer.laneStart)); + } + } + + for (auto& producerEntry : pending) { + ProducerKey producer = producerEntry.first; + for (auto& classEntry : producerEntry.second) { + ClassId targetClassId = classEntry.first; + PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second; + + if (pendingDescriptor.invalid) + continue; + if (pendingDescriptor.offsetsByLane.empty()) + continue; + + unsigned fragmentsPerLane = pendingDescriptor.offsetsByLane.front().size(); + if (fragmentsPerLane == 0) + continue; + + bool uniform = true; + for (ArrayRef laneOffsets : pendingDescriptor.offsetsByLane) { + if (laneOffsets.size() != fragmentsPerLane) { + uniform = false; + break; + } + } + if (!uniform) + continue; + + SmallVector payloadShape(pendingDescriptor.fragmentType.getShape()); + payloadShape[0] *= static_cast(fragmentsPerLane); + RankedTensorType payloadType = RankedTensorType::get(payloadShape, + pendingDescriptor.fragmentType.getElementType(), + pendingDescriptor.fragmentType.getEncoding()); + + ProjectedTransferDescriptor descriptor; + descriptor.inputKey = pendingDescriptor.inputKey; + descriptor.extractOp = pendingDescriptor.extractOp; + descriptor.fragmentType = pendingDescriptor.fragmentType; + descriptor.payloadType = payloadType; + descriptor.fragmentsPerLane = fragmentsPerLane; + descriptor.laneMajorSourceDim0Offsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane); + for (ArrayRef laneOffsets : pendingDescriptor.offsetsByLane) + llvm::append_range(descriptor.laneMajorSourceDim0Offsets, laneOffsets); + + state.projectedTransfers[producer][targetClassId] = std::move(descriptor); + } + } + + return success(); +} + SmallVector getOutputKeysForPeers(ArrayRef peers, size_t resultIndex) { SmallVector keys; keys.reserve(peers.size()); @@ -1237,6 +1473,111 @@ void appendScalarSendLoop(MaterializerState& state, SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); } +Value buildProjectedPackedPayload(MaterializerState& state, + Operation* anchor, + Value fullPayload, + const ProjectedTransferDescriptor& descriptor, + Value laneIndex, + Location loc) { + assert(descriptor.fragmentsPerLane > 1 && "use direct fragment path for single-fragment projection"); + + Value init = tensor::EmptyOp::create( + state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) + .getResult(); + + Value lowerBound = createIndexConstant(state, anchor, 0); + Value upperBound = createIndexConstant(state, anchor, descriptor.fragmentsPerLane); + Value step = createIndexConstant(state, anchor, 1); + + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); + + Block* body = loop.getBody(); + if (!body->empty()) + if (auto yield = dyn_cast(body->back())) + state.rewriter.eraseOp(yield); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToEnd(body); + + Value fragmentIndex = loop.getInductionVar(); + Value acc = body->getArgument(1); + + Value fragmentsPerLane = createIndexConstant(state, anchor, descriptor.fragmentsPerLane); + Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); + Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); + + Value sourceOffset = + createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc); + + Value fragment = + createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0)); + + Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex); + scf::YieldOp::create(state.rewriter, loc, next); + + return loop.getResult(0); +} + +void appendProjectedScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ProjectedTransferDescriptor& descriptor, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); + assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); + assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); + assert(channelIds.size() * descriptor.fragmentsPerLane == descriptor.laneMajorSourceDim0Offsets.size() + && "projected send lane count mismatch"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + if (channelIds.size() == 1) { + Value channelId = createIndexConstant(state, sourceClass.op, channelIds.front()); + Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds.front()); + Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds.front()); + Value laneIndex = createIndexConstant(state, sourceClass.op, 0); + Value sendPayload; + if (descriptor.fragmentsPerLane == 1) { + Value offset = createIndexConstant(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front()); + sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0)); + } + else { + sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc); + } + + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); + return; + } + + Value lowerBound = createIndexConstant(state, sourceClass.op, 0); + Value upperBound = createIndexConstant(state, sourceClass.op, static_cast(channelIds.size())); + Value step = createIndexConstant(state, sourceClass.op, 1); + + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToStart(loop.getBody()); + + Value index = loop.getInductionVar(); + Value channelId = createIndexedIndexValue(state, sourceClass.op, channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, targetCoreIds, index, loc); + + Value sendPayload; + if (descriptor.fragmentsPerLane == 1) { + Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets, index, loc); + sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0)); + } + else { + sendPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc); + } + + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); +} + void appendSend(MaterializerState& state, MaterializedClass& sourceClass, Value payload, @@ -1394,6 +1735,7 @@ SmallVector collectDestinationClassesForKeys(MaterializerState& stat SmallVector emitScalarSourceSends(MaterializerState& state, MaterializedClass& sourceClass, + ArrayRef keys, ArrayRef destinationClasses, Value payload, Location loc) { @@ -1401,25 +1743,44 @@ SmallVector emitScalarSourceSends(MaterializerState& int32_t sourceCpu = static_cast(sourceClass.cpus.front()); - size_t messageCount = 0; - for (ClassId destinationClass : destinationClasses) { - if (destinationClass == sourceClass.id) - continue; - - MaterializedClass& targetClass = state.classes[destinationClass]; - messageCount += targetClass.isBatch ? targetClass.cpus.size() : 1; - } - - SmallVector allChannelIds; - SmallVector allSourceCoreIds; - SmallVector allTargetCoreIds; - allChannelIds.reserve(messageCount); - allSourceCoreIds.reserve(messageCount); - allTargetCoreIds.reserve(messageCount); - SmallVector receivePlans; receivePlans.reserve(destinationClasses.size()); + const auto tryEmitProjected = [&](ClassId destinationClass, + const SmallVector& channelIds, + const SmallVector& sourceCoreIds, + const SmallVector& targetCoreIds) -> bool { + if (keys.size() != 1) + return false; + + MaterializedClass& targetClass = state.classes[destinationClass]; + if (!targetClass.isBatch) + return false; + + auto producerIt = state.projectedTransfers.find(keys.front()); + if (producerIt == state.projectedTransfers.end()) + return false; + + auto descriptorIt = producerIt->second.find(destinationClass); + if (descriptorIt == producerIt->second.end()) + return false; + + const ProjectedTransferDescriptor& descriptor = descriptorIt->second; + if (descriptor.laneMajorSourceDim0Offsets.size() + != targetClass.cpus.size() * static_cast(descriptor.fragmentsPerLane)) + return false; + + appendProjectedScalarSendLoop( + state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc); + + Value received = appendReceive( + state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc); + + state.projectedExtractReplacements[descriptor.extractOp][destinationClass] = + ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane}; + return true; + }; + for (ClassId destinationClass : destinationClasses) { if (destinationClass == sourceClass.id) continue; @@ -1435,10 +1796,6 @@ SmallVector emitScalarSourceSends(MaterializerState& plan.channelIds.push_back(channelId); plan.sourceCoreIds.push_back(sourceCpu); plan.targetCoreIds.push_back(targetCpu); - - allChannelIds.push_back(channelId); - allSourceCoreIds.push_back(sourceCpu); - allTargetCoreIds.push_back(targetCpu); }; if (!targetClass.isBatch) @@ -1447,12 +1804,13 @@ SmallVector emitScalarSourceSends(MaterializerState& for (CpuId targetCpu : targetClass.cpus) appendMessage(static_cast(targetCpu)); + if (tryEmitProjected(destinationClass, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds)) + continue; + + appendSend(state, sourceClass, payload, plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc); receivePlans.push_back(std::move(plan)); } - if (!allChannelIds.empty()) - appendSend(state, sourceClass, payload, allChannelIds, allSourceCoreIds, allTargetCoreIds, loc); - return receivePlans; } @@ -1465,7 +1823,7 @@ LogicalResult emitScalarSourceCommunication( SmallVector destinationClasses = collectDestinationClassesForKeys(state, keys); SmallVector receivePlans = - emitScalarSourceSends(state, sourceClass, destinationClasses, payload, loc); + emitScalarSourceSends(state, sourceClass, keys, destinationClasses, payload, loc); for (const ScalarSourceReceivePlan& plan : receivePlans) { MaterializedClass& targetClass = state.classes[plan.targetClass]; @@ -1716,7 +2074,8 @@ FailureOr> cloneBatchBodyForLane(MaterializerState& state, MaterializedClass& targetClass, const ComputeInstance& instance, Value laneValue, - ArrayRef resultIndices); + ArrayRef resultIndices, + std::optional projectionSlotIndex = std::nullopt); FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, MaterializedClass& targetClass, @@ -1766,7 +2125,7 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices); + cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); if (failed(produced)) return failure(); @@ -1833,7 +2192,7 @@ FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices); + cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices, loopIndex); if (failed(produced)) return failure(); @@ -2153,6 +2512,19 @@ FailureOr resolveInputValue(MaterializerState& state, return appendInput(state, targetClass, input); } +bool hasProjectedInputReplacement( + MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) { + std::optional extract = matchSimpleLaneProjectedInput(batch, inputIndex); + if (!extract) + return false; + + auto replacementIt = state.projectedExtractReplacements.find(extract->getOperation()); + if (replacementIt == state.projectedExtractReplacements.end()) + return false; + + return replacementIt->second.find(classId) != replacementIt->second.end(); +} + void mapWeights(MaterializerState& state, MaterializedClass& targetClass, const ComputeInstance& instance, @@ -2195,6 +2567,9 @@ LogicalResult mapInputs(MaterializerState& state, auto batch = cast(op); for (auto [index, input] : llvm::enumerate(batch.getInputs())) { + if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) + continue; + FailureOr mapped = resolveInputValue(state, targetClass, input, instance); if (failed(mapped)) return batch.emitOpError("failed to resolve materialized compute_batch input"); @@ -2262,6 +2637,35 @@ SmallVector collectBatchOutputFragmentTypes(SpatComputeBatch batch) { return types; } +std::optional lookupProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract) { + auto replacementIt = state.projectedExtractReplacements.find(extract.getOperation()); + if (replacementIt == state.projectedExtractReplacements.end()) + return std::nullopt; + + auto classIt = replacementIt->second.find(targetClass.id); + if (classIt == replacementIt->second.end()) + return std::nullopt; + + return classIt->second; +} + +FailureOr materializeProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract, + const ProjectedExtractReplacement& replacement, + std::optional projectionSlotIndex) { + if (replacement.fragmentsPerLane == 1) + return replacement.payload; + + if (!projectionSlotIndex) + return targetClass.op->emitError("packed projected extract replacement requires a projection slot index"); + + return createDim0ExtractSlice( + state, extract.getLoc(), replacement.payload, *projectionSlotIndex, replacement.fragmentType.getDimSize(0)); +} + FailureOr> cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef peers) { assert(!peers.empty() && "expected at least one peer instance"); @@ -2301,6 +2705,19 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra Block& sourceBlock = getComputeInstanceTemplateBlock(instance); for (Operation& op : sourceBlock.without_terminator()) { + if (auto extract = dyn_cast(&op)) { + if (std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, extract)) { + FailureOr projected = + materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, std::nullopt); + if (failed(projected)) + return failure(); + + mapper.map(extract.getResult(), *projected); + continue; + } + } + Operation* cloned = state.rewriter.clone(op, mapper); for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) mapper.map(oldResult, newResult); @@ -2503,7 +2920,7 @@ LogicalResult emitPackedRunFanout(MaterializerState& state, assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); SmallVector receivePlans = - emitScalarSourceSends(state, sourceClass, destinationClasses, packed, loc); + emitScalarSourceSends(state, sourceClass, keys, destinationClasses, packed, loc); for (const ScalarSourceReceivePlan& plan : receivePlans) { MaterializedClass& targetClass = state.classes[plan.targetClass]; @@ -2522,7 +2939,8 @@ FailureOr> cloneBatchBodyForLane(MaterializerState& state, MaterializedClass& targetClass, const ComputeInstance& instance, Value laneValue, - ArrayRef resultIndices) { + ArrayRef resultIndices, + std::optional projectionSlotIndex) { auto batch = dyn_cast(instance.op); if (!batch) return failure(); @@ -2544,6 +2962,19 @@ FailureOr> cloneBatchBodyForLane(MaterializerState& state, Block& sourceBlock = getComputeInstanceTemplateBlock(instance); for (Operation& op : sourceBlock.without_terminator()) { + if (auto extract = dyn_cast(&op)) { + if (std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, extract)) { + FailureOr projected = materializeProjectedExtractReplacement( + state, targetClass, extract, *replacement, projectionSlotIndex); + if (failed(projected)) + return failure(); + + mapper.map(extract.getResult(), *projected); + continue; + } + } + Operation* cloned = state.rewriter.clone(op, mapper); for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) mapper.map(oldResult, newResult); @@ -2637,7 +3068,7 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc); FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices); + cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, loopIndex); if (failed(produced)) return failure(); @@ -3127,7 +3558,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); FailureOr> produced = - cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices); + cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices, slotIndex); if (failed(produced)) return failure(); @@ -3286,6 +3717,8 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch createEmptyMaterializedOps(state); if (failed(collectProducerDestinations(state))) return failure(); + if (failed(collectProjectedTransfers(state))) + return failure(); for (const ComputeInstance& instance : schedule.dominanceOrderCompute) if (failed(materializeInstanceSlot(state, instance)))