Fix vgg16_depth05 bug
This commit is contained in:
+171
-40
@@ -3767,9 +3767,6 @@ FailureOr<ScalarSourceFanoutPlan> buildScalarSourceFanoutPlan(MaterializerState&
|
||||
if (*descriptor) {
|
||||
const ProjectedTransferDescriptor& projectedDescriptor = **descriptor;
|
||||
|
||||
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType())
|
||||
return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type");
|
||||
|
||||
receivePlan.receiveType = projectedDescriptor.payloadType;
|
||||
receivePlan.projectedExtractOp = projectedDescriptor.extractOp;
|
||||
receivePlan.projectedLayout = projectedDescriptor.layout;
|
||||
@@ -7617,6 +7614,16 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
|
||||
return success();
|
||||
}
|
||||
|
||||
std::optional<int64_t> getStaticCommunicationCore(Value value) {
|
||||
if (std::optional<int64_t> coreId = mlir::getConstantIntValue(value))
|
||||
return coreId;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool isTopLevelCommunicationOp(Operation* op) {
|
||||
return isa<SpatChannelSendOp, SpatChannelReceiveOp>(op);
|
||||
}
|
||||
|
||||
bool valueMayEvaluateToCore(Value value, int64_t coreId) {
|
||||
if (std::optional<int64_t> constant = getConstantIndexValue(value))
|
||||
return *constant == coreId;
|
||||
@@ -7645,7 +7652,7 @@ bool valueMayEvaluateToCore(Value value, int64_t coreId) {
|
||||
return false;
|
||||
|
||||
for (int64_t iteration = *lower; iteration < *upper; iteration += *step) {
|
||||
FailureOr<int64_t> evaluated = evaluateSingleResultAffineMap(map, ArrayRef<int64_t> {iteration});
|
||||
FailureOr<int64_t> evaluated = evaluateSingleResultAffineMap(map, ArrayRef<int64_t>{iteration});
|
||||
if (succeeded(evaluated) && *evaluated == coreId)
|
||||
return true;
|
||||
}
|
||||
@@ -7653,62 +7660,186 @@ bool valueMayEvaluateToCore(Value value, int64_t coreId) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool operationContainsReceiveFromPeer(Operation& op, int64_t localCore, int64_t peerCore, Type payloadType) {
|
||||
bool operationContainsCommunication(Operation& op) {
|
||||
bool found = false;
|
||||
op.walk([&](SpatChannelReceiveOp receive) {
|
||||
if (receive.getOutput().getType() != payloadType)
|
||||
return;
|
||||
if (!valueMayEvaluateToCore(receive.getTargetCoreId(), localCore))
|
||||
return;
|
||||
if (!valueMayEvaluateToCore(receive.getSourceCoreId(), peerCore))
|
||||
return;
|
||||
WalkResult walkResult = op.walk([&](Operation* nestedOp) -> WalkResult {
|
||||
if (!isa<SpatChannelSendOp, SpatChannelReceiveOp>(nestedOp))
|
||||
return WalkResult::advance();
|
||||
found = true;
|
||||
return WalkResult::interrupt();
|
||||
});
|
||||
(void) walkResult;
|
||||
return found;
|
||||
}
|
||||
|
||||
bool operationContainsSend(Operation& op) {
|
||||
bool found = false;
|
||||
WalkResult walkResult = op.walk([&](SpatChannelSendOp) -> WalkResult {
|
||||
found = true;
|
||||
return WalkResult::interrupt();
|
||||
});
|
||||
(void) walkResult;
|
||||
return found;
|
||||
}
|
||||
|
||||
bool operationContainsReceiveFromPeer(Operation& op, int64_t sourceCoreId, int64_t targetCoreId, Type payloadType) {
|
||||
bool found = false;
|
||||
WalkResult walkResult = op.walk([&](SpatChannelReceiveOp receive) -> WalkResult {
|
||||
if (receive.getType() != payloadType)
|
||||
return WalkResult::advance();
|
||||
if (!valueMayEvaluateToCore(receive.getSourceCoreId(), targetCoreId))
|
||||
return WalkResult::advance();
|
||||
if (!valueMayEvaluateToCore(receive.getTargetCoreId(), sourceCoreId))
|
||||
return WalkResult::advance();
|
||||
|
||||
found = true;
|
||||
return WalkResult::interrupt();
|
||||
});
|
||||
(void) walkResult;
|
||||
return found;
|
||||
}
|
||||
|
||||
bool valueDependsOn(Value value, Value dependency) {
|
||||
if (value == dependency)
|
||||
return true;
|
||||
|
||||
SmallVector<Value, 8> worklist {value};
|
||||
DenseSet<Value> visited;
|
||||
while (!worklist.empty()) {
|
||||
Value current = worklist.pop_back_val();
|
||||
if (!visited.insert(current).second)
|
||||
continue;
|
||||
if (current == dependency)
|
||||
return true;
|
||||
|
||||
Operation* definingOp = current.getDefiningOp();
|
||||
if (!definingOp)
|
||||
continue;
|
||||
llvm::append_range(worklist, definingOp->getOperands());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool opDependsOnValue(Operation& op, Value dependency) {
|
||||
bool found = false;
|
||||
WalkResult walkResult = op.walk([&](Operation* nestedOp) -> WalkResult {
|
||||
for (Value operand : nestedOp->getOperands()) {
|
||||
if (!valueDependsOn(operand, dependency))
|
||||
continue;
|
||||
found = true;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
(void) walkResult;
|
||||
return found;
|
||||
}
|
||||
|
||||
LogicalResult orderLowerCoreScalarSendsAfterMatchingReceives(MaterializerState& state) {
|
||||
for (MaterializedClass& materializedClass : state.classes) {
|
||||
if (materializedClass.isBatch || materializedClass.cpus.empty())
|
||||
if (materializedClass.isBatch)
|
||||
continue;
|
||||
|
||||
int64_t localCore = static_cast<int64_t>(materializedClass.cpus.front());
|
||||
Block* body = materializedClass.body;
|
||||
if (!body)
|
||||
continue;
|
||||
Block& body = *materializedClass.body;
|
||||
SmallVector<Operation*, 16> topLevelOps;
|
||||
for (Operation& op : body.without_terminator())
|
||||
topLevelOps.push_back(&op);
|
||||
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (Operation& op : llvm::make_early_inc_range(*body)) {
|
||||
if (&op == body->getTerminator())
|
||||
break;
|
||||
for (Operation* op : topLevelOps) {
|
||||
auto send = dyn_cast<SpatChannelSendOp>(op);
|
||||
if (!send || send->getBlock() != &body)
|
||||
continue;
|
||||
|
||||
auto send = dyn_cast<SpatChannelSendOp>(&op);
|
||||
if (!send)
|
||||
continue;
|
||||
std::optional<int64_t> sourceCoreId = getStaticCommunicationCore(send.getSourceCoreId());
|
||||
std::optional<int64_t> targetCoreId = getStaticCommunicationCore(send.getTargetCoreId());
|
||||
if (!sourceCoreId || !targetCoreId || *sourceCoreId >= *targetCoreId)
|
||||
continue;
|
||||
|
||||
std::optional<int64_t> sourceCore = getConstantIndexValue(send.getSourceCoreId());
|
||||
std::optional<int64_t> targetCore = getConstantIndexValue(send.getTargetCoreId());
|
||||
if (!sourceCore || !targetCore || *sourceCore != localCore || *sourceCore >= *targetCore)
|
||||
continue;
|
||||
|
||||
Operation* matchingReceiveContainer = nullptr;
|
||||
for (Operation* candidate = op.getNextNode(); candidate && candidate != body->getTerminator();
|
||||
candidate = candidate->getNextNode()) {
|
||||
if (operationContainsReceiveFromPeer(*candidate, localCore, *targetCore, send.getInput().getType())) {
|
||||
matchingReceiveContainer = candidate;
|
||||
Operation* anchor = nullptr;
|
||||
for (Operation* cursor = send->getNextNode(); cursor && cursor != body.getTerminator(); cursor = cursor->getNextNode()) {
|
||||
if (isTopLevelCommunicationOp(cursor)) {
|
||||
auto receive = dyn_cast<SpatChannelReceiveOp>(cursor);
|
||||
if (!receive)
|
||||
break;
|
||||
}
|
||||
|
||||
std::optional<int64_t> receiveSourceCoreId = getStaticCommunicationCore(receive.getSourceCoreId());
|
||||
std::optional<int64_t> receiveTargetCoreId = getStaticCommunicationCore(receive.getTargetCoreId());
|
||||
if (!receiveSourceCoreId || !receiveTargetCoreId)
|
||||
break;
|
||||
if (*receiveSourceCoreId != *targetCoreId || *receiveTargetCoreId != *sourceCoreId)
|
||||
break;
|
||||
if (receive.getType() != send.getInput().getType())
|
||||
break;
|
||||
|
||||
anchor = receive;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!matchingReceiveContainer)
|
||||
if (!operationContainsCommunication(*cursor))
|
||||
continue;
|
||||
if (operationContainsSend(*cursor))
|
||||
break;
|
||||
if (!operationContainsReceiveFromPeer(*cursor, *sourceCoreId, *targetCoreId, send.getInput().getType()))
|
||||
break;
|
||||
|
||||
op.moveAfter(matchingReceiveContainer);
|
||||
changed = true;
|
||||
anchor = cursor;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!anchor)
|
||||
continue;
|
||||
send->moveAfter(anchor);
|
||||
}
|
||||
|
||||
for (Operation* op : topLevelOps) {
|
||||
auto receive = dyn_cast<SpatChannelReceiveOp>(op);
|
||||
if (!receive || receive->getBlock() != &body)
|
||||
continue;
|
||||
|
||||
std::optional<int64_t> sourceCoreId = getStaticCommunicationCore(receive.getSourceCoreId());
|
||||
std::optional<int64_t> targetCoreId = getStaticCommunicationCore(receive.getTargetCoreId());
|
||||
if (!sourceCoreId || !targetCoreId || *targetCoreId >= *sourceCoreId)
|
||||
continue;
|
||||
|
||||
Operation* anchor = nullptr;
|
||||
for (Operation* cursor = receive->getNextNode(); cursor && cursor != body.getTerminator(); cursor = cursor->getNextNode()) {
|
||||
if (!isTopLevelCommunicationOp(cursor)) {
|
||||
if (opDependsOnValue(*cursor, receive.getOutput()))
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto send = dyn_cast<SpatChannelSendOp>(cursor);
|
||||
if (!send || send.getInput().getType() != receive.getType())
|
||||
break;
|
||||
|
||||
std::optional<int64_t> sendSourceCoreId = getStaticCommunicationCore(send.getSourceCoreId());
|
||||
std::optional<int64_t> sendTargetCoreId = getStaticCommunicationCore(send.getTargetCoreId());
|
||||
if (!sendSourceCoreId || !sendTargetCoreId)
|
||||
break;
|
||||
if (*sendSourceCoreId != *targetCoreId || *sendTargetCoreId != *sourceCoreId)
|
||||
break;
|
||||
|
||||
bool hasInterveningUse = false;
|
||||
for (Operation* between = receive->getNextNode(); between; between = between->getNextNode()) {
|
||||
if (opDependsOnValue(*between, receive.getOutput())) {
|
||||
hasInterveningUse = true;
|
||||
break;
|
||||
}
|
||||
if (between == cursor)
|
||||
break;
|
||||
}
|
||||
if (hasInterveningUse)
|
||||
continue;
|
||||
|
||||
anchor = cursor;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!anchor)
|
||||
continue;
|
||||
receive->moveAfter(anchor);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user