8bb0babf1b
Validate Operations / validate-operations (push) Has been cancelled
use uniqued constant helpers everywhere materialize transposed constants directly
93 lines
3.4 KiB
C++
93 lines
3.4 KiB
C++
#include "IndexingUtils.hpp"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
|
|
#include "llvm/ADT/APInt.h"
|
|
|
|
#include <algorithm>
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
|
|
|
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
|
|
int64_t normalizedAxis = normalizeAxis(axis, rank);
|
|
if (normalizedAxis < 0 || normalizedAxis >= rank)
|
|
return failure();
|
|
return normalizedAxis;
|
|
}
|
|
|
|
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
|
|
|
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
|
SmallVector<int64_t> normalizedAxes;
|
|
if (!axesAttr) {
|
|
normalizedAxes.reserve(rank);
|
|
for (int64_t axis = 0; axis < rank; ++axis)
|
|
normalizedAxes.push_back(axis);
|
|
}
|
|
else {
|
|
normalizedAxes.reserve(axesAttr->size());
|
|
for (Attribute attr : *axesAttr)
|
|
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
|
|
llvm::sort(normalizedAxes);
|
|
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
|
}
|
|
return normalizedAxes;
|
|
}
|
|
|
|
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
|
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
|
|
for (int64_t axis : normalizedAxes)
|
|
if (axis < 0 || axis >= rank)
|
|
return failure();
|
|
return normalizedAxes;
|
|
}
|
|
|
|
Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
|
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
|
}
|
|
|
|
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
|
|
if (multiplier == 0)
|
|
return getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
|
if (multiplier == 1)
|
|
return value;
|
|
|
|
MLIRContext* context = rewriter.getContext();
|
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
|
|
}
|
|
|
|
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
|
if (divisor == 1)
|
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
|
|
MLIRContext* context = rewriter.getContext();
|
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value});
|
|
}
|
|
|
|
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
|
if (divisor == 1)
|
|
return value;
|
|
|
|
MLIRContext* context = rewriter.getContext();
|
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
|
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
|
|
}
|
|
|
|
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
|
|
if (auto attr = dyn_cast<Attribute>(value))
|
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
|
return cast<Value>(value);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|