diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 35ee2be..3996742 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -3767,9 +3767,6 @@ FailureOr buildScalarSourceFanoutPlan(MaterializerState& if (*descriptor) { const ProjectedTransferDescriptor& projectedDescriptor = **descriptor; - if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) - return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type"); - receivePlan.receiveType = projectedDescriptor.payloadType; receivePlan.projectedExtractOp = projectedDescriptor.extractOp; receivePlan.projectedLayout = projectedDescriptor.layout; @@ -7617,6 +7614,16 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns return success(); } +std::optional getStaticCommunicationCore(Value value) { + if (std::optional coreId = mlir::getConstantIntValue(value)) + return coreId; + return std::nullopt; +} + +bool isTopLevelCommunicationOp(Operation* op) { + return isa(op); +} + bool valueMayEvaluateToCore(Value value, int64_t coreId) { if (std::optional constant = getConstantIndexValue(value)) return *constant == coreId; @@ -7645,7 +7652,7 @@ bool valueMayEvaluateToCore(Value value, int64_t coreId) { return false; for (int64_t iteration = *lower; iteration < *upper; iteration += *step) { - FailureOr evaluated = evaluateSingleResultAffineMap(map, ArrayRef {iteration}); + FailureOr evaluated = evaluateSingleResultAffineMap(map, ArrayRef{iteration}); if (succeeded(evaluated) && *evaluated == coreId) return true; } @@ -7653,62 +7660,186 @@ bool valueMayEvaluateToCore(Value value, int64_t coreId) { return false; } -bool operationContainsReceiveFromPeer(Operation& op, int64_t localCore, int64_t peerCore, Type payloadType) { +bool operationContainsCommunication(Operation& op) { bool found = false; - op.walk([&](SpatChannelReceiveOp receive) { - if (receive.getOutput().getType() != payloadType) - return; - if (!valueMayEvaluateToCore(receive.getTargetCoreId(), localCore)) - return; - if (!valueMayEvaluateToCore(receive.getSourceCoreId(), peerCore)) - return; + WalkResult walkResult = op.walk([&](Operation* nestedOp) -> WalkResult { + if (!isa(nestedOp)) + return WalkResult::advance(); found = true; + return WalkResult::interrupt(); }); + (void) walkResult; + return found; +} + +bool operationContainsSend(Operation& op) { + bool found = false; + WalkResult walkResult = op.walk([&](SpatChannelSendOp) -> WalkResult { + found = true; + return WalkResult::interrupt(); + }); + (void) walkResult; + return found; +} + +bool operationContainsReceiveFromPeer(Operation& op, int64_t sourceCoreId, int64_t targetCoreId, Type payloadType) { + bool found = false; + WalkResult walkResult = op.walk([&](SpatChannelReceiveOp receive) -> WalkResult { + if (receive.getType() != payloadType) + return WalkResult::advance(); + if (!valueMayEvaluateToCore(receive.getSourceCoreId(), targetCoreId)) + return WalkResult::advance(); + if (!valueMayEvaluateToCore(receive.getTargetCoreId(), sourceCoreId)) + return WalkResult::advance(); + + found = true; + return WalkResult::interrupt(); + }); + (void) walkResult; + return found; +} + +bool valueDependsOn(Value value, Value dependency) { + if (value == dependency) + return true; + + SmallVector worklist {value}; + DenseSet visited; + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current).second) + continue; + if (current == dependency) + return true; + + Operation* definingOp = current.getDefiningOp(); + if (!definingOp) + continue; + llvm::append_range(worklist, definingOp->getOperands()); + } + + return false; +} + +bool opDependsOnValue(Operation& op, Value dependency) { + bool found = false; + WalkResult walkResult = op.walk([&](Operation* nestedOp) -> WalkResult { + for (Value operand : nestedOp->getOperands()) { + if (!valueDependsOn(operand, dependency)) + continue; + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + (void) walkResult; return found; } LogicalResult orderLowerCoreScalarSendsAfterMatchingReceives(MaterializerState& state) { for (MaterializedClass& materializedClass : state.classes) { - if (materializedClass.isBatch || materializedClass.cpus.empty()) + if (materializedClass.isBatch) continue; - int64_t localCore = static_cast(materializedClass.cpus.front()); - Block* body = materializedClass.body; - if (!body) - continue; + Block& body = *materializedClass.body; + SmallVector topLevelOps; + for (Operation& op : body.without_terminator()) + topLevelOps.push_back(&op); - bool changed = true; - while (changed) { - changed = false; - for (Operation& op : llvm::make_early_inc_range(*body)) { - if (&op == body->getTerminator()) - break; + for (Operation* op : topLevelOps) { + auto send = dyn_cast(op); + if (!send || send->getBlock() != &body) + continue; - auto send = dyn_cast(&op); - if (!send) - continue; + std::optional sourceCoreId = getStaticCommunicationCore(send.getSourceCoreId()); + std::optional targetCoreId = getStaticCommunicationCore(send.getTargetCoreId()); + if (!sourceCoreId || !targetCoreId || *sourceCoreId >= *targetCoreId) + continue; - std::optional sourceCore = getConstantIndexValue(send.getSourceCoreId()); - std::optional targetCore = getConstantIndexValue(send.getTargetCoreId()); - if (!sourceCore || !targetCore || *sourceCore != localCore || *sourceCore >= *targetCore) - continue; - - Operation* matchingReceiveContainer = nullptr; - for (Operation* candidate = op.getNextNode(); candidate && candidate != body->getTerminator(); - candidate = candidate->getNextNode()) { - if (operationContainsReceiveFromPeer(*candidate, localCore, *targetCore, send.getInput().getType())) { - matchingReceiveContainer = candidate; + Operation* anchor = nullptr; + for (Operation* cursor = send->getNextNode(); cursor && cursor != body.getTerminator(); cursor = cursor->getNextNode()) { + if (isTopLevelCommunicationOp(cursor)) { + auto receive = dyn_cast(cursor); + if (!receive) break; - } + + std::optional receiveSourceCoreId = getStaticCommunicationCore(receive.getSourceCoreId()); + std::optional receiveTargetCoreId = getStaticCommunicationCore(receive.getTargetCoreId()); + if (!receiveSourceCoreId || !receiveTargetCoreId) + break; + if (*receiveSourceCoreId != *targetCoreId || *receiveTargetCoreId != *sourceCoreId) + break; + if (receive.getType() != send.getInput().getType()) + break; + + anchor = receive; + break; } - if (!matchingReceiveContainer) + if (!operationContainsCommunication(*cursor)) continue; + if (operationContainsSend(*cursor)) + break; + if (!operationContainsReceiveFromPeer(*cursor, *sourceCoreId, *targetCoreId, send.getInput().getType())) + break; - op.moveAfter(matchingReceiveContainer); - changed = true; + anchor = cursor; break; } + + if (!anchor) + continue; + send->moveAfter(anchor); + } + + for (Operation* op : topLevelOps) { + auto receive = dyn_cast(op); + if (!receive || receive->getBlock() != &body) + continue; + + std::optional sourceCoreId = getStaticCommunicationCore(receive.getSourceCoreId()); + std::optional targetCoreId = getStaticCommunicationCore(receive.getTargetCoreId()); + if (!sourceCoreId || !targetCoreId || *targetCoreId >= *sourceCoreId) + continue; + + Operation* anchor = nullptr; + for (Operation* cursor = receive->getNextNode(); cursor && cursor != body.getTerminator(); cursor = cursor->getNextNode()) { + if (!isTopLevelCommunicationOp(cursor)) { + if (opDependsOnValue(*cursor, receive.getOutput())) + break; + continue; + } + + auto send = dyn_cast(cursor); + if (!send || send.getInput().getType() != receive.getType()) + break; + + std::optional sendSourceCoreId = getStaticCommunicationCore(send.getSourceCoreId()); + std::optional sendTargetCoreId = getStaticCommunicationCore(send.getTargetCoreId()); + if (!sendSourceCoreId || !sendTargetCoreId) + break; + if (*sendSourceCoreId != *targetCoreId || *sendTargetCoreId != *sourceCoreId) + break; + + bool hasInterveningUse = false; + for (Operation* between = receive->getNextNode(); between; between = between->getNextNode()) { + if (opDependsOnValue(*between, receive.getOutput())) { + hasInterveningUse = true; + break; + } + if (between == cursor) + break; + } + if (hasInterveningUse) + continue; + + anchor = cursor; + break; + } + + if (!anchor) + continue; + receive->moveAfter(anchor); } }