Files
Raptor/src/PIM/Common/IR/AffineUtils.cpp
T
NiccoloN ab63498f3f
Validate Operations / validate-operations (push) Has been cancelled
normalize affine arithmetic helpers
2026-05-30 16:37:28 +02:00

183 lines
6.7 KiB
C++

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "AffineUtils.hpp"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
static FailureOr<int64_t> floorDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs < 0)
--quotient;
return quotient;
}
static FailureOr<int64_t> ceilDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs > 0)
++quotient;
return quotient;
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map");
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
std::optional<int64_t> constantValue = matchConstantIndexValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1)
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt());
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) {
AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr);
return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor);
}
Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
if (multiplier == 0)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
if (multiplier == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
}
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.mod divisor");
if (divisor == 1)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor);
}
Value affineFloorDivConst(
RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
if (divisor == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
}
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
return constant.getValue();
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
unsigned position = dim.getPosition();
if (position >= dims.size())
return failure();
return dims[position];
}
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
unsigned position = symbol.getPosition();
if (position >= symbols.size())
return failure();
return symbols[position];
}
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binary)
return failure();
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
if (failed(lhs) || failed(rhs))
return failure();
switch (binary.getKind()) {
case AffineExprKind::Add: return *lhs + *rhs;
case AffineExprKind::Mul: return *lhs * *rhs;
case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs);
case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs);
case AffineExprKind::Mod: {
FailureOr<int64_t> div = floorDivSigned(*lhs, *rhs);
if (failed(div))
return failure();
return *lhs - *div * *rhs;
}
default: return failure();
}
}
FailureOr<int64_t> evaluateSingleResultAffineMap(AffineMap map, ArrayRef<int64_t> operands) {
if (map.getNumResults() != 1 || operands.size() != map.getNumInputs())
return failure();
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
return evaluateAffineExpr(map.getResult(0), dims, symbols);
}
FailureOr<int64_t> evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) {
SmallVector<int64_t, 4> operands;
operands.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = resolver(operand);
if (failed(folded))
return failure();
operands.push_back(*folded);
}
return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands);
}
bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; }
bool isDimAndConstantAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
} // namespace onnx_mlir