This commit is contained in:
@@ -6,20 +6,21 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (auto attr = dyn_cast<Attribute>(result))
|
||||
return arith::ConstantIndexOp::create(rewriter, loc, cast<IntegerAttr>(attr).getInt()).getResult();
|
||||
return cast<Value>(result);
|
||||
return getOrMaterializeIndexValue(rewriter, loc, result);
|
||||
}
|
||||
|
||||
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -50,6 +51,84 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt
|
||||
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); }
|
||||
|
||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); }
|
||||
|
||||
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
||||
SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
SmallVector<bool> seen(rank, false);
|
||||
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
Value transposeMaybeInCompute(Value value,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||
};
|
||||
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
|
||||
@@ -88,11 +167,8 @@ SmallVector<Value> sliceTensor(
|
||||
assert("Invalid axis" && axis < shape.size());
|
||||
|
||||
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (const auto size : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(size));
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
|
||||
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
||||
|
||||
long length = shape[axis];
|
||||
@@ -276,4 +352,43 @@ Value materializeContiguousTensorSlice(Value source,
|
||||
return buildLoopNest(buildLoopNest, 0, init);
|
||||
}
|
||||
|
||||
Value extractStaticSlice(PatternRewriter& rewriter,
|
||||
Location loc,
|
||||
Value source,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<OpFoldResult> offsets) {
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()),
|
||||
getUnitStrides(rewriter, resultType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value extractAxisSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
SmallVector<int64_t> resultShape(sourceType.getShape());
|
||||
resultShape[axis] = size;
|
||||
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value insertStaticSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
return tensor::InsertSliceOp::create(rewriter,
|
||||
loc,
|
||||
source,
|
||||
dest,
|
||||
offsets,
|
||||
getStaticSizes(rewriter, sourceType.getShape()),
|
||||
getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user