faster scheduling: split batches into numCores tasks before scheduling instead of numLanes tasks

This commit is contained in:
NiccoloN
2026-06-03 19:40:34 +02:00
parent 37a59054a5
commit e33f517221
5 changed files with 606 additions and 446 deletions
File diff suppressed because it is too large Load Diff
@@ -105,6 +105,28 @@ bool isProjectedBatchOffset(OpFoldResult offset, Value laneArg) {
&& succeeded(evaluateIndexLike(offset, bindings, /*lane=*/1, laneArg));
}
std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
if (extract.getMixedOffsets().empty())
return std::nullopt;
OpFoldResult offset = extract.getMixedOffsets().front();
if (auto attr = llvm::dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (!intAttr || intAttr.getInt() < 0)
return std::nullopt;
return static_cast<uint32_t>(intAttr.getInt());
}
Value offsetValue = llvm::cast<Value>(offset);
if (auto constantIndex = offsetValue.getDefiningOp<arith::ConstantIndexOp>()) {
if (constantIndex.value() < 0)
return std::nullopt;
return static_cast<uint32_t>(constantIndex.value());
}
return std::nullopt;
}
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
auto inputIt = llvm::find(batch.getInputs(), input);
if (inputIt == batch.getInputs().end())
@@ -143,6 +165,102 @@ Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input)
return static_cast<Cost>(getSizeInBytes(inputType));
}
uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance& rhs) {
uint32_t lhsEnd = lhs.laneStart + lhs.laneCount;
uint32_t rhsEnd = rhs.laneStart + rhs.laneCount;
return std::max(lhs.laneStart, rhs.laneStart) < std::min(lhsEnd, rhsEnd)
? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart)
: 0;
}
Cost scaleTransferCostByLaneCount(Cost totalCost, uint32_t totalLaneCount, uint32_t fragmentLaneCount) {
assert(totalLaneCount > 0 && "laneCount must be positive");
assert(fragmentLaneCount > 0 && "fragmentLaneCount must be positive");
if (fragmentLaneCount >= totalLaneCount)
return totalCost;
return checkedMultiply(totalCost, static_cast<Cost>(fragmentLaneCount)) / static_cast<Cost>(totalLaneCount);
}
SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const ComputeInstance& consumerInstance) {
SmallVector<ProducerValueRef, 4> producers;
Operation* op = value.getDefiningOp();
if (!op)
return producers;
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
Value source = extract.getSource();
auto batch = dyn_cast_or_null<SpatComputeBatch>(source.getDefiningOp());
if (batch && batch.getNumResults() != 0) {
if (std::optional<uint32_t> lane = getConstantExtractLane(extract)) {
ComputeInstance instance = getBatchChunkForLane(batch, *lane);
producers.push_back({instance, 0});
return producers;
}
if (isa<SpatComputeBatch>(consumerInstance.op)) {
for (ComputeInstance instance :
getBatchChunksForRange(batch, consumerInstance.laneStart, consumerInstance.laneCount))
producers.push_back({instance, 0});
}
else {
for (ComputeInstance instance :
getBatchChunksForRange(batch, 0, static_cast<uint32_t>(batch.getLaneCount())))
producers.push_back({instance, 0});
}
return producers;
}
value = source;
op = value.getDefiningOp();
if (!op)
return producers;
}
if (auto compute = dyn_cast<SpatCompute>(op)) {
producers.push_back({ComputeInstance {compute.getOperation(), 0, 1},
static_cast<size_t>(cast<OpResult>(value).getResultNumber())});
return producers;
}
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
if (batch.getNumResults() != 0) {
uint32_t laneStart = isa<SpatComputeBatch>(consumerInstance.op) ? consumerInstance.laneStart : 0;
uint32_t laneCount = isa<SpatComputeBatch>(consumerInstance.op)
? consumerInstance.laneCount
: static_cast<uint32_t>(batch.getLaneCount());
for (ComputeInstance instance : getBatchChunksForRange(batch, laneStart, laneCount))
producers.push_back({instance, 0});
return producers;
}
uint32_t lane = cast<OpResult>(value).getResultNumber();
ComputeInstance instance = getBatchChunkForLane(batch, lane);
producers.push_back({instance, lane - instance.laneStart});
return producers;
}
return producers;
}
Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstance, const ProducerValueRef& producerRef) {
Cost transferCost = getInputTransferCost(consumerInstance, input);
auto producerBatch = dyn_cast<SpatComputeBatch>(producerRef.instance.op);
if (!producerBatch || producerBatch.getNumResults() == 0)
return transferCost;
if (auto consumerBatch = dyn_cast<SpatComputeBatch>(consumerInstance.op)) {
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(consumerBatch, input)) {
uint32_t overlapLaneCount = getLaneOverlapCount(consumerInstance, producerRef.instance);
assert(overlapLaneCount > 0 && "projected batch edge must overlap consumer lanes");
return checkedMultiply(*projectedCost, static_cast<Cost>(overlapLaneCount));
}
}
return scaleTransferCostByLaneCount(transferCost,
static_cast<uint32_t>(producerBatch.getLaneCount()),
producerRef.instance.laneCount);
}
static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_t> lane) {
CrossbarWeight weight;
weight.opaqueValue = value;
@@ -458,25 +576,13 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
for (Value input : inputs) {
Cost transferCost = getInputTransferCost(node.instance, input);
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
}
continue;
for (const ProducerValueRef& producerRef : collectProducerValueRefs(input, node.instance)) {
auto producerIt = graph.instanceToIndex.find(producerRef.instance);
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back(
{producerIt->second, targetIndex, getProducerTransferCost(input, node.instance, producerRef)});
}
auto producerInstance = getComputeProducerInstance(input, &node.instance);
if (!producerInstance)
continue;
auto producerIt = graph.instanceToIndex.find(*producerInstance);
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
}
}
@@ -20,17 +20,67 @@ size_t getSchedulingCpuBudget() {
size_t getBatchChunkTargetCount(int32_t laneCount) {
assert(laneCount > 0 && "laneCount must be positive");
return static_cast<size_t>(laneCount);
return std::min(static_cast<size_t>(laneCount), getSchedulingCpuBudget());
}
BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex) {
assert(laneCount > 0 && "laneCount must be positive");
size_t chunkCount = getBatchChunkTargetCount(laneCount);
assert(chunkIndex < chunkCount && "chunkIndex out of range");
size_t laneCountSize = static_cast<size_t>(laneCount);
size_t baseChunkSize = laneCountSize / chunkCount;
size_t remainder = laneCountSize % chunkCount;
size_t extraBefore = std::min(chunkIndex, remainder);
size_t start = chunkIndex * baseChunkSize + extraBefore;
size_t count = baseChunkSize + (chunkIndex < remainder ? 1 : 0);
assert(count > 0 && "chunk size must be positive");
return {static_cast<uint32_t>(start), static_cast<uint32_t>(count)};
}
size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane) {
assert(laneCount > 0 && "laneCount must be positive");
assert(lane < static_cast<uint32_t>(laneCount) && "lane out of range");
size_t chunkCount = getBatchChunkTargetCount(laneCount);
size_t laneCountSize = static_cast<size_t>(laneCount);
size_t baseChunkSize = laneCountSize / chunkCount;
size_t remainder = laneCountSize % chunkCount;
size_t largeChunkSize = baseChunkSize + 1;
size_t laneIndex = static_cast<size_t>(lane);
size_t largerChunkLanes = remainder * largeChunkSize;
if (laneIndex < largerChunkLanes)
return laneIndex / largeChunkSize;
return remainder + ((laneIndex - largerChunkLanes) / baseChunkSize);
}
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
assert(chunkIndex < static_cast<size_t>(batch.getLaneCount()) && "chunkIndex out of range");
return {batch.getOperation(), static_cast<uint32_t>(chunkIndex), 1};
BatchChunkRange chunk = getBatchChunkRange(batch.getLaneCount(), chunkIndex);
return {batch.getOperation(), chunk.laneStart, chunk.laneCount};
}
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
assert(lane < static_cast<uint32_t>(batch.getLaneCount()) && "lane out of range");
return {batch.getOperation(), lane, 1};
return getBatchChunkForIndex(batch, getBatchChunkIndexForLane(batch.getLaneCount(), lane));
}
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
uint32_t laneStart,
uint32_t laneCount) {
llvm::SmallVector<ComputeInstance, 4> chunks;
if (laneCount == 0)
return chunks;
uint32_t laneEnd = laneStart + laneCount;
assert(laneEnd >= laneStart && "lane range overflow");
assert(laneEnd <= static_cast<uint32_t>(batch.getLaneCount()) && "lane range out of bounds");
size_t firstChunk = getBatchChunkIndexForLane(batch.getLaneCount(), laneStart);
size_t lastChunk = getBatchChunkIndexForLane(batch.getLaneCount(), laneEnd - 1);
chunks.reserve(lastChunk - firstChunk + 1);
for (size_t chunkIndex = firstChunk; chunkIndex <= lastChunk; ++chunkIndex)
chunks.push_back(getBatchChunkForIndex(batch, chunkIndex));
return chunks;
}
static std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
@@ -21,10 +21,20 @@ struct ProducerValueRef {
size_t resultIndex = 0;
};
struct BatchChunkRange {
uint32_t laneStart = 0;
uint32_t laneCount = 0;
};
size_t getSchedulingCpuBudget();
size_t getBatchChunkTargetCount(int32_t laneCount);
BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex);
size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane);
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
uint32_t laneStart,
uint32_t laneCount);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
const ComputeInstance* consumerInstance = nullptr);
@@ -1,11 +1,9 @@
#include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <functional>
#include <limits>
#include <queue>
#include <vector>