update cost model of batch lanes to consider only a slice of the shared batch input
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-22 22:16:19 +02:00
parent c77ffa9c56
commit 7f3c7464b4
@@ -7,7 +7,7 @@
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <limits>
#include <iterator>
#include <optional>
#include <queue>
#include <utility>
@@ -64,6 +64,49 @@ bool isUsedAsWeightOnly(Operation* producerOp) {
return true;
}
bool isLaneOffset(OpFoldResult offset, Value laneArg) {
auto offsetValue = llvm::dyn_cast<Value>(offset);
return offsetValue == laneArg;
}
std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
auto inputIt = llvm::find(batch.getInputs(), input);
if (inputIt == batch.getInputs().end())
return std::nullopt;
size_t inputIndex = std::distance(batch.getInputs().begin(), inputIt);
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
if (!inputArg || !laneArg)
return std::nullopt;
Weight projectedCost = 0;
for (Operation* user : inputArg->getUsers()) {
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
if (!extract || extract.getSource() != *inputArg)
return std::nullopt;
if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg))
return std::nullopt;
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return std::nullopt;
projectedCost = checkedAdd(projectedCost, static_cast<Weight>(getSizeInBytes(resultType)));
}
if (projectedCost == 0)
return std::nullopt;
return projectedCost;
}
Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
auto inputType = cast<ShapedType>(input.getType());
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
if (std::optional<Weight> projectedCost = getBatchProjectedInputTransferCost(batch, input))
return *projectedCost;
return static_cast<Weight>(getSizeInBytes(inputType));
}
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (const ComputeGraphEdge& edge : edges) {
@@ -136,15 +179,16 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
for (Value input : getComputeInstanceInputs(node.instance)) {
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
for (Value input : inputs) {
Weight transferCost = getInputTransferCost(node.instance, input);
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back(
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
}
continue;
}
@@ -155,8 +199,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
auto producerIt = graph.instanceToIndex.find(*producerInstance);
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back(
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
}
}