#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "llvm/ADT/SmallVector.h" #include #include "IndexingUtils.hpp" #include "ShapeTilingUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { bool hasStaticPositiveShape(ArrayRef shape) { return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); } bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); } int64_t getStaticShapeElementCount(ArrayRef shape) { return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies {}); } SmallVector permuteShape(ArrayRef shape, ArrayRef permutation) { SmallVector permutedShape; permutedShape.reserve(permutation.size()); for (int64_t axis : permutation) permutedShape.push_back(shape[axis]); return permutedShape; } SmallVector invertPermutation(ArrayRef permutation) { SmallVector inversePermutation(permutation.size()); for (auto [newIndex, oldIndex] : llvm::enumerate(permutation)) inversePermutation[oldIndex] = static_cast(newIndex); return inversePermutation; } FailureOr> getTransposePermutationChecked(std::optional permAttr, int64_t rank) { SmallVector permutation; if (!permAttr) { permutation.reserve(rank); for (int64_t dim = rank - 1; dim >= 0; --dim) permutation.push_back(dim); return permutation; } if (static_cast(permAttr->size()) != rank) return failure(); permutation.reserve(permAttr->size()); SmallVector seen(rank, false); for (IntegerAttr attr : permAttr->getAsRange()) { int64_t axis = attr.getInt(); if (axis < 0 || axis >= rank || seen[axis]) return failure(); seen[axis] = true; permutation.push_back(axis); } return permutation; } Value transposeMaybeInCompute( Value value, RankedTensorType resultType, ArrayRef permutation, PatternRewriter& rewriter, Location loc) { auto buildTranspose = [&](Value input) -> Value { return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult(); }; return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose); } SmallVector getUnitStrides(PatternRewriter& rewriter, int64_t rank) { return SmallVector(rank, rewriter.getIndexAttr(1)); } SmallVector getZeroOffsets(PatternRewriter& rewriter, int64_t rank) { return SmallVector(rank, rewriter.getIndexAttr(0)); } SmallVector getStaticSizes(PatternRewriter& rewriter, ArrayRef shape) { SmallVector sizes; sizes.reserve(shape.size()); for (int64_t dim : shape) sizes.push_back(rewriter.getIndexAttr(dim)); return sizes; } SmallVector sliceTensor( const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { ArrayRef shape = getTensorShape(tensorToSlice); assert("Invalid axis" && axis < shape.size()); SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); SmallVector offsets = getZeroOffsets(rewriter, shape.size()); SmallVector sizes = getStaticSizes(rewriter, shape); sizes[axis] = rewriter.getIndexAttr(sliceSize); long length = shape[axis]; auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize); SmallVector slices; slices.reserve(numSlices); for (int64_t i = 0; i < numSlices; i++) { offsets[axis] = rewriter.getIndexAttr(i * sliceSize); int64_t currentSliceSize = sliceSize; if (i == numSlices - 1 && lastSliceSize != 0) { currentSliceSize = lastSliceSize; sizes[axis] = rewriter.getIndexAttr(lastSliceSize); } SmallVector sliceShape(shape.begin(), shape.end()); sliceShape[axis] = currentSliceSize; auto sliceType = RankedTensorType::get(sliceShape, cast(tensorToSlice.getType()).getElementType()); Value slice; if (isCompileTimeComputable(tensorToSlice)) { slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides); } else { auto sliceCompute = createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) { Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); spatial::SpatYieldOp::create(rewriter, loc, computedSlice); }); slice = sliceCompute.getResult(0); } slices.push_back(slice); } return slices; } SmallVector sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { ArrayRef shape = getTensorShape(vectorToSlice); assert("Not a vector" && isVectorShape(shape)); size_t axis = shape[0] != 1 ? 0 : 1; return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc); } DenseMap> sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) { SmallVector slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc); DenseMap> slicesPerCore; for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) { size_t coreId = sliceId / crossbarCountInCore; slicesPerCore[coreId].push_back(slices[sliceId]); } return slicesPerCore; } Value extractAxisSlice( PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) { auto sourceType = cast(source.getType()); SmallVector resultShape(sourceType.getShape()); resultShape[axis] = size; auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding()); SmallVector offsets = getZeroOffsets(rewriter, sourceType.getRank()); SmallVector sizes = getStaticSizes(rewriter, sourceType.getShape()); offsets[axis] = rewriter.getIndexAttr(offset); sizes[axis] = rewriter.getIndexAttr(size); return tensor::ExtractSliceOp::create( rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank())) .getResult(); } Value insertStaticSlice( PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef offsets) { auto sourceType = cast(source.getType()); return tensor::InsertSliceOp::create(rewriter, loc, source, dest, offsets, getStaticSizes(rewriter, sourceType.getShape()), getUnitStrides(rewriter, sourceType.getRank())) .getResult(); } } // namespace onnx_mlir