Fix vgg16_depth05 bug

This commit is contained in:
ilgeco
2026-06-30 14:54:33 +02:00
parent 94c96195b9
commit f5e1c2e706
@@ -3767,9 +3767,6 @@ FailureOr<ScalarSourceFanoutPlan> buildScalarSourceFanoutPlan(MaterializerState&
if (*descriptor) {
const ProjectedTransferDescriptor& projectedDescriptor = **descriptor;
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType())
return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type");
receivePlan.receiveType = projectedDescriptor.payloadType;
receivePlan.projectedExtractOp = projectedDescriptor.extractOp;
receivePlan.projectedLayout = projectedDescriptor.layout;
@@ -7617,6 +7614,16 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
return success();
}
std::optional<int64_t> getStaticCommunicationCore(Value value) {
if (std::optional<int64_t> coreId = mlir::getConstantIntValue(value))
return coreId;
return std::nullopt;
}
bool isTopLevelCommunicationOp(Operation* op) {
return isa<SpatChannelSendOp, SpatChannelReceiveOp>(op);
}
bool valueMayEvaluateToCore(Value value, int64_t coreId) {
if (std::optional<int64_t> constant = getConstantIndexValue(value))
return *constant == coreId;
@@ -7645,7 +7652,7 @@ bool valueMayEvaluateToCore(Value value, int64_t coreId) {
return false;
for (int64_t iteration = *lower; iteration < *upper; iteration += *step) {
FailureOr<int64_t> evaluated = evaluateSingleResultAffineMap(map, ArrayRef<int64_t> {iteration});
FailureOr<int64_t> evaluated = evaluateSingleResultAffineMap(map, ArrayRef<int64_t>{iteration});
if (succeeded(evaluated) && *evaluated == coreId)
return true;
}
@@ -7653,62 +7660,186 @@ bool valueMayEvaluateToCore(Value value, int64_t coreId) {
return false;
}
bool operationContainsReceiveFromPeer(Operation& op, int64_t localCore, int64_t peerCore, Type payloadType) {
bool operationContainsCommunication(Operation& op) {
bool found = false;
op.walk([&](SpatChannelReceiveOp receive) {
if (receive.getOutput().getType() != payloadType)
return;
if (!valueMayEvaluateToCore(receive.getTargetCoreId(), localCore))
return;
if (!valueMayEvaluateToCore(receive.getSourceCoreId(), peerCore))
return;
WalkResult walkResult = op.walk([&](Operation* nestedOp) -> WalkResult {
if (!isa<SpatChannelSendOp, SpatChannelReceiveOp>(nestedOp))
return WalkResult::advance();
found = true;
return WalkResult::interrupt();
});
(void) walkResult;
return found;
}
bool operationContainsSend(Operation& op) {
bool found = false;
WalkResult walkResult = op.walk([&](SpatChannelSendOp) -> WalkResult {
found = true;
return WalkResult::interrupt();
});
(void) walkResult;
return found;
}
bool operationContainsReceiveFromPeer(Operation& op, int64_t sourceCoreId, int64_t targetCoreId, Type payloadType) {
bool found = false;
WalkResult walkResult = op.walk([&](SpatChannelReceiveOp receive) -> WalkResult {
if (receive.getType() != payloadType)
return WalkResult::advance();
if (!valueMayEvaluateToCore(receive.getSourceCoreId(), targetCoreId))
return WalkResult::advance();
if (!valueMayEvaluateToCore(receive.getTargetCoreId(), sourceCoreId))
return WalkResult::advance();
found = true;
return WalkResult::interrupt();
});
(void) walkResult;
return found;
}
bool valueDependsOn(Value value, Value dependency) {
if (value == dependency)
return true;
SmallVector<Value, 8> worklist {value};
DenseSet<Value> visited;
while (!worklist.empty()) {
Value current = worklist.pop_back_val();
if (!visited.insert(current).second)
continue;
if (current == dependency)
return true;
Operation* definingOp = current.getDefiningOp();
if (!definingOp)
continue;
llvm::append_range(worklist, definingOp->getOperands());
}
return false;
}
bool opDependsOnValue(Operation& op, Value dependency) {
bool found = false;
WalkResult walkResult = op.walk([&](Operation* nestedOp) -> WalkResult {
for (Value operand : nestedOp->getOperands()) {
if (!valueDependsOn(operand, dependency))
continue;
found = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
(void) walkResult;
return found;
}
LogicalResult orderLowerCoreScalarSendsAfterMatchingReceives(MaterializerState& state) {
for (MaterializedClass& materializedClass : state.classes) {
if (materializedClass.isBatch || materializedClass.cpus.empty())
if (materializedClass.isBatch)
continue;
int64_t localCore = static_cast<int64_t>(materializedClass.cpus.front());
Block* body = materializedClass.body;
if (!body)
continue;
Block& body = *materializedClass.body;
SmallVector<Operation*, 16> topLevelOps;
for (Operation& op : body.without_terminator())
topLevelOps.push_back(&op);
bool changed = true;
while (changed) {
changed = false;
for (Operation& op : llvm::make_early_inc_range(*body)) {
if (&op == body->getTerminator())
break;
for (Operation* op : topLevelOps) {
auto send = dyn_cast<SpatChannelSendOp>(op);
if (!send || send->getBlock() != &body)
continue;
auto send = dyn_cast<SpatChannelSendOp>(&op);
if (!send)
continue;
std::optional<int64_t> sourceCoreId = getStaticCommunicationCore(send.getSourceCoreId());
std::optional<int64_t> targetCoreId = getStaticCommunicationCore(send.getTargetCoreId());
if (!sourceCoreId || !targetCoreId || *sourceCoreId >= *targetCoreId)
continue;
std::optional<int64_t> sourceCore = getConstantIndexValue(send.getSourceCoreId());
std::optional<int64_t> targetCore = getConstantIndexValue(send.getTargetCoreId());
if (!sourceCore || !targetCore || *sourceCore != localCore || *sourceCore >= *targetCore)
continue;
Operation* matchingReceiveContainer = nullptr;
for (Operation* candidate = op.getNextNode(); candidate && candidate != body->getTerminator();
candidate = candidate->getNextNode()) {
if (operationContainsReceiveFromPeer(*candidate, localCore, *targetCore, send.getInput().getType())) {
matchingReceiveContainer = candidate;
Operation* anchor = nullptr;
for (Operation* cursor = send->getNextNode(); cursor && cursor != body.getTerminator(); cursor = cursor->getNextNode()) {
if (isTopLevelCommunicationOp(cursor)) {
auto receive = dyn_cast<SpatChannelReceiveOp>(cursor);
if (!receive)
break;
}
std::optional<int64_t> receiveSourceCoreId = getStaticCommunicationCore(receive.getSourceCoreId());
std::optional<int64_t> receiveTargetCoreId = getStaticCommunicationCore(receive.getTargetCoreId());
if (!receiveSourceCoreId || !receiveTargetCoreId)
break;
if (*receiveSourceCoreId != *targetCoreId || *receiveTargetCoreId != *sourceCoreId)
break;
if (receive.getType() != send.getInput().getType())
break;
anchor = receive;
break;
}
if (!matchingReceiveContainer)
if (!operationContainsCommunication(*cursor))
continue;
if (operationContainsSend(*cursor))
break;
if (!operationContainsReceiveFromPeer(*cursor, *sourceCoreId, *targetCoreId, send.getInput().getType()))
break;
op.moveAfter(matchingReceiveContainer);
changed = true;
anchor = cursor;
break;
}
if (!anchor)
continue;
send->moveAfter(anchor);
}
for (Operation* op : topLevelOps) {
auto receive = dyn_cast<SpatChannelReceiveOp>(op);
if (!receive || receive->getBlock() != &body)
continue;
std::optional<int64_t> sourceCoreId = getStaticCommunicationCore(receive.getSourceCoreId());
std::optional<int64_t> targetCoreId = getStaticCommunicationCore(receive.getTargetCoreId());
if (!sourceCoreId || !targetCoreId || *targetCoreId >= *sourceCoreId)
continue;
Operation* anchor = nullptr;
for (Operation* cursor = receive->getNextNode(); cursor && cursor != body.getTerminator(); cursor = cursor->getNextNode()) {
if (!isTopLevelCommunicationOp(cursor)) {
if (opDependsOnValue(*cursor, receive.getOutput()))
break;
continue;
}
auto send = dyn_cast<SpatChannelSendOp>(cursor);
if (!send || send.getInput().getType() != receive.getType())
break;
std::optional<int64_t> sendSourceCoreId = getStaticCommunicationCore(send.getSourceCoreId());
std::optional<int64_t> sendTargetCoreId = getStaticCommunicationCore(send.getTargetCoreId());
if (!sendSourceCoreId || !sendTargetCoreId)
break;
if (*sendSourceCoreId != *targetCoreId || *sendTargetCoreId != *sourceCoreId)
break;
bool hasInterveningUse = false;
for (Operation* between = receive->getNextNode(); between; between = between->getNextNode()) {
if (opDependsOnValue(*between, receive.getOutput())) {
hasInterveningUse = true;
break;
}
if (between == cursor)
break;
}
if (hasInterveningUse)
continue;
anchor = cursor;
break;
}
if (!anchor)
continue;
receive->moveAfter(anchor);
}
}