fix scheduling cost model
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-27 17:14:19 +02:00
parent 874a2f53e6
commit 783dffe553
2 changed files with 43 additions and 5 deletions
@@ -1220,9 +1220,38 @@ LogicalResult collectProducerDestinations(MaterializerState& state) {
return success();
}
static bool isLaneProjectedOffsetValue(Value value, Value expected, bool& usesExpected) {
if (value == expected) {
usesExpected = true;
return true;
}
if (matchPattern(value, m_Constant()))
return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (!affineApply || affineApply.getAffineMap().getNumResults() != 1)
return false;
bool nestedUsesExpected = false;
for (Value operand : affineApply.getMapOperands()) {
bool operandUsesExpected = false;
if (!isLaneProjectedOffsetValue(operand, expected, operandUsesExpected))
return false;
nestedUsesExpected = nestedUsesExpected || operandUsesExpected;
}
usesExpected = usesExpected || nestedUsesExpected;
return nestedUsesExpected;
}
bool isValueOffset(OpFoldResult offset, Value expected) {
auto value = dyn_cast<Value>(offset);
return value && value == expected;
if (!value)
return false;
bool usesExpected = false;
return isLaneProjectedOffsetValue(value, expected, usesExpected) && usesExpected;
}
bool isStaticIndexAttr(OpFoldResult value, int64_t expected) {
@@ -89,9 +89,18 @@ bool isUsedAsWeightOnly(Operation* producerOp) {
return true;
}
bool isLaneOffset(OpFoldResult offset, Value laneArg) {
auto offsetValue = llvm::dyn_cast<Value>(offset);
return offsetValue == laneArg;
static FailureOr<int64_t>
evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg);
static FailureOr<int64_t> evaluateIndexLike(OpFoldResult value,
const DenseMap<Value, int64_t>& bindings,
std::optional<uint32_t> lane,
Value laneArg);
bool isProjectedBatchOffset(OpFoldResult offset, Value laneArg) {
DenseMap<Value, int64_t> bindings;
return succeeded(evaluateIndexLike(offset, bindings, /*lane=*/0, laneArg))
&& succeeded(evaluateIndexLike(offset, bindings, /*lane=*/1, laneArg));
}
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
@@ -110,7 +119,7 @@ std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, V
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
if (!extract || extract.getSource() != *inputArg)
return std::nullopt;
if (extract.getMixedOffsets().empty() || !isLaneOffset(extract.getMixedOffsets().front(), *laneArg))
if (extract.getMixedOffsets().empty() || !isProjectedBatchOffset(extract.getMixedOffsets().front(), *laneArg))
return std::nullopt;
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());