From 783dffe5534662111f63bcee0eb4e73162004d48 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 27 May 2026 17:14:19 +0200 Subject: [PATCH] fix scheduling cost model --- .../MaterializeMergeSchedule.cpp | 31 ++++++++++++++++++- .../Scheduling/ComputeGraph.cpp | 17 +++++++--- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index d9441e1..32e5ffe 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1220,9 +1220,38 @@ LogicalResult collectProducerDestinations(MaterializerState& state) { return success(); } +static bool isLaneProjectedOffsetValue(Value value, Value expected, bool& usesExpected) { + if (value == expected) { + usesExpected = true; + return true; + } + + if (matchPattern(value, m_Constant())) + return true; + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return false; + + bool nestedUsesExpected = false; + for (Value operand : affineApply.getMapOperands()) { + bool operandUsesExpected = false; + if (!isLaneProjectedOffsetValue(operand, expected, operandUsesExpected)) + return false; + nestedUsesExpected = nestedUsesExpected || operandUsesExpected; + } + + usesExpected = usesExpected || nestedUsesExpected; + return nestedUsesExpected; +} + bool isValueOffset(OpFoldResult offset, Value expected) { auto value = dyn_cast(offset); - return value && value == expected; + if (!value) + return false; + + bool usesExpected = false; + return isLaneProjectedOffsetValue(value, expected, usesExpected) && usesExpected; } bool isStaticIndexAttr(OpFoldResult value, int64_t expected) { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp index e193168..128eb13 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp @@ -89,9 +89,18 @@ bool isUsedAsWeightOnly(Operation* producerOp) { return true; } -bool isLaneOffset(OpFoldResult offset, Value laneArg) { - auto offsetValue = llvm::dyn_cast(offset); - return offsetValue == laneArg; +static FailureOr +evaluateIndexLike(Value value, const DenseMap& bindings, std::optional lane, Value laneArg); + +static FailureOr evaluateIndexLike(OpFoldResult value, + const DenseMap& bindings, + std::optional lane, + Value laneArg); + +bool isProjectedBatchOffset(OpFoldResult offset, Value laneArg) { + DenseMap bindings; + return succeeded(evaluateIndexLike(offset, bindings, /*lane=*/0, laneArg)) + && succeeded(evaluateIndexLike(offset, bindings, /*lane=*/1, laneArg)); } std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) { @@ -110,7 +119,7 @@ std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, V auto extract = dyn_cast(user); if (!extract || extract.getSource() != *inputArg) return std::nullopt; - if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg)) + if (extract.getMixedOffsets().empty() || !isProjectedBatchOffset(extract.getMixedOffsets().front(), *laneArg)) return std::nullopt; auto resultType = dyn_cast(extract.getResult().getType());