update cost model of batch lanes to consider only a slice of the shared batch input
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -7,7 +7,7 @@
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <limits>
|
#include <iterator>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@@ -64,6 +64,49 @@ bool isUsedAsWeightOnly(Operation* producerOp) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isLaneOffset(OpFoldResult offset, Value laneArg) {
|
||||||
|
auto offsetValue = llvm::dyn_cast<Value>(offset);
|
||||||
|
return offsetValue == laneArg;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||||
|
auto inputIt = llvm::find(batch.getInputs(), input);
|
||||||
|
if (inputIt == batch.getInputs().end())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
size_t inputIndex = std::distance(batch.getInputs().begin(), inputIt);
|
||||||
|
std::optional<BlockArgument> inputArg = batch.getInputArgument(inputIndex);
|
||||||
|
std::optional<BlockArgument> laneArg = batch.getLaneArgument();
|
||||||
|
if (!inputArg || !laneArg)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
Weight projectedCost = 0;
|
||||||
|
for (Operation* user : inputArg->getUsers()) {
|
||||||
|
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||||
|
if (!extract || extract.getSource() != *inputArg)
|
||||||
|
return std::nullopt;
|
||||||
|
if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return std::nullopt;
|
||||||
|
projectedCost = checkedAdd(projectedCost, static_cast<Weight>(getSizeInBytes(resultType)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (projectedCost == 0)
|
||||||
|
return std::nullopt;
|
||||||
|
return projectedCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
|
||||||
|
auto inputType = cast<ShapedType>(input.getType());
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
|
||||||
|
if (std::optional<Weight> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
||||||
|
return *projectedCost;
|
||||||
|
return static_cast<Weight>(getSizeInBytes(inputType));
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
for (const ComputeGraphEdge& edge : edges) {
|
for (const ComputeGraphEdge& edge : edges) {
|
||||||
@@ -136,15 +179,16 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
|||||||
|
|
||||||
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||||
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||||
for (Value input : getComputeInstanceInputs(node.instance)) {
|
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
|
||||||
|
for (Value input : inputs) {
|
||||||
|
Weight transferCost = getInputTransferCost(node.instance, input);
|
||||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||||
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
||||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||||
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
|
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
|
||||||
if (producerIt == graph.instanceToIndex.end())
|
if (producerIt == graph.instanceToIndex.end())
|
||||||
continue;
|
continue;
|
||||||
rawEdges.push_back(
|
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
|
||||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -155,8 +199,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
|||||||
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||||
if (producerIt == graph.instanceToIndex.end())
|
if (producerIt == graph.instanceToIndex.end())
|
||||||
continue;
|
continue;
|
||||||
rawEdges.push_back(
|
rawEdges.push_back({producerIt->second, targetIndex, transferCost});
|
||||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user