183 lines
6.7 KiB
C++
183 lines
6.7 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
#include "AffineUtils.hpp"
|
|
#include "ConstantUtils.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
static FailureOr<int64_t> floorDivSigned(int64_t lhs, int64_t rhs) {
|
|
if (rhs <= 0)
|
|
return failure();
|
|
|
|
int64_t quotient = lhs / rhs;
|
|
int64_t remainder = lhs % rhs;
|
|
if (remainder != 0 && lhs < 0)
|
|
--quotient;
|
|
return quotient;
|
|
}
|
|
|
|
static FailureOr<int64_t> ceilDivSigned(int64_t lhs, int64_t rhs) {
|
|
if (rhs <= 0)
|
|
return failure();
|
|
|
|
int64_t quotient = lhs / rhs;
|
|
int64_t remainder = lhs % rhs;
|
|
if (remainder != 0 && lhs > 0)
|
|
++quotient;
|
|
return quotient;
|
|
}
|
|
|
|
Value createOrFoldAffineApply(
|
|
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) {
|
|
assert(constantAnchor && "expected a valid constant anchor");
|
|
assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map");
|
|
|
|
SmallVector<Attribute> operandConstants;
|
|
operandConstants.reserve(operands.size());
|
|
for (Value operand : operands) {
|
|
std::optional<int64_t> constantValue = matchConstantIndexValue(operand);
|
|
if (!constantValue)
|
|
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
|
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
|
|
}
|
|
|
|
SmallVector<Attribute> foldedResults;
|
|
if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1)
|
|
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
|
|
return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt());
|
|
|
|
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
|
}
|
|
|
|
Value createOrFoldAffineApply(
|
|
RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) {
|
|
AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr);
|
|
return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor);
|
|
}
|
|
|
|
Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) {
|
|
assert(constantAnchor && "expected a valid constant anchor");
|
|
if (multiplier == 0)
|
|
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
|
if (multiplier == 1)
|
|
return value;
|
|
|
|
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
|
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
|
|
}
|
|
|
|
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
|
assert(constantAnchor && "expected a valid constant anchor");
|
|
assert(divisor > 0 && "expected a positive affine.mod divisor");
|
|
if (divisor == 1)
|
|
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
|
|
|
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
|
return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor);
|
|
}
|
|
|
|
Value affineFloorDivConst(
|
|
RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
|
assert(constantAnchor && "expected a valid constant anchor");
|
|
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
|
|
if (divisor == 1)
|
|
return value;
|
|
|
|
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
|
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
|
}
|
|
|
|
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
|
|
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
|
|
return constant.getValue();
|
|
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
|
|
unsigned position = dim.getPosition();
|
|
if (position >= dims.size())
|
|
return failure();
|
|
return dims[position];
|
|
}
|
|
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
|
|
unsigned position = symbol.getPosition();
|
|
if (position >= symbols.size())
|
|
return failure();
|
|
return symbols[position];
|
|
}
|
|
|
|
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
|
|
if (!binary)
|
|
return failure();
|
|
|
|
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
|
|
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
|
|
if (failed(lhs) || failed(rhs))
|
|
return failure();
|
|
|
|
switch (binary.getKind()) {
|
|
case AffineExprKind::Add: return *lhs + *rhs;
|
|
case AffineExprKind::Mul: return *lhs * *rhs;
|
|
case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs);
|
|
case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs);
|
|
case AffineExprKind::Mod: {
|
|
FailureOr<int64_t> div = floorDivSigned(*lhs, *rhs);
|
|
if (failed(div))
|
|
return failure();
|
|
return *lhs - *div * *rhs;
|
|
}
|
|
default: return failure();
|
|
}
|
|
}
|
|
|
|
FailureOr<int64_t> evaluateSingleResultAffineMap(AffineMap map, ArrayRef<int64_t> operands) {
|
|
if (map.getNumResults() != 1 || operands.size() != map.getNumInputs())
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
|
|
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
|
|
return evaluateAffineExpr(map.getResult(0), dims, symbols);
|
|
}
|
|
|
|
FailureOr<int64_t> evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) {
|
|
SmallVector<int64_t, 4> operands;
|
|
operands.reserve(affineApply.getMapOperands().size());
|
|
for (Value operand : affineApply.getMapOperands()) {
|
|
FailureOr<int64_t> folded = resolver(operand);
|
|
if (failed(folded))
|
|
return failure();
|
|
operands.push_back(*folded);
|
|
}
|
|
|
|
return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands);
|
|
}
|
|
|
|
bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; }
|
|
|
|
bool isDimAndConstantAffineExpr(AffineExpr expr) {
|
|
switch (expr.getKind()) {
|
|
case AffineExprKind::Constant:
|
|
case AffineExprKind::DimId: return true;
|
|
case AffineExprKind::SymbolId: return false;
|
|
case AffineExprKind::Add: {
|
|
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
|
return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS());
|
|
}
|
|
case AffineExprKind::Mul: {
|
|
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
|
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()))
|
|
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()));
|
|
}
|
|
case AffineExprKind::FloorDiv:
|
|
case AffineExprKind::CeilDiv:
|
|
case AffineExprKind::Mod: {
|
|
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
|
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS());
|
|
}
|
|
}
|
|
|
|
llvm_unreachable("unexpected affine expression kind");
|
|
}
|
|
|
|
} // namespace onnx_mlir
|