simplify affine maps to constants where possible
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-27 16:39:27 +02:00
parent 1a5d7d2a3f
commit 4bdaa57656
2 changed files with 65 additions and 5 deletions
@@ -1019,6 +1019,27 @@ std::optional<IndexedIndexPattern> getIndexedIndexPattern(ArrayRef<int64_t> valu
return std::nullopt;
}
Value createAffineApplyOrConstant(
MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) {
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
auto constantValue = getConstantIntValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
if (constantResult)
return createIndexConstant(state, anchor, constantResult.getInt());
}
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
}
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
MLIRContext* context = state.func.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
@@ -1033,7 +1054,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern
}
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {index}).getResult();
return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func);
}
Value createIndexedIndexValue(
@@ -3346,7 +3367,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}).getResult();
return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
}
Value createBatchClassRunSourceLane(MaterializerState& state,