#include "llvm/ADT/STLExtras.h" #include "llvm/Support/ErrorHandling.h" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" namespace onnx_mlir { llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape) { llvm::SmallVector strides(shape.size(), 1); for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) strides[dim] = strides[dim + 1] * shape[dim + 1]; return strides; } llvm::SmallVector delinearizeIndex(int64_t linearIndex, llvm::ArrayRef shape, llvm::ArrayRef strides) { llvm::SmallVector indices(shape.size(), 0); for (auto [dim, stride] : llvm::enumerate(strides)) { indices[dim] = linearIndex / stride; linearIndex %= stride; } return indices; } int64_t linearizeIndex(llvm::ArrayRef indices, llvm::ArrayRef strides) { int64_t linearIndex = 0; for (auto [index, stride] : llvm::zip_equal(indices, strides)) linearIndex += index * stride; return linearIndex; } int64_t getNumElements(llvm::ArrayRef shape) { int64_t numElements = 1; for (int64_t dim : shape) numElements *= dim; return numElements; } bool hasByteSizedElementType(mlir::Type elementType) { if (mlir::isa(elementType)) return true; if (auto intType = mlir::dyn_cast(elementType)) return intType.getWidth() > 0 && intType.getWidth() % 8 == 0; if (auto floatType = mlir::dyn_cast(elementType)) return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0; return false; } size_t getElementTypeSizeInBytes(mlir::Type elementType) { if (mlir::isa(elementType)) return mlir::IndexType::kInternalStorageBitWidth / 8; if (auto intType = mlir::dyn_cast(elementType)) return static_cast(intType.getWidth() / 8); if (auto floatType = mlir::dyn_cast(elementType)) return static_cast(floatType.getWidth() / 8); llvm_unreachable("expected byte-sized integer, float, or index element type"); } size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) { return static_cast(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType()); } bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef offsets, llvm::ArrayRef sizes, llvm::ArrayRef strides) { if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) return false; auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()), llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstNonZeroOffset = std::find_if( offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool { auto [offset, _size, _dimension] = offsetAndSizeAndShape; return offset != 0; }); if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) { auto [offset, size, dimension] = *firstNonZeroOffset; if (size > dimension - offset) return false; ++firstNonZeroOffset; if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool { auto [_offset, size, _dimension] = offsetAndSizeAndShape; return size != 1; })) return false; } auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { auto [size, dimension] = sizeAndShape; return size != dimension; }); if (firstDifferentSize != sizesAndShape.end()) { ++firstDifferentSize; if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool { auto [size, _dimension] = sizeAndShape; return size != 1; })) return false; } return true; } } // namespace onnx_mlir