faster scheduling: split batches into numCores tasks before scheduling instead of numLanes tasks
This commit is contained in:
+417
-421
File diff suppressed because it is too large
Load Diff
@@ -105,6 +105,28 @@ bool isProjectedBatchOffset(OpFoldResult offset, Value laneArg) {
|
||||
&& succeeded(evaluateIndexLike(offset, bindings, /*lane=*/1, laneArg));
|
||||
}
|
||||
|
||||
std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
|
||||
if (extract.getMixedOffsets().empty())
|
||||
return std::nullopt;
|
||||
|
||||
OpFoldResult offset = extract.getMixedOffsets().front();
|
||||
if (auto attr = llvm::dyn_cast<Attribute>(offset)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!intAttr || intAttr.getInt() < 0)
|
||||
return std::nullopt;
|
||||
return static_cast<uint32_t>(intAttr.getInt());
|
||||
}
|
||||
|
||||
Value offsetValue = llvm::cast<Value>(offset);
|
||||
if (auto constantIndex = offsetValue.getDefiningOp<arith::ConstantIndexOp>()) {
|
||||
if (constantIndex.value() < 0)
|
||||
return std::nullopt;
|
||||
return static_cast<uint32_t>(constantIndex.value());
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||
auto inputIt = llvm::find(batch.getInputs(), input);
|
||||
if (inputIt == batch.getInputs().end())
|
||||
@@ -143,6 +165,102 @@ Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input)
|
||||
return static_cast<Cost>(getSizeInBytes(inputType));
|
||||
}
|
||||
|
||||
uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance& rhs) {
|
||||
uint32_t lhsEnd = lhs.laneStart + lhs.laneCount;
|
||||
uint32_t rhsEnd = rhs.laneStart + rhs.laneCount;
|
||||
return std::max(lhs.laneStart, rhs.laneStart) < std::min(lhsEnd, rhsEnd)
|
||||
? std::min(lhsEnd, rhsEnd) - std::max(lhs.laneStart, rhs.laneStart)
|
||||
: 0;
|
||||
}
|
||||
|
||||
Cost scaleTransferCostByLaneCount(Cost totalCost, uint32_t totalLaneCount, uint32_t fragmentLaneCount) {
|
||||
assert(totalLaneCount > 0 && "laneCount must be positive");
|
||||
assert(fragmentLaneCount > 0 && "fragmentLaneCount must be positive");
|
||||
if (fragmentLaneCount >= totalLaneCount)
|
||||
return totalCost;
|
||||
return checkedMultiply(totalCost, static_cast<Cost>(fragmentLaneCount)) / static_cast<Cost>(totalLaneCount);
|
||||
}
|
||||
|
||||
SmallVector<ProducerValueRef, 4> collectProducerValueRefs(Value value, const ComputeInstance& consumerInstance) {
|
||||
SmallVector<ProducerValueRef, 4> producers;
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return producers;
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
Value source = extract.getSource();
|
||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(source.getDefiningOp());
|
||||
if (batch && batch.getNumResults() != 0) {
|
||||
if (std::optional<uint32_t> lane = getConstantExtractLane(extract)) {
|
||||
ComputeInstance instance = getBatchChunkForLane(batch, *lane);
|
||||
producers.push_back({instance, 0});
|
||||
return producers;
|
||||
}
|
||||
|
||||
if (isa<SpatComputeBatch>(consumerInstance.op)) {
|
||||
for (ComputeInstance instance :
|
||||
getBatchChunksForRange(batch, consumerInstance.laneStart, consumerInstance.laneCount))
|
||||
producers.push_back({instance, 0});
|
||||
}
|
||||
else {
|
||||
for (ComputeInstance instance :
|
||||
getBatchChunksForRange(batch, 0, static_cast<uint32_t>(batch.getLaneCount())))
|
||||
producers.push_back({instance, 0});
|
||||
}
|
||||
return producers;
|
||||
}
|
||||
|
||||
value = source;
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return producers;
|
||||
}
|
||||
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
producers.push_back({ComputeInstance {compute.getOperation(), 0, 1},
|
||||
static_cast<size_t>(cast<OpResult>(value).getResultNumber())});
|
||||
return producers;
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||
if (batch.getNumResults() != 0) {
|
||||
uint32_t laneStart = isa<SpatComputeBatch>(consumerInstance.op) ? consumerInstance.laneStart : 0;
|
||||
uint32_t laneCount = isa<SpatComputeBatch>(consumerInstance.op)
|
||||
? consumerInstance.laneCount
|
||||
: static_cast<uint32_t>(batch.getLaneCount());
|
||||
for (ComputeInstance instance : getBatchChunksForRange(batch, laneStart, laneCount))
|
||||
producers.push_back({instance, 0});
|
||||
return producers;
|
||||
}
|
||||
|
||||
uint32_t lane = cast<OpResult>(value).getResultNumber();
|
||||
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||
producers.push_back({instance, lane - instance.laneStart});
|
||||
return producers;
|
||||
}
|
||||
|
||||
return producers;
|
||||
}
|
||||
|
||||
Cost getProducerTransferCost(Value input, const ComputeInstance& consumerInstance, const ProducerValueRef& producerRef) {
|
||||
Cost transferCost = getInputTransferCost(consumerInstance, input);
|
||||
auto producerBatch = dyn_cast<SpatComputeBatch>(producerRef.instance.op);
|
||||
if (!producerBatch || producerBatch.getNumResults() == 0)
|
||||
return transferCost;
|
||||
|
||||
if (auto consumerBatch = dyn_cast<SpatComputeBatch>(consumerInstance.op)) {
|
||||
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(consumerBatch, input)) {
|
||||
uint32_t overlapLaneCount = getLaneOverlapCount(consumerInstance, producerRef.instance);
|
||||
assert(overlapLaneCount > 0 && "projected batch edge must overlap consumer lanes");
|
||||
return checkedMultiply(*projectedCost, static_cast<Cost>(overlapLaneCount));
|
||||
}
|
||||
}
|
||||
|
||||
return scaleTransferCostByLaneCount(transferCost,
|
||||
static_cast<uint32_t>(producerBatch.getLaneCount()),
|
||||
producerRef.instance.laneCount);
|
||||
}
|
||||
|
||||
static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_t> lane) {
|
||||
CrossbarWeight weight;
|
||||
weight.opaqueValue = value;
|
||||
@@ -458,25 +576,13 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
|
||||
for (Value input : inputs) {
|
||||
Cost transferCost = getInputTransferCost(node.instance, input);
|
||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
|
||||
}
|
||||
continue;
|
||||
for (const ProducerValueRef& producerRef : collectProducerValueRefs(input, node.instance)) {
|
||||
auto producerIt = graph.instanceToIndex.find(producerRef.instance);
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back(
|
||||
{producerIt->second, targetIndex, getProducerTransferCost(input, node.instance, producerRef)});
|
||||
}
|
||||
|
||||
auto producerInstance = getComputeProducerInstance(input, &node.instance);
|
||||
if (!producerInstance)
|
||||
continue;
|
||||
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+55
-5
@@ -20,17 +20,67 @@ size_t getSchedulingCpuBudget() {
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return static_cast<size_t>(laneCount);
|
||||
return std::min(static_cast<size_t>(laneCount), getSchedulingCpuBudget());
|
||||
}
|
||||
|
||||
BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
size_t chunkCount = getBatchChunkTargetCount(laneCount);
|
||||
assert(chunkIndex < chunkCount && "chunkIndex out of range");
|
||||
|
||||
size_t laneCountSize = static_cast<size_t>(laneCount);
|
||||
size_t baseChunkSize = laneCountSize / chunkCount;
|
||||
size_t remainder = laneCountSize % chunkCount;
|
||||
size_t extraBefore = std::min(chunkIndex, remainder);
|
||||
size_t start = chunkIndex * baseChunkSize + extraBefore;
|
||||
size_t count = baseChunkSize + (chunkIndex < remainder ? 1 : 0);
|
||||
assert(count > 0 && "chunk size must be positive");
|
||||
return {static_cast<uint32_t>(start), static_cast<uint32_t>(count)};
|
||||
}
|
||||
|
||||
size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
assert(lane < static_cast<uint32_t>(laneCount) && "lane out of range");
|
||||
|
||||
size_t chunkCount = getBatchChunkTargetCount(laneCount);
|
||||
size_t laneCountSize = static_cast<size_t>(laneCount);
|
||||
size_t baseChunkSize = laneCountSize / chunkCount;
|
||||
size_t remainder = laneCountSize % chunkCount;
|
||||
size_t largeChunkSize = baseChunkSize + 1;
|
||||
size_t laneIndex = static_cast<size_t>(lane);
|
||||
size_t largerChunkLanes = remainder * largeChunkSize;
|
||||
|
||||
if (laneIndex < largerChunkLanes)
|
||||
return laneIndex / largeChunkSize;
|
||||
return remainder + ((laneIndex - largerChunkLanes) / baseChunkSize);
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
assert(chunkIndex < static_cast<size_t>(batch.getLaneCount()) && "chunkIndex out of range");
|
||||
return {batch.getOperation(), static_cast<uint32_t>(chunkIndex), 1};
|
||||
BatchChunkRange chunk = getBatchChunkRange(batch.getLaneCount(), chunkIndex);
|
||||
return {batch.getOperation(), chunk.laneStart, chunk.laneCount};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
assert(lane < static_cast<uint32_t>(batch.getLaneCount()) && "lane out of range");
|
||||
return {batch.getOperation(), lane, 1};
|
||||
return getBatchChunkForIndex(batch, getBatchChunkIndexForLane(batch.getLaneCount(), lane));
|
||||
}
|
||||
|
||||
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
|
||||
uint32_t laneStart,
|
||||
uint32_t laneCount) {
|
||||
llvm::SmallVector<ComputeInstance, 4> chunks;
|
||||
if (laneCount == 0)
|
||||
return chunks;
|
||||
|
||||
uint32_t laneEnd = laneStart + laneCount;
|
||||
assert(laneEnd >= laneStart && "lane range overflow");
|
||||
assert(laneEnd <= static_cast<uint32_t>(batch.getLaneCount()) && "lane range out of bounds");
|
||||
|
||||
size_t firstChunk = getBatchChunkIndexForLane(batch.getLaneCount(), laneStart);
|
||||
size_t lastChunk = getBatchChunkIndexForLane(batch.getLaneCount(), laneEnd - 1);
|
||||
chunks.reserve(lastChunk - firstChunk + 1);
|
||||
for (size_t chunkIndex = firstChunk; chunkIndex <= lastChunk; ++chunkIndex)
|
||||
chunks.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
||||
return chunks;
|
||||
}
|
||||
|
||||
static std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
|
||||
|
||||
+10
@@ -21,10 +21,20 @@ struct ProducerValueRef {
|
||||
size_t resultIndex = 0;
|
||||
};
|
||||
|
||||
struct BatchChunkRange {
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget();
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount);
|
||||
BatchChunkRange getBatchChunkRange(int32_t laneCount, size_t chunkIndex);
|
||||
size_t getBatchChunkIndexForLane(int32_t laneCount, uint32_t lane);
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||
llvm::SmallVector<ComputeInstance, 4> getBatchChunksForRange(SpatComputeBatch batch,
|
||||
uint32_t laneStart,
|
||||
uint32_t laneCount);
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
|
||||
const ComputeInstance* consumerInstance = nullptr);
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
#include "mlir/IR/Threading.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
Reference in New Issue
Block a user