Files
Raptor/src/PIM/Common/IR/ShapeUtils.hpp
T
2026-05-22 21:52:28 +02:00

34 lines
977 B
C++

#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool hasByteSizedElementType(mlir::Type elementType);
size_t getElementTypeSizeInBytes(mlir::Type elementType);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
} // namespace onnx_mlir