46 lines
1.4 KiB
C++
46 lines
1.4 KiB
C++
#include <algorithm>
|
|
|
|
#include "IndexingUtils.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
|
|
|
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
|
|
int64_t normalizedAxis = normalizeAxis(axis, rank);
|
|
if (normalizedAxis < 0 || normalizedAxis >= rank)
|
|
return failure();
|
|
return normalizedAxis;
|
|
}
|
|
|
|
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
|
|
|
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
|
SmallVector<int64_t> normalizedAxes;
|
|
if (!axesAttr) {
|
|
normalizedAxes.reserve(rank);
|
|
for (int64_t axis = 0; axis < rank; ++axis)
|
|
normalizedAxes.push_back(axis);
|
|
}
|
|
else {
|
|
normalizedAxes.reserve(axesAttr->size());
|
|
for (Attribute attr : *axesAttr)
|
|
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
|
|
llvm::sort(normalizedAxes);
|
|
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
|
}
|
|
return normalizedAxes;
|
|
}
|
|
|
|
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
|
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
|
|
for (int64_t axis : normalizedAxes)
|
|
if (axis < 0 || axis >= rank)
|
|
return failure();
|
|
return normalizedAxes;
|
|
}
|
|
|
|
} // namespace onnx_mlir
|