#include "llvm/ADT/STLExtras.h" #include "llvm/Support/ErrorHandling.h" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" namespace onnx_mlir { llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape) { llvm::SmallVector strides(shape.size(), 1); for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) strides[dim] = strides[dim + 1] * shape[dim + 1]; return strides; } llvm::SmallVector delinearizeIndex(int64_t linearIndex, llvm::ArrayRef shape, llvm::ArrayRef strides) { llvm::SmallVector indices(shape.size(), 0); for (auto [dim, stride] : llvm::enumerate(strides)) { indices[dim] = linearIndex / stride; linearIndex %= stride; } return indices; } int64_t linearizeIndex(llvm::ArrayRef indices, llvm::ArrayRef strides) { int64_t linearIndex = 0; for (auto [index, stride] : llvm::zip_equal(indices, strides)) linearIndex += index * stride; return linearIndex; } int64_t getNumElements(llvm::ArrayRef shape) { int64_t numElements = 1; for (int64_t dim : shape) numElements *= dim; return numElements; } bool hasByteSizedElementType(mlir::Type elementType) { if (mlir::isa(elementType)) return true; if (auto intType = mlir::dyn_cast(elementType)) return intType.getWidth() > 0 && intType.getWidth() % 8 == 0; if (auto floatType = mlir::dyn_cast(elementType)) return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0; return false; } size_t getElementTypeSizeInBytes(mlir::Type elementType) { if (mlir::isa(elementType)) return mlir::IndexType::kInternalStorageBitWidth / 8; if (auto intType = mlir::dyn_cast(elementType)) return static_cast(intType.getWidth() / 8); if (auto floatType = mlir::dyn_cast(elementType)) return static_cast(floatType.getWidth() / 8); llvm_unreachable("expected byte-sized integer, float, or index element type"); } size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) { return static_cast(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType()); } bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef offsets, llvm::ArrayRef sizes, llvm::ArrayRef strides) { if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) return false; auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()), llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstNonZeroOffset = std::find_if( offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool { auto [offset, _size, _dimension] = offsetAndSizeAndShape; return offset != 0; }); if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) { auto [offset, size, dimension] = *firstNonZeroOffset; if (size > dimension - offset) return false; ++firstNonZeroOffset; if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool { auto [_offset, size, _dimension] = offsetAndSizeAndShape; return size != 1; })) return false; } auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { auto [size, dimension] = sizeAndShape; return size != dimension; }); if (firstDifferentSize != sizesAndShape.end()) { ++firstDifferentSize; if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool { auto [size, _dimension] = sizeAndShape; return size != 1; })) return false; } return true; } bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef sourceShape, llvm::ArrayRef mixedOffsets, llvm::ArrayRef staticSizes, llvm::ArrayRef staticStrides) { if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size() || sourceShape.size() != staticStrides.size()) { return false; } if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; })) return false; auto reversedTriples = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes)); auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) { auto [_sourceDim, offset, _size] = triple; if (auto attr = mlir::dyn_cast(offset)) return mlir::cast(attr).getInt() != 0; return true; }); if (firstNonZeroOrDynamicOffset != reversedTriples.end()) { auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset; if (auto attr = mlir::dyn_cast(offset)) { int64_t staticOffset = mlir::cast(attr).getInt(); if (size > sourceDim - staticOffset) return false; } ++firstNonZeroOrDynamicOffset; for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it) if (std::get<2>(*it) != 1) return false; } auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes)); auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) { auto [sourceDim, size] = pair; return size != sourceDim; }); if (firstDifferentSize != reversedSizes.end()) { ++firstDifferentSize; for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it) if (std::get<1>(*it) != 1) return false; } return true; } } // namespace onnx_mlir