diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 565d514..ed34d6b 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -2350,20 +2350,29 @@ LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) continue; - SmallVector runRanges; - runRanges.reserve(run.slots.size()); + SmallVector runRanges; for (const PackedScalarRunSlot& slot : run.slots) { - std::optional slotKey = getContiguousProducerKeyForKeys(slot.keys); - if (!slotKey) - return failure(); + for (ProducerKey fragmentKey : slot.keys) { + if (fragmentKey.instance.op != key.instance.op || fragmentKey.resultIndex != key.resultIndex) + return failure(); - if (wholeBatchRangeOverlaps(plan.coveredRanges, slotKey->instance.laneStart, slotKey->instance.laneCount)) - return failure(); + if (fragmentKey.instance.laneCount == 0) + return failure(); - runRanges.push_back({slotKey->instance.laneStart, slotKey->instance.laneCount}); + if (wholeBatchRangeOverlaps(plan.coveredRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) + return failure(); + + if (wholeBatchRangeOverlaps(runRanges, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) + return failure(); + + runRanges.push_back({fragmentKey.instance.laneStart, fragmentKey.instance.laneCount}); + } } + if (runRanges.empty()) + continue; + plan.packedRuns.push_back(&run); for (WholeBatchAssemblyRange range : runRanges)