#include "llvm/ADT/STLExtras.h" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" namespace onnx_mlir { llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape) { llvm::SmallVector strides(shape.size(), 1); for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) strides[dim] = strides[dim + 1] * shape[dim + 1]; return strides; } llvm::SmallVector delinearizeIndex(int64_t linearIndex, llvm::ArrayRef shape, llvm::ArrayRef strides) { llvm::SmallVector indices(shape.size(), 0); for (auto [dim, stride] : llvm::enumerate(strides)) { indices[dim] = linearIndex / stride; linearIndex %= stride; } return indices; } int64_t linearizeIndex(llvm::ArrayRef indices, llvm::ArrayRef strides) { int64_t linearIndex = 0; for (auto [index, stride] : llvm::zip_equal(indices, strides)) linearIndex += index * stride; return linearIndex; } int64_t getNumElements(llvm::ArrayRef shape) { int64_t numElements = 1; for (int64_t dim : shape) numElements *= dim; return numElements; } bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef offsets, llvm::ArrayRef sizes, llvm::ArrayRef strides) { if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) return false; auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()), llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstNonZeroOffset = std::find_if( offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool { auto [offset, _size, _dimension] = offsetAndSizeAndShape; return offset != 0; }); if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) { auto [offset, size, dimension] = *firstNonZeroOffset; if (size > dimension - offset) return false; ++firstNonZeroOffset; if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool { auto [_offset, size, _dimension] = offsetAndSizeAndShape; return size != 1; })) return false; } auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { auto [size, dimension] = sizeAndShape; return size != dimension; }); if (firstDifferentSize != sizesAndShape.end()) { ++firstDifferentSize; if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool { auto [size, _dimension] = sizeAndShape; return size != 1; })) return false; } return true; } } // namespace onnx_mlir