301 lines
12 KiB
C++
301 lines
12 KiB
C++
#pragma once
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/Block.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/ValueRange.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include <cassert>
|
|
#include <cstddef>
|
|
#include <limits>
|
|
#include <type_traits>
|
|
#include <utility>
|
|
|
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace detail {
|
|
|
|
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
|
|
|
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
|
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
|
}
|
|
|
|
template <typename Fn, size_t... Is>
|
|
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
|
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
|
}
|
|
|
|
template <typename Fn, size_t... Is>
|
|
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
|
return std::forward<Fn>(fn)(values[Is]...);
|
|
}
|
|
|
|
template <size_t>
|
|
using ValueArg = mlir::Value;
|
|
|
|
template <typename Fn, typename Seq>
|
|
struct InvokeWithBlockArgsResult;
|
|
|
|
template <typename Fn, size_t... Is>
|
|
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
|
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
|
};
|
|
|
|
template <typename Fn, typename Seq>
|
|
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
|
|
|
template <typename Fn>
|
|
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
|
|
|
struct SpatComputeBatchBodyArgs {
|
|
mlir::Value lane;
|
|
mlir::ValueRange weights;
|
|
mlir::ValueRange inputs;
|
|
mlir::ValueRange outputs;
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
template <typename RewriterT>
|
|
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
|
assert(!inputs.empty() && "spat.concat requires at least one input");
|
|
if (inputs.size() == 1)
|
|
return inputs.front();
|
|
|
|
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
|
auto outputShape = llvm::to_vector(firstType.getShape());
|
|
int64_t concatDimSize = 0;
|
|
bool concatDimDynamic = false;
|
|
|
|
for (mlir::Value input : inputs) {
|
|
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
|
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
|
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
|
concatDimDynamic = true;
|
|
else
|
|
concatDimSize += inputType.getDimSize(axis);
|
|
}
|
|
|
|
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
|
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
|
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
|
}
|
|
|
|
/// Builds a `spat.graph_compute` with a fixed number of SSA inputs and erases it if
|
|
/// the body callback reports failure.
|
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
|
auto createSpatGraphCompute(RewriterT& rewriter,
|
|
mlir::Location loc,
|
|
mlir::TypeRange resultTypes,
|
|
mlir::ValueRange weights,
|
|
mlir::ValueRange inputs,
|
|
BodyFn&& body) {
|
|
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
|
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
|
|
|
auto* block = new mlir::Block();
|
|
for (mlir::Value weight : weights)
|
|
block->addArgument(weight.getType(), loc);
|
|
for (mlir::Value input : inputs)
|
|
block->addArgument(input.getType(), loc);
|
|
|
|
computeOp.getBody().push_back(block);
|
|
rewriter.setInsertionPointToStart(block);
|
|
|
|
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
|
detail::invokeWithValues(std::forward<BodyFn>(body),
|
|
detail::getInputBlockArgs(block, weights.size()),
|
|
std::make_index_sequence<NumInputs> {});
|
|
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
return computeOp;
|
|
}
|
|
else {
|
|
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
|
|
detail::getInputBlockArgs(block, weights.size()),
|
|
std::make_index_sequence<NumInputs> {});
|
|
if (mlir::failed(bodyResult)) {
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
rewriter.eraseOp(computeOp);
|
|
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
|
|
}
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
|
|
}
|
|
}
|
|
|
|
/// Builds a `spat.graph_compute` whose body consumes the block arguments as a single
|
|
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
|
template <typename RewriterT, typename BodyFn>
|
|
auto createSpatGraphCompute(RewriterT& rewriter,
|
|
mlir::Location loc,
|
|
mlir::TypeRange resultTypes,
|
|
mlir::ValueRange weights,
|
|
mlir::ValueRange inputs,
|
|
BodyFn&& body) {
|
|
auto computeOp = spatial::SpatGraphCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
|
|
|
auto* block = new mlir::Block();
|
|
for (mlir::Value weight : weights)
|
|
block->addArgument(weight.getType(), loc);
|
|
for (mlir::Value input : inputs)
|
|
block->addArgument(input.getType(), loc);
|
|
|
|
computeOp.getBody().push_back(block);
|
|
rewriter.setInsertionPointToStart(block);
|
|
|
|
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
|
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
|
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
return computeOp;
|
|
}
|
|
else {
|
|
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
|
if (mlir::failed(bodyResult)) {
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
rewriter.eraseOp(computeOp);
|
|
return mlir::FailureOr<spatial::SpatGraphCompute>(mlir::failure());
|
|
}
|
|
rewriter.setInsertionPointAfter(computeOp);
|
|
return mlir::FailureOr<spatial::SpatGraphCompute>(computeOp);
|
|
}
|
|
}
|
|
|
|
template <typename RewriterT, typename BodyFn>
|
|
auto createSpatGraphComputeBatch(RewriterT& rewriter,
|
|
mlir::Location loc,
|
|
mlir::TypeRange resultTypes,
|
|
int64_t laneCount,
|
|
mlir::ValueRange weights,
|
|
mlir::ValueRange inputs,
|
|
BodyFn&& body) {
|
|
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
|
|
|
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
|
|
if (mlir::failed(laneCountAttr))
|
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
|
|
|
auto batchOp = spatial::SpatGraphComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
|
|
|
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
|
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
|
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
|
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
|
for (mlir::Value weight : weights) {
|
|
blockArgTypes.push_back(weight.getType());
|
|
blockArgLocs.push_back(weight.getLoc());
|
|
}
|
|
for (mlir::Value input : inputs) {
|
|
blockArgTypes.push_back(input.getType());
|
|
blockArgLocs.push_back(input.getLoc());
|
|
}
|
|
for (mlir::Type resultType : resultTypes) {
|
|
blockArgTypes.push_back(resultType);
|
|
blockArgLocs.push_back(loc);
|
|
}
|
|
|
|
auto* block =
|
|
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
|
|
rewriter.setInsertionPointToStart(block);
|
|
|
|
detail::SpatComputeBatchBodyArgs args {
|
|
block->getArgument(0),
|
|
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
|
|
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
|
|
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
|
|
|
|
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
|
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
|
std::forward<BodyFn>(body)(args);
|
|
rewriter.setInsertionPointAfter(batchOp);
|
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
|
|
}
|
|
else {
|
|
auto bodyResult = std::forward<BodyFn>(body)(args);
|
|
if (mlir::failed(bodyResult)) {
|
|
rewriter.setInsertionPointAfter(batchOp);
|
|
rewriter.eraseOp(batchOp);
|
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(mlir::failure());
|
|
}
|
|
rewriter.setInsertionPointAfter(batchOp);
|
|
return mlir::FailureOr<spatial::SpatGraphComputeBatch>(batchOp);
|
|
}
|
|
}
|
|
|
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
|
auto createSpatCompute(RewriterT& rewriter,
|
|
mlir::Location loc,
|
|
mlir::TypeRange resultTypes,
|
|
mlir::ValueRange weights,
|
|
mlir::ValueRange inputs,
|
|
BodyFn&& body) {
|
|
return createSpatGraphCompute<NumInputs>(
|
|
rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
|
|
}
|
|
|
|
template <typename RewriterT, typename BodyFn>
|
|
auto createSpatCompute(RewriterT& rewriter,
|
|
mlir::Location loc,
|
|
mlir::TypeRange resultTypes,
|
|
mlir::ValueRange weights,
|
|
mlir::ValueRange inputs,
|
|
BodyFn&& body) {
|
|
return createSpatGraphCompute(rewriter, loc, resultTypes, weights, inputs, std::forward<BodyFn>(body));
|
|
}
|
|
|
|
template <typename RewriterT, typename BodyFn>
|
|
auto createSpatComputeBatch(RewriterT& rewriter,
|
|
mlir::Location loc,
|
|
mlir::TypeRange resultTypes,
|
|
int64_t laneCount,
|
|
mlir::ValueRange weights,
|
|
mlir::ValueRange inputs,
|
|
BodyFn&& body) {
|
|
return createSpatGraphComputeBatch(
|
|
rewriter, loc, resultTypes, laneCount, weights, inputs, std::forward<BodyFn>(body));
|
|
}
|
|
|
|
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
|
mlir::Location loc,
|
|
mlir::Value source,
|
|
mlir::Value dest,
|
|
mlir::ArrayRef<mlir::OpFoldResult> offsets,
|
|
mlir::ArrayRef<mlir::OpFoldResult> sizes,
|
|
mlir::ArrayRef<mlir::OpFoldResult> strides) {
|
|
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
|
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
|
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
|
|
}
|
|
|
|
template <typename BodyFn>
|
|
mlir::Value materializeOrComputeUnary(mlir::Value input,
|
|
mlir::RankedTensorType resultType,
|
|
mlir::PatternRewriter& rewriter,
|
|
mlir::Location loc,
|
|
BodyFn&& build) {
|
|
auto&& buildFn = build;
|
|
if (isCompileTimeComputable(input))
|
|
return buildFn(input);
|
|
|
|
auto computeOp = createSpatCompute<1>(
|
|
rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
|
mlir::Value result = buildFn(computeInput);
|
|
spatial::SpatYieldOp::create(rewriter, loc, result);
|
|
});
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::PatternRewriter& rewriter);
|
|
|
|
} // namespace onnx_mlir
|