90 lines
3.1 KiB
C++
90 lines
3.1 KiB
C++
#include "llvm/ADT/STLExtras.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
|
|
llvm::SmallVector<int64_t> strides(shape.size(), 1);
|
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
|
return strides;
|
|
}
|
|
|
|
llvm::SmallVector<int64_t>
|
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
|
|
llvm::SmallVector<int64_t> indices(shape.size(), 0);
|
|
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
|
indices[dim] = linearIndex / stride;
|
|
linearIndex %= stride;
|
|
}
|
|
return indices;
|
|
}
|
|
|
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
|
|
int64_t linearIndex = 0;
|
|
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
|
linearIndex += index * stride;
|
|
return linearIndex;
|
|
}
|
|
|
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
|
int64_t numElements = 1;
|
|
for (int64_t dim : shape)
|
|
numElements *= dim;
|
|
return numElements;
|
|
}
|
|
|
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|
llvm::ArrayRef<int64_t> offsets,
|
|
llvm::ArrayRef<int64_t> sizes,
|
|
llvm::ArrayRef<int64_t> strides) {
|
|
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
|
return false;
|
|
|
|
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
|
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
|
|
auto firstNonZeroOffset = std::find_if(
|
|
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
|
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
|
return offset != 0;
|
|
});
|
|
|
|
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
|
auto [offset, size, dimension] = *firstNonZeroOffset;
|
|
if (size > dimension - offset)
|
|
return false;
|
|
++firstNonZeroOffset;
|
|
|
|
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
|
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
|
return size != 1;
|
|
}))
|
|
return false;
|
|
}
|
|
|
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
|
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
auto [size, dimension] = sizeAndShape;
|
|
return size != dimension;
|
|
});
|
|
|
|
if (firstDifferentSize != sizesAndShape.end()) {
|
|
++firstDifferentSize;
|
|
|
|
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
|
auto [size, _dimension] = sizeAndShape;
|
|
return size != 1;
|
|
}))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace onnx_mlir
|