automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-03 19:43:56 +02:00
parent dc5edd032c
commit 69021d56aa
12 changed files with 187 additions and 195 deletions
@@ -448,7 +448,10 @@ collectProducerKeysForDestinations(Value value, std::optional<ComputeInstance> l
auto result = dyn_cast<OpResult>(value);
if (!result)
return {};
keys.push_back({{compute.getOperation(), 0, 1}, result.getResultNumber()});
keys.push_back({
{compute.getOperation(), 0, 1},
result.getResultNumber()
});
return keys;
}
@@ -476,8 +479,8 @@ collectProducerKeysForDestinations(Value value, std::optional<ComputeInstance> l
return keys;
}
std::optional<ProducerKey>
getInputRequestProducerKey(Value value, std::optional<ComputeInstance> logicalConsumer = std::nullopt) {
std::optional<ProducerKey> getInputRequestProducerKey(Value value,
std::optional<ComputeInstance> logicalConsumer = std::nullopt) {
// Input resolution may request a whole-batch key for scalar consumers that read
// a complete resultful compute_batch value.
Operation* definingOp = value.getDefiningOp();
@@ -511,7 +514,10 @@ getInputRequestProducerKey(Value value, std::optional<ComputeInstance> logicalCo
auto result = dyn_cast<OpResult>(value);
if (!result)
return std::nullopt;
return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()};
return ProducerKey {
{compute.getOperation(), 0, 1},
result.getResultNumber()
};
}
if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) {
@@ -675,7 +681,7 @@ LogicalResult verifyScheduleEquivalenceMatchesLogicalStreams(MaterializerState&
return state.func.emitError("materialized class logical slot source op mismatch");
if (isa<SpatComputeBatch>(referenceInstance.op) != isa<SpatComputeBatch>(currentInstance.op))
return state.func.emitError("materialized class logical slot batch/scalar mismatch");
(void)slot;
(void) slot;
}
}
}
@@ -1707,7 +1713,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
continue;
}
(void)logicalSlot;
(void) logicalSlot;
descriptor.offsetsByLane[targetLane].push_back(*offset);
}
@@ -3268,9 +3274,7 @@ FailureOr<Value> materializeIndexedBatchRunReceive(MaterializerState& state,
}
FailureOr<SmallVector<Value, 4>>
cloneInstanceBody(MaterializerState& state,
MaterializedClass& targetClass,
ArrayRef<ComputeInstance> peers) {
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
assert(!peers.empty() && "expected at least one peer instance");
const ComputeInstance& instance = peers.front();
Operation* sourceOp = instance.op;
@@ -3620,8 +3624,7 @@ FailureOr<SmallVector<Value, 4>> materializeBatchOutputGroupLoop(MaterializerSta
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, item.laneStart);
return cloneBatchBodyForLane(
state, targetClass, item, laneValue, group.resultIndices, {});
return cloneBatchBodyForLane(state, targetClass, item, laneValue, group.resultIndices, {});
}
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
@@ -3733,8 +3736,7 @@ FailureOr<MaterializationRun> collectBatchMaterializationRun(MaterializerState&
if (state.materializedLogicalSlots.contains(classSlot))
break;
FailureOr<SmallVector<ComputeInstance, 8>> peers =
getMaterializationRunSlotPeers(state, targetClass, slot);
FailureOr<SmallVector<ComputeInstance, 8>> peers = getMaterializationRunSlotPeers(state, targetClass, slot);
if (failed(peers) || peers->empty())
break;
@@ -3818,10 +3820,9 @@ bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state,
const OutputDestinationGroup& group) {
for (size_t resultIndex : group.resultIndices) {
for (const MaterializationRunSlot& slot : run) {
for (const ComputeInstance& peer : slot.peers) {
for (const ComputeInstance& peer : slot.peers)
if (hasSameClassConsumer(state, {peer, resultIndex}, classId))
return true;
}
}
}
@@ -4020,7 +4021,7 @@ LogicalResult buildBatchRunSendPlans(MaterializerState& state,
plan.messages.targetCoreIds.reserve(messageCount);
for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) {
(void)slotIndex;
(void) slotIndex;
for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) {
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id");
if (failed(checkedSourceCpu))
@@ -4236,7 +4237,8 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns
return success();
if (isa<SpatComputeBatch>(instance.op)) {
FailureOr<MaterializationRun> run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
FailureOr<MaterializationRun> run =
collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op);
if (succeeded(run)) {
if (!targetClass.isBatch)
@@ -169,8 +169,8 @@ uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance&
uint32_t lhsEnd = lhs.laneStart + lhs.laneCount;
uint32_t rhsEnd = rhs.laneStart + rhs.laneCount;
return std::max(lhs.laneStart, rhs.laneStart) < std::min(lhsEnd, rhsEnd)
? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart)
: 0;
? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart)
: 0;
}
Cost scaleTransferCostByLaneCount(Cost totalCost, uint32_t totalLaneCount, uint32_t fragmentLaneCount) {
@@ -197,16 +197,13 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
return producers;
}
if (isa<SpatComputeBatch>(consumerInstance.op)) {
if (isa<SpatComputeBatch>(consumerInstance.op))
for (ComputeInstance instance :
getBatchChunksForRange(batch, consumerInstance.laneStart, consumerInstance.laneCount))
producers.push_back({instance, 0});
}
else {
for (ComputeInstance instance :
getBatchChunksForRange(batch, 0, static_cast<uint32_t>(batch.getLaneCount())))
else
for (ComputeInstance instance : getBatchChunksForRange(batch, 0, static_cast<uint32_t>(batch.getLaneCount())))
producers.push_back({instance, 0});
}
return producers;
}
@@ -217,17 +214,18 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
}
if (auto compute = dyn_cast<SpatCompute>(op)) {
producers.push_back({ComputeInstance {compute.getOperation(), 0, 1},
static_cast<size_t>(cast<OpResult>(value).getResultNumber())});
producers.push_back({
ComputeInstance {compute.getOperation(), 0, 1},
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
});
return producers;
}
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
if (batch.getNumResults() != 0) {
uint32_t laneStart = isa<SpatComputeBatch>(consumerInstance.op) ? consumerInstance.laneStart : 0;
uint32_t laneCount = isa<SpatComputeBatch>(consumerInstance.op)
? consumerInstance.laneCount
: static_cast<uint32_t>(batch.getLaneCount());
uint32_t laneCount = isa<SpatComputeBatch>(consumerInstance.op) ? consumerInstance.laneCount
: static_cast<uint32_t>(batch.getLaneCount());
for (ComputeInstance instance : getBatchChunksForRange(batch, laneStart, laneCount))
producers.push_back({instance, 0});
return producers;
@@ -242,7 +240,9 @@ SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const Com
return producers;
}
Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstance, const ProducerValueRef& producerRef) {
Cost getProducerTransferCost(Value input,
const ComputeInstance& consumerInstance,
const ProducerValueRef& producerRef) {
Cost transferCost = getInputTransferCost(consumerInstance, input);
auto producerBatch = dyn_cast<SpatComputeBatch>(producerRef.instance.op);
if (!producerBatch || producerBatch.getNumResults() == 0)
@@ -256,9 +256,8 @@ Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstanc
}
}
return scaleTransferCostByLaneCount(transferCost,
static_cast<uint32_t>(producerBatch.getLaneCount()),
producerRef.instance.laneCount);
return scaleTransferCostByLaneCount(
transferCost, static_cast<uint32_t>(producerBatch.getLaneCount()), producerRef.instance.laneCount);
}
static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_t> lane) {
@@ -64,9 +64,8 @@ ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
return getBatchChunkForIndex(batch, getBatchChunkIndexForLane(batch.getLaneCount(), lane));
}
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
uint32_t laneStart,
uint32_t laneCount) {
llvm::SmallVector<ComputeInstance, 4>
getBatchChunksForRange(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount) {
llvm::SmallVector<ComputeInstance, 4> chunks;
if (laneCount == 0)
return chunks;
@@ -32,9 +32,8 @@ BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex);
size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane);
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
uint32_t laneStart,
uint32_t laneCount);
llvm::SmallVector<ComputeInstance, 4>
getBatchChunksForRange(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
const ComputeInstance* consumerInstance = nullptr);