use affine dialect to express simple constant progressions
Validate Operations / validate-operations (push) Has been cancelled

run dce at the end of MaterializeMergeSchedule to get rid of unused constants
This commit is contained in:
NiccoloN
2026-05-23 14:25:34 +02:00
parent 76a37e198f
commit b79ce8eeaa
3 changed files with 180 additions and 61 deletions
@@ -1,5 +1,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
@@ -119,10 +121,45 @@ static bool isConstantIndexLike(Value value) {
return matchPattern(value, m_ConstantInt(&constantValue));
}
static bool isSupportedLaneAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isSupportedLaneAffineExpr(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value))
return true;
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (affineApply) {
if (affineApply.getAffineMap().getNumResults() != 1 || affineApply.getAffineMap().getNumSymbols() != 0)
return false;
if (!llvm::all_of(affineApply.getMapOperands(),
[&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) {
return false;
}
return isSupportedLaneAffineExpr(affineApply.getAffineMap().getResult(0));
}
auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
if (extractOp) {
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();