From 7f3c7464b47448e609adfddf3787857c626194c2 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 22 May 2026 22:16:19 +0200 Subject: [PATCH] update cost model of batch lanes to consider only a slice of the shared batch input --- .../Scheduling/ComputeGraph.cpp | 55 +++++++++++++++++-- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp index b3831f6..6ca07f3 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp @@ -7,7 +7,7 @@ #include "llvm/Support/Casting.h" #include -#include +#include #include #include #include @@ -64,6 +64,49 @@ bool isUsedAsWeightOnly(Operation* producerOp) { return true; } +bool isLaneOffset(OpFoldResult offset, Value laneArg) { + auto offsetValue = llvm::dyn_cast(offset); + return offsetValue == laneArg; +} + +std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) { + auto inputIt = llvm::find(batch.getInputs(), input); + if (inputIt == batch.getInputs().end()) + return std::nullopt; + + size_t inputIndex = std::distance(batch.getInputs().begin(), inputIt); + std::optional inputArg = batch.getInputArgument(inputIndex); + std::optional laneArg = batch.getLaneArgument(); + if (!inputArg || !laneArg) + return std::nullopt; + + Weight projectedCost = 0; + for (Operation* user : inputArg->getUsers()) { + auto extract = dyn_cast(user); + if (!extract || extract.getSource() != *inputArg) + return std::nullopt; + if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg)) + return std::nullopt; + + auto resultType = dyn_cast(extract.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) + return std::nullopt; + projectedCost = checkedAdd(projectedCost, static_cast(getSizeInBytes(resultType))); + } + + if (projectedCost == 0) + return std::nullopt; + return projectedCost; +} + +Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) { + auto inputType = cast(input.getType()); + if (auto batch = dyn_cast(consumerInstance.op)) + if (std::optional projectedCost = getBatchProjectedInputTransferCost(batch, input)) + return *projectedCost; + return static_cast(getSizeInBytes(inputType)); +} + std::vector aggregateEdges(llvm::ArrayRef edges) { llvm::DenseMap, Weight> edgeWeights; for (const ComputeGraphEdge& edge : edges) { @@ -136,15 +179,16 @@ ComputeGraph buildComputeGraph(Operation* entryOp) { llvm::SmallVector rawEdges; for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) { - for (Value input : getComputeInstanceInputs(node.instance)) { + llvm::SmallVector inputs = getComputeInstanceInputs(node.instance); + for (Value input : inputs) { + Weight transferCost = getInputTransferCost(node.instance, input); if (auto producerBatch = dyn_cast_or_null(input.getDefiningOp()); producerBatch && producerBatch.getNumResults() != 0 && !isa(node.instance.op)) { for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) { auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane)); if (producerIt == graph.instanceToIndex.end()) continue; - rawEdges.push_back( - {producerIt->second, targetIndex, static_cast(getSizeInBytes(cast(input.getType())))}); + rawEdges.push_back({producerIt->second, targetIndex, transferCost}); } continue; } @@ -155,8 +199,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) { auto producerIt = graph.instanceToIndex.find(*producerInstance); if (producerIt == graph.instanceToIndex.end()) continue; - rawEdges.push_back( - {producerIt->second, targetIndex, static_cast(getSizeInBytes(cast(input.getType())))}); + rawEdges.push_back({producerIt->second, targetIndex, transferCost}); } }