#pragma once #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include #include #include #include #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { namespace detail { inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } template decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { return std::forward(fn)(block->getArgument(Is)...); } template decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef values, std::index_sequence) { return std::forward(fn)(values[Is]...); } template using ValueArg = mlir::Value; template struct InvokeWithBlockArgsResult; template struct InvokeWithBlockArgsResult> { using type = std::invoke_result_t...>; }; template using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult::type; template using InvokeWithValueRangeResultT = std::invoke_result_t; } // namespace detail template inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) { assert(!inputs.empty() && "spat.concat requires at least one input"); if (inputs.size() == 1) return inputs.front(); auto firstType = mlir::cast(inputs.front().getType()); auto outputShape = llvm::to_vector(firstType.getShape()); int64_t concatDimSize = 0; bool concatDimDynamic = false; for (mlir::Value input : inputs) { auto inputType = mlir::cast(input.getType()); assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs"); if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis))) concatDimDynamic = true; else concatDimSize += inputType.getDimSize(axis); } outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize; auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput(); } /// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if /// the body callback reports failure. template auto createSpatCompute(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); computeOp.getBody().push_back(block); rewriter.setInsertionPointToStart(block); using BodyResult = detail::InvokeWithBlockArgsResultT, std::make_index_sequence>; if constexpr (std::is_same_v) { detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); rewriter.setInsertionPointAfter(computeOp); return computeOp; } else { auto bodyResult = detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); return mlir::FailureOr(computeOp); } } /// Builds a `spat.compute` whose body consumes the block arguments as a single /// `ValueRange`, which is convenient for variadic reductions/concats. template auto createSpatCompute(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); computeOp.getBody().push_back(block); rewriter.setInsertionPointToStart(block); using BodyResult = detail::InvokeWithValueRangeResultT>; if constexpr (std::is_same_v) { std::forward(body)(detail::getBlockArgs(block)); rewriter.setInsertionPointAfter(computeOp); return computeOp; } else { auto bodyResult = std::forward(body)(detail::getBlockArgs(block)); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); return mlir::FailureOr(computeOp); } } mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); } // namespace onnx_mlir