This commit is contained in:
@@ -3,9 +3,6 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
@@ -15,73 +12,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
||||
SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
SmallVector<bool> seen(rank, false);
|
||||
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
|
||||
Reference in New Issue
Block a user