update cost model of batch lanes to consider only a slice of the shared batch input
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
@@ -64,6 +64,49 @@ bool isUsedAsWeightOnly(Operation* producerOp) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isLaneOffset(OpFoldResult offset, Value laneArg) {
|
||||
auto offsetValue = llvm::dyn_cast<Value>(offset);
|
||||
return offsetValue == laneArg;
|
||||
}
|
||||
|
||||
std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||
auto inputIt = llvm::find(batch.getInputs(), input);
|
||||
if (inputIt == batch.getInputs().end())
|
||||
return std::nullopt;
|
||||
|
||||
size_t inputIndex = std::distance(batch.getInputs().begin(), inputIt);
|
||||
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
|
||||
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
|
||||
if (!inputArg || !laneArg)
|
||||
return std::nullopt;
|
||||
|
||||
Weight projectedCost = 0;
|
||||
for (Operation* user : inputArg->getUsers()) {
|
||||
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||
if (!extract || extract.getSource() != *inputArg)
|
||||
return std::nullopt;
|
||||
if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg))
|
||||
return std::nullopt;
|
||||
|
||||
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return std::nullopt;
|
||||
projectedCost = checkedAdd(projectedCost, static_cast<Weight>(getSizeInBytes(resultType)));
|
||||
}
|
||||
|
||||
if (projectedCost == 0)
|
||||
return std::nullopt;
|
||||
return projectedCost;
|
||||
}
|
||||
|
||||
Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
|
||||
auto inputType = cast<ShapedType>(input.getType());
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
|
||||
if (std::optional<Weight> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
||||
return *projectedCost;
|
||||
return static_cast<Weight>(getSizeInBytes(inputType));
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (const ComputeGraphEdge& edge : edges) {
|
||||
@@ -136,15 +179,16 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
|
||||
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||
for (Value input : getComputeInstanceInputs(node.instance)) {
|
||||
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
|
||||
for (Value input : inputs) {
|
||||
Weight transferCost = getInputTransferCost(node.instance, input);
|
||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back(
|
||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -155,8 +199,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back(
|
||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user