#include "IndexingUtils.hpp" #include "mlir/Dialect/Arith/IR/Arith.h" #include "llvm/ADT/APInt.h" #include #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" using namespace mlir; namespace onnx_mlir { int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } FailureOr normalizeAxisChecked(int64_t axis, int64_t rank) { int64_t normalizedAxis = normalizeAxis(axis, rank); if (normalizedAxis < 0 || normalizedAxis >= rank) return failure(); return normalizedAxis; } int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; } static SmallVector normalizeAxesImpl(std::optional axesAttr, int64_t rank) { SmallVector normalizedAxes; if (!axesAttr) { normalizedAxes.reserve(rank); for (int64_t axis = 0; axis < rank; ++axis) normalizedAxes.push_back(axis); } else { normalizedAxes.reserve(axesAttr->size()); for (Attribute attr : *axesAttr) normalizedAxes.push_back(normalizeAxis(cast(attr).getInt(), rank)); llvm::sort(normalizedAxes); normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end()); } return normalizedAxes; } FailureOr> normalizeAxesChecked(std::optional axesAttr, int64_t rank) { SmallVector normalizedAxes = normalizeAxesImpl(axesAttr, rank); for (int64_t axis : normalizedAxes) if (axis < 0 || axis >= rank) return failure(); return normalizedAxes; } Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp); } Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) { if (multiplier == 0) return getOrCreateIndexConstant(rewriter, anchorOp, 0); if (multiplier == 1) return value; MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value}); } Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { if (divisor == 1) return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value}); } Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { if (divisor == 1) return value; MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}); } Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) { if (auto attr = dyn_cast(value)) return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast(attr).getInt()); return cast(value); } } // namespace onnx_mlir