ff36729140
fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op
167 lines
6.2 KiB
C++
167 lines
6.2 KiB
C++
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
|
|
llvm::SmallVector<int64_t> strides(shape.size(), 1);
|
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
|
return strides;
|
|
}
|
|
|
|
llvm::SmallVector<int64_t>
|
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
|
|
llvm::SmallVector<int64_t> indices(shape.size(), 0);
|
|
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
|
indices[dim] = linearIndex / stride;
|
|
linearIndex %= stride;
|
|
}
|
|
return indices;
|
|
}
|
|
|
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
|
|
int64_t linearIndex = 0;
|
|
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
|
linearIndex += index * stride;
|
|
return linearIndex;
|
|
}
|
|
|
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
|
int64_t numElements = 1;
|
|
for (int64_t dim : shape)
|
|
numElements *= dim;
|
|
return numElements;
|
|
}
|
|
|
|
bool hasByteSizedElementType(mlir::Type elementType) {
|
|
if (mlir::isa<mlir::IndexType>(elementType))
|
|
return true;
|
|
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
|
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
|
|
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
|
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
|
|
return false;
|
|
}
|
|
|
|
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
|
if (mlir::isa<mlir::IndexType>(elementType))
|
|
return mlir::IndexType::kInternalStorageBitWidth / 8;
|
|
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
|
return static_cast<size_t>(intType.getWidth() / 8);
|
|
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
|
return static_cast<size_t>(floatType.getWidth() / 8);
|
|
llvm_unreachable("expected byte-sized integer, float, or index element type");
|
|
}
|
|
|
|
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
|
|
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
|
|
}
|
|
|
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|
llvm::ArrayRef<int64_t> offsets,
|
|
llvm::ArrayRef<int64_t> sizes,
|
|
llvm::ArrayRef<int64_t> strides) {
|
|
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
|
return false;
|
|
|
|
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
|
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
|
|
auto firstNonZeroOffset = std::find_if(
|
|
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
|
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
|
return offset != 0;
|
|
});
|
|
|
|
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
|
auto [offset, size, dimension] = *firstNonZeroOffset;
|
|
if (size > dimension - offset)
|
|
return false;
|
|
++firstNonZeroOffset;
|
|
|
|
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
|
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
|
return size != 1;
|
|
}))
|
|
return false;
|
|
}
|
|
|
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
|
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
auto [size, dimension] = sizeAndShape;
|
|
return size != dimension;
|
|
});
|
|
|
|
if (firstDifferentSize != sizesAndShape.end()) {
|
|
++firstDifferentSize;
|
|
|
|
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
|
auto [size, _dimension] = sizeAndShape;
|
|
return size != 1;
|
|
}))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
|
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
|
llvm::ArrayRef<int64_t> staticSizes,
|
|
llvm::ArrayRef<int64_t> staticStrides) {
|
|
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|
|
|| sourceShape.size() != staticStrides.size()) {
|
|
return false;
|
|
}
|
|
|
|
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
|
|
return false;
|
|
|
|
auto reversedTriples =
|
|
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
|
|
|
|
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
|
|
auto [_sourceDim, offset, _size] = triple;
|
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
|
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
|
|
return true;
|
|
});
|
|
|
|
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
|
|
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
|
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
|
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
|
|
if (size > sourceDim - staticOffset)
|
|
return false;
|
|
}
|
|
|
|
++firstNonZeroOrDynamicOffset;
|
|
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
|
|
if (std::get<2>(*it) != 1)
|
|
return false;
|
|
}
|
|
|
|
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
|
|
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
|
|
auto [sourceDim, size] = pair;
|
|
return size != sourceDim;
|
|
});
|
|
|
|
if (firstDifferentSize != reversedSizes.end()) {
|
|
++firstDifferentSize;
|
|
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
|
|
if (std::get<1>(*it) != 1)
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace onnx_mlir
|