ff36729140
fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op
190 lines
7.3 KiB
C++
190 lines
7.3 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <functional>
|
|
|
|
#include "IndexingUtils.hpp"
|
|
#include "ShapeTilingUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
|
}
|
|
|
|
bool hasStaticPositiveShape(RankedTensorType type) {
|
|
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
|
}
|
|
|
|
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
|
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
|
}
|
|
|
|
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
|
SmallVector<int64_t> permutedShape;
|
|
permutedShape.reserve(permutation.size());
|
|
for (int64_t axis : permutation)
|
|
permutedShape.push_back(shape[axis]);
|
|
return permutedShape;
|
|
}
|
|
|
|
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
|
SmallVector<int64_t> inversePermutation(permutation.size());
|
|
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
|
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
|
return inversePermutation;
|
|
}
|
|
|
|
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
|
SmallVector<int64_t> permutation;
|
|
if (!permAttr) {
|
|
permutation.reserve(rank);
|
|
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
|
permutation.push_back(dim);
|
|
return permutation;
|
|
}
|
|
|
|
if (static_cast<int64_t>(permAttr->size()) != rank)
|
|
return failure();
|
|
|
|
permutation.reserve(permAttr->size());
|
|
SmallVector<bool> seen(rank, false);
|
|
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
|
int64_t axis = attr.getInt();
|
|
if (axis < 0 || axis >= rank || seen[axis])
|
|
return failure();
|
|
seen[axis] = true;
|
|
permutation.push_back(axis);
|
|
}
|
|
return permutation;
|
|
}
|
|
|
|
Value transposeMaybeInCompute(
|
|
Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
|
|
auto buildTranspose = [&](Value input) -> Value {
|
|
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
|
};
|
|
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
|
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
|
}
|
|
|
|
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
|
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
|
}
|
|
|
|
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
|
SmallVector<OpFoldResult> sizes;
|
|
sizes.reserve(shape.size());
|
|
for (int64_t dim : shape)
|
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
return sizes;
|
|
}
|
|
|
|
SmallVector<Value> sliceTensor(
|
|
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
|
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
|
assert("Invalid axis" && axis < shape.size());
|
|
|
|
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
|
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
|
|
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
|
|
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
|
|
|
long length = shape[axis];
|
|
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
|
|
SmallVector<Value> slices;
|
|
slices.reserve(numSlices);
|
|
|
|
for (int64_t i = 0; i < numSlices; i++) {
|
|
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
|
int64_t currentSliceSize = sliceSize;
|
|
if (i == numSlices - 1 && lastSliceSize != 0) {
|
|
currentSliceSize = lastSliceSize;
|
|
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
|
}
|
|
|
|
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
|
|
sliceShape[axis] = currentSliceSize;
|
|
auto sliceType =
|
|
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
|
|
|
|
Value slice;
|
|
if (isCompileTimeComputable(tensorToSlice)) {
|
|
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
|
}
|
|
else {
|
|
auto sliceCompute =
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
|
|
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
|
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
|
|
});
|
|
slice = sliceCompute.getResult(0);
|
|
}
|
|
slices.push_back(slice);
|
|
}
|
|
|
|
return slices;
|
|
}
|
|
|
|
SmallVector<Value>
|
|
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
|
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
|
assert("Not a vector" && isVectorShape(shape));
|
|
size_t axis = shape[0] != 1 ? 0 : 1;
|
|
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
|
|
}
|
|
|
|
DenseMap<CoreId, SmallVector<Value>>
|
|
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
|
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
|
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
|
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
|
size_t coreId = sliceId / crossbarCountInCore;
|
|
slicesPerCore[coreId].push_back(slices[sliceId]);
|
|
}
|
|
return slicesPerCore;
|
|
}
|
|
|
|
Value extractAxisSlice(
|
|
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
|
SmallVector<int64_t> resultShape(sourceType.getShape());
|
|
resultShape[axis] = size;
|
|
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
|
|
|
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
|
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
|
offsets[axis] = rewriter.getIndexAttr(offset);
|
|
sizes[axis] = rewriter.getIndexAttr(size);
|
|
return tensor::ExtractSliceOp::create(
|
|
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
|
.getResult();
|
|
}
|
|
|
|
Value insertStaticSlice(
|
|
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
|
return tensor::InsertSliceOp::create(rewriter,
|
|
loc,
|
|
source,
|
|
dest,
|
|
offsets,
|
|
getStaticSizes(rewriter, sourceType.getShape()),
|
|
getUnitStrides(rewriter, sourceType.getRank()))
|
|
.getResult();
|
|
}
|
|
|
|
} // namespace onnx_mlir
|