#pragma once #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include #include #include #include "llvm/ADT/SmallPtrSet.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" namespace onnx_mlir { template inline auto getImageWidth(const ShapedType& shapedType) { return shapedType.getDimSize(2); } template inline auto getImageHeight(const ShapedType& shapedType) { return shapedType.getDimSize(3); } template inline auto getImageChannel(const ShapedType& shapedType) { return shapedType.getDimSize(1); } template inline auto getImageN(const ShapedType& shapedType) { return shapedType.getDimSize(0); } template inline auto getKernelWidth(const ShapedType& shapedType) { return shapedType.getDimSize(2); } template inline auto getKernelHeight(const ShapedType& shapedType) { return shapedType.getDimSize(3); } template inline auto getFilterCount(const ShapedType& shapedType) { return shapedType.getDimSize(0); } using HSliceId = size_t; using CoreId = size_t; template > constexpr C ceilIntegerDivide(A a, B b) { static_assert(std::is_integral_v, "A must be an integer type"); static_assert(std::is_integral_v, "B must be an integer type"); C ac = static_cast(a); C bc = static_cast(b); return 1 + (ac - 1) / bc; } template > constexpr std::pair ceilIntegerDivideWithRemainder(A a, B b) { static_assert(std::is_integral_v, "A must be an integer type"); static_assert(std::is_integral_v, "B must be an integer type"); C ac = static_cast(a); C bc = static_cast(b); return {ceilIntegerDivide(ac, bc), ac % bc}; } template bool isVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1); } template bool isMatrixShape(mlir::ArrayRef shape) { return shape.size() == 2; } template bool isHVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && shape[0] == 1; } template bool isVVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && shape[1] == 1; } template T getVectorLength(mlir::ArrayRef shape) { assert(isVectorShape(shape)); return shape[0] != 1 ? shape[0] : shape[1]; } inline auto getTensorShape(mlir::Value tensor) { return mlir::cast(tensor.getType()).getShape(); } inline bool isWeightLikeComputeOperand(mlir::Value value) { auto rankedType = mlir::dyn_cast(value.getType()); if (!rankedType || !isMatrixShape(rankedType.getShape())) return false; llvm::SmallPtrSet visited; while (auto* definingOp = value.getDefiningOp()) { if (!visited.insert(definingOp).second) return false; if (hasWeightAlways(definingOp)) return true; if (auto extractSliceOp = mlir::dyn_cast(definingOp)) { value = extractSliceOp.getSource(); continue; } if (auto expandShapeOp = mlir::dyn_cast(definingOp)) { value = expandShapeOp.getSrc(); continue; } if (auto collapseShapeOp = mlir::dyn_cast(definingOp)) { value = collapseShapeOp.getSrc(); continue; } if (auto transposeOp = mlir::dyn_cast(definingOp)) { value = transposeOp.getData(); continue; } return false; } return false; } namespace detail { inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } template decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { return std::forward(fn)(block->getArgument(Is)...); } template decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef values, std::index_sequence) { return std::forward(fn)(values[Is]...); } template using ValueArg = mlir::Value; template struct InvokeWithBlockArgsResult; template struct InvokeWithBlockArgsResult> { using type = std::invoke_result_t...>; }; template using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult::type; template using InvokeWithValueRangeResultT = std::invoke_result_t; } // namespace detail template auto createSpatCompute(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); computeOp.getBody().push_back(block); rewriter.setInsertionPointToStart(block); using BodyResult = detail::InvokeWithBlockArgsResultT, std::make_index_sequence>; if constexpr (std::is_same_v) { auto bodyResult = detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); return mlir::FailureOr(computeOp); } else { static_assert(std::is_same_v, "createSpatCompute body must return void or mlir::LogicalResult"); detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); rewriter.setInsertionPointAfter(computeOp); return computeOp; } } template auto createSpatCompute(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); computeOp.getBody().push_back(block); rewriter.setInsertionPointToStart(block); using BodyResult = detail::InvokeWithValueRangeResultT>; if constexpr (std::is_same_v) { auto bodyResult = std::forward(body)(detail::getBlockArgs(block)); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); return mlir::FailureOr(computeOp); } else { static_assert(std::is_same_v, "createSpatCompute body must return void or mlir::LogicalResult"); std::forward(body)(detail::getBlockArgs(block)); rewriter.setInsertionPointAfter(computeOp); return computeOp; } } llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, size_t axis, int64_t sliceSize, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, int64_t sliceSize, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); llvm::DenseMap> sliceVectorPerCrossbarPerCore( const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); llvm::DenseMap>> tileMatrix(mlir::Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, mlir::ConversionPatternRewriter& rewriter, mlir::Location& loc); mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, int64_t length, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); }; // namespace onnx_mlir