#pragma once #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include #include #include #include #include #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { namespace detail { inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) { return mlir::ValueRange(block->getArguments()).drop_front(weightCount); } template decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { return std::forward(fn)(block->getArgument(Is)...); } template decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence) { return std::forward(fn)(values[Is]...); } template using ValueArg = mlir::Value; template struct InvokeWithBlockArgsResult; template struct InvokeWithBlockArgsResult> { using type = std::invoke_result_t...>; }; template using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult::type; template using InvokeWithValueRangeResultT = std::invoke_result_t; struct SpatComputeBatchBodyArgs { mlir::Value lane; mlir::ValueRange weights; mlir::ValueRange inputs; mlir::ValueRange outputs; }; } // namespace detail template inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) { assert(!inputs.empty() && "spat.concat requires at least one input"); if (inputs.size() == 1) return inputs.front(); auto firstType = mlir::cast(inputs.front().getType()); auto outputShape = llvm::to_vector(firstType.getShape()); int64_t concatDimSize = 0; bool concatDimDynamic = false; for (mlir::Value input : inputs) { auto inputType = mlir::cast(input.getType()); assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs"); if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis))) concatDimDynamic = true; else concatDimSize += inputType.getDimSize(axis); } outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize; auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput(); } /// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if /// the body callback reports failure. template auto createSpatGraphCompute(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value weight : weights) block->addArgument(weight.getType(), loc); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); computeOp.getBody().push_back(block); rewriter.setInsertionPointToStart(block); using BodyResult = detail::InvokeWithBlockArgsResultT, std::make_index_sequence>; if constexpr (std::is_same_v) { detail::invokeWithValues(std::forward(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence {}); rewriter.setInsertionPointAfter(computeOp); return computeOp; } else { auto bodyResult = detail::invokeWithValues(std::forward(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence {}); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); return mlir::FailureOr(computeOp); } } /// Builds a `spat.graph_compute` whose body consumes the block arguments as a single /// `ValueRange`, which is convenient for variadic reductions/concats. template auto createSpatGraphCompute(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); for (mlir::Value weight : weights) block->addArgument(weight.getType(), loc); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); computeOp.getBody().push_back(block); rewriter.setInsertionPointToStart(block); using BodyResult = detail::InvokeWithValueRangeResultT>; if constexpr (std::is_same_v) { std::forward(body)(detail::getInputBlockArgs(block, weights.size())); rewriter.setInsertionPointAfter(computeOp); return computeOp; } else { auto bodyResult = std::forward(body)(detail::getInputBlockArgs(block, weights.size())); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(computeOp); return mlir::FailureOr(computeOp); } } template auto createSpatGraphComputeBatch(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, int64_t laneCount, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { if (laneCount <= 0 || laneCount > std::numeric_limits::max()) return mlir::FailureOr(mlir::failure()); auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count"); if (mlir::failed(laneCountAttr)) return mlir::FailureOr(mlir::failure()); auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs); mlir::SmallVector blockArgTypes {rewriter.getIndexType()}; mlir::SmallVector blockArgLocs {loc}; blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size()); blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size()); for (mlir::Value weight : weights) { blockArgTypes.push_back(weight.getType()); blockArgLocs.push_back(weight.getLoc()); } for (mlir::Value input : inputs) { blockArgTypes.push_back(input.getType()); blockArgLocs.push_back(input.getLoc()); } for (mlir::Type resultType : resultTypes) { blockArgTypes.push_back(resultType); blockArgLocs.push_back(loc); } auto* block = rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs); rewriter.setInsertionPointToStart(block); detail::SpatComputeBatchBodyArgs args { block->getArgument(0), mlir::ValueRange(block->getArguments()).slice(1, weights.size()), mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()), mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())}; using BodyResult = std::invoke_result_t; if constexpr (std::is_same_v) { std::forward(body)(args); rewriter.setInsertionPointAfter(batchOp); return mlir::FailureOr(batchOp); } else { auto bodyResult = std::forward(body)(args); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(batchOp); rewriter.eraseOp(batchOp); return mlir::FailureOr(mlir::failure()); } rewriter.setInsertionPointAfter(batchOp); return mlir::FailureOr(batchOp); } } template auto createSpatCompute(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { return createSpatGraphCompute( rewriter, loc, resultTypes, weights, inputs, std::forward(body)); } template auto createSpatCompute(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward(body)); } template auto createSpatComputeBatch(RewriterT& rewriter, mlir::Location loc, mlir::TypeRange resultTypes, int64_t laneCount, mlir::ValueRange weights, mlir::ValueRange inputs, BodyFn&& body) { return createSpatGraphComputeBatch( rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward(body)); } inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, mlir::Value dest, mlir::ArrayRef offsets, mlir::ArrayRef sizes, mlir::ArrayRef strides) { auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides); } template mlir::Value materializeOrComputeUnary(mlir::Value input, mlir::RankedTensorType resultType, mlir::PatternRewriter& rewriter, mlir::Location loc, BodyFn&& build) { auto&& buildFn = build; if (isCompileTimeComputable(input)) return buildFn(input); auto computeOp = createSpatCompute<1>( rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) { mlir::Value result = buildFn(computeInput); spatial::SpatYieldOp::create(rewriter, loc, result); }); return computeOp.getResult(0); } mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::PatternRewriter& rewriter); } // namespace onnx_mlir