diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index fb2ea0f..565d514 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -146,9 +146,7 @@ struct ProjectedBatchInputKeyInfo { return llvm::hash_combine(key.consumerOp, key.inputIndex); } - static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { - return lhs == rhs; - } + static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; } }; struct ProjectedTransferDescriptor { @@ -201,6 +199,7 @@ struct MaterializerState { DenseSet materializedSlots; DenseMap, ProducerKeyInfo> producerDestClasses; + DenseMap, ProducerKeyInfo> sameClassConsumers; DenseMap, ProducerKeyInfo> projectedTransfers; DenseMap> projectedExtractReplacements; AvailableValueStore availableValues; @@ -1187,8 +1186,10 @@ LogicalResult collectProducerDestinations(MaterializerState& state) { return consumer.op->emitError("schedule materialization found an input produced by an unscheduled compute"); ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); - if (sourceClass == targetClass) + if (sourceClass == targetClass) { + state.sameClassConsumers[producerKey].insert(targetClass); continue; + } appendDestinationClass(state, producerKey, targetClass); } @@ -1342,7 +1343,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { } for (auto& producerEntry : pending) { - ProducerKey producer = producerEntry.first; + ProducerKey producer = producerEntry.first; for (auto& classEntry : producerEntry.second) { ClassId targetClassId = classEntry.first; PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second; @@ -1368,9 +1369,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { SmallVector payloadShape(pendingDescriptor.fragmentType.getShape()); payloadShape[0] *= static_cast(fragmentsPerLane); - RankedTensorType payloadType = RankedTensorType::get(payloadShape, - pendingDescriptor.fragmentType.getElementType(), - pendingDescriptor.fragmentType.getEncoding()); + RankedTensorType payloadType = RankedTensorType::get( + payloadShape, pendingDescriptor.fragmentType.getElementType(), pendingDescriptor.fragmentType.getEncoding()); ProjectedTransferDescriptor descriptor; descriptor.inputKey = pendingDescriptor.inputKey; @@ -1506,11 +1506,9 @@ Value buildProjectedPackedPayload(MaterializerState& state, Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); - Value sourceOffset = - createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc); + Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorSourceDim0Offsets, flatIndex, loc); - Value fragment = - createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0)); + Value fragment = createDim0ExtractSlice(state, loc, fullPayload, sourceOffset, descriptor.fragmentType.getDimSize(0)); Value next = createDim0InsertSlice(state, loc, fragment, acc, fragmentIndex); scf::YieldOp::create(state.rewriter, loc, next); @@ -1773,8 +1771,8 @@ SmallVector emitScalarSourceSends(MaterializerState& appendProjectedScalarSendLoop( state, sourceClass, payload, descriptor, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = appendReceive( - state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = + appendReceive(state, targetClass, descriptor.payloadType, channelIds, sourceCoreIds, targetCoreIds, loc); state.projectedExtractReplacements[descriptor.extractOp][destinationClass] = ProjectedExtractReplacement {received, descriptor.fragmentType, descriptor.fragmentsPerLane}; @@ -2512,8 +2510,10 @@ FailureOr resolveInputValue(MaterializerState& state, return appendInput(state, targetClass, input); } -bool hasProjectedInputReplacement( - MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) { +bool hasProjectedInputReplacement(MaterializerState& state, + SpatComputeBatch batch, + unsigned inputIndex, + ClassId classId) { std::optional extract = matchSimpleLaneProjectedInput(batch, inputIndex); if (!extract) return false; @@ -2965,8 +2965,8 @@ FailureOr> cloneBatchBodyForLane(MaterializerState& state, if (auto extract = dyn_cast(&op)) { if (std::optional replacement = lookupProjectedExtractReplacement(state, targetClass, extract)) { - FailureOr projected = materializeProjectedExtractReplacement( - state, targetClass, extract, *replacement, projectionSlotIndex); + FailureOr projected = + materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, projectionSlotIndex); if (failed(projected)) return failure(); @@ -3288,26 +3288,8 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, } bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { - for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) { - auto cpuIt = state.schedule.computeToCpuMap.find(consumer); - if (cpuIt == state.schedule.computeToCpuMap.end()) - continue; - - if (state.cpuToClass.lookup(cpuIt->second) != classId) - continue; - - for (Value input : getComputeInstanceInputs(consumer)) { - std::optional producer = getProducerKey(input, &consumer); - if (!producer) - continue; - - for (ProducerKey expanded : expandWholeBatchProducerKey(*producer)) - if (expanded == producerKey) - return true; - } - } - - return false; + auto it = state.sameClassConsumers.find(producerKey); + return it != state.sameClassConsumers.end() && it->second.contains(classId); } bool canCompactBatchClassRun(MaterializerState& state,