add better createSpatCompute helper

This commit is contained in:
NiccoloN
2026-03-30 16:14:26 +02:00
parent 39830be888
commit 3625edc80a
5 changed files with 259 additions and 239 deletions

View File

@@ -104,20 +104,39 @@ inline auto getTensorShape(mlir::Value tensor) {
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
std::forward<Fn>(fn)(block->getArgument(Is)...);
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <size_t NumInputs, typename BodyFn>
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
@@ -128,10 +147,61 @@ spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter&
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,