#pragma once #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include #include #include #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #define DEFINE_MAP_OP(opname) opname, namespace onnx_mlir { template inline auto getImageWidth(const ShapedType& shapedType) { return shapedType.getDimSize(2); } template inline auto getImageHeight(const ShapedType& shapedType) { return shapedType.getDimSize(3); } template inline auto getImageChannel(const ShapedType& shapedType) { return shapedType.getDimSize(1); } template inline auto getImageN(const ShapedType& shapedType) { return shapedType.getDimSize(0); } template inline auto getKernelWidth(const ShapedType& shapedType) { return shapedType.getDimSize(2); } template inline auto getKernelHeight(const ShapedType& shapedType) { return shapedType.getDimSize(3); } template inline auto getFilterCount(const ShapedType& shapedType) { return shapedType.getDimSize(0); } using HSliceId = size_t; using CoreId = size_t; template > constexpr C ceilIntegerDivide(A a, B b) { static_assert(std::is_integral_v, "A must be an integer type"); static_assert(std::is_integral_v, "B must be an integer type"); C ac = static_cast(a); C bc = static_cast(b); return 1 + (ac - 1) / bc; } template > constexpr std::pair ceilIntegerDivideWithRemainder(A a, B b) { static_assert(std::is_integral_v, "A must be an integer type"); static_assert(std::is_integral_v, "B must be an integer type"); C ac = static_cast(a); C bc = static_cast(b); return {ceilIntegerDivide(ac, bc), ac % bc}; } template bool isVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); } template bool isMatrixShape(mlir::ArrayRef shape) { return shape.size() == 2; } template bool isHVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && shape[0] == 1; } template bool isVVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && shape[1] == 1; } template T getVectorLength(mlir::ArrayRef shape) { assert(isVectorShape(shape)); return shape[0] != 1 ? shape[0] : shape[1]; } inline auto getTensorShape(mlir::Value tensor) { return mlir::cast(tensor.getType()).getShape(); } namespace detail { template void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { std::forward(fn)(block->getArgument(Is)...); } } // namespace detail template spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); computeOp.getBody().push_back(block); rewriter.setInsertionPointToStart(block); detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); rewriter.setInsertionPointAfter(computeOp); return computeOp; } llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, size_t axis, int64_t sliceSize, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, int64_t sliceSize, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); llvm::DenseMap> sliceVectorPerCrossbarPerCore( const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); llvm::DenseMap>> tileMatrix(mlir::Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, mlir::ConversionPatternRewriter& rewriter, mlir::Location& loc); mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, int64_t length, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); }; // namespace onnx_mlir