#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Matchers.h" #include "AffineUtils.hpp" #include "ConstantUtils.hpp" using namespace mlir; namespace onnx_mlir { static FailureOr floorDivSigned(int64_t lhs, int64_t rhs) { if (rhs <= 0) return failure(); int64_t quotient = lhs / rhs; int64_t remainder = lhs % rhs; if (remainder != 0 && lhs < 0) --quotient; return quotient; } static FailureOr ceilDivSigned(int64_t lhs, int64_t rhs) { if (rhs <= 0) return failure(); int64_t quotient = lhs / rhs; int64_t remainder = lhs % rhs; if (remainder != 0 && lhs > 0) ++quotient; return quotient; } Value createOrFoldAffineApply( RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) { assert(constantAnchor && "expected a valid constant anchor"); assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map"); SmallVector operandConstants; operandConstants.reserve(operands.size()); for (Value operand : operands) { std::optional constantValue = matchConstantIndexValue(operand); if (!constantValue) return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); operandConstants.push_back(rewriter.getIndexAttr(*constantValue)); } SmallVector foldedResults; if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1) if (auto constantResult = dyn_cast(foldedResults.front())) return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt()); return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); } Value createOrFoldAffineApply( RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) { AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr); return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor); } Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) { assert(constantAnchor && "expected a valid constant anchor"); if (multiplier == 0) return getOrCreateIndexConstant(rewriter, constantAnchor, 0); if (multiplier == 1) return value; AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor); } Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) { assert(constantAnchor && "expected a valid constant anchor"); assert(divisor > 0 && "expected a positive affine.mod divisor"); if (divisor == 1) return getOrCreateIndexConstant(rewriter, constantAnchor, 0); AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor); } Value affineFloorDivConst( RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) { assert(constantAnchor && "expected a valid constant anchor"); assert(divisor > 0 && "expected a positive affine.floor_div divisor"); if (divisor == 1) return value; AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor); } FailureOr evaluateAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols) { if (auto constant = dyn_cast(expr)) return constant.getValue(); if (auto dim = dyn_cast(expr)) { unsigned position = dim.getPosition(); if (position >= dims.size()) return failure(); return dims[position]; } if (auto symbol = dyn_cast(expr)) { unsigned position = symbol.getPosition(); if (position >= symbols.size()) return failure(); return symbols[position]; } auto binary = dyn_cast(expr); if (!binary) return failure(); FailureOr lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols); FailureOr rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols); if (failed(lhs) || failed(rhs)) return failure(); switch (binary.getKind()) { case AffineExprKind::Add: return *lhs + *rhs; case AffineExprKind::Mul: return *lhs * *rhs; case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs); case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs); case AffineExprKind::Mod: { FailureOr div = floorDivSigned(*lhs, *rhs); if (failed(div)) return failure(); return *lhs - *div * *rhs; } default: return failure(); } } FailureOr evaluateSingleResultAffineMap(AffineMap map, ArrayRef operands) { if (map.getNumResults() != 1 || operands.size() != map.getNumInputs()) return failure(); ArrayRef dims(operands.data(), map.getNumDims()); ArrayRef symbols(operands.data() + map.getNumDims(), map.getNumSymbols()); return evaluateAffineExpr(map.getResult(0), dims, symbols); } FailureOr evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) { SmallVector operands; operands.reserve(affineApply.getMapOperands().size()); for (Value operand : affineApply.getMapOperands()) { FailureOr folded = resolver(operand); if (failed(folded)) return failure(); operands.push_back(*folded); } return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands); } bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; } bool isDimAndConstantAffineExpr(AffineExpr expr) { switch (expr.getKind()) { case AffineExprKind::Constant: case AffineExprKind::DimId: return true; case AffineExprKind::SymbolId: return false; case AffineExprKind::Add: { auto binaryExpr = cast(expr); return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()); } case AffineExprKind::Mul: { auto binaryExpr = cast(expr); return (isa(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS())) || (isa(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS())); } case AffineExprKind::FloorDiv: case AffineExprKind::CeilDiv: case AffineExprKind::Mod: { auto binaryExpr = cast(expr); return isa(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()); } } llvm_unreachable("unexpected affine expression kind"); } } // namespace onnx_mlir