add better createSpatCompute helper
This commit is contained in:
@@ -104,20 +104,39 @@ inline auto getTensorShape(mlir::Value tensor) {
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <size_t>
|
||||
using ValueArg = mlir::Value;
|
||||
|
||||
template <typename Fn, typename Seq>
|
||||
struct InvokeWithBlockArgsResult;
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
||||
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
||||
};
|
||||
|
||||
template <typename Fn, typename Seq>
|
||||
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
||||
|
||||
template <typename Fn>
|
||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <size_t NumInputs, typename BodyFn>
|
||||
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
@@ -128,10 +147,61 @@ spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter&
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
||||
auto bodyResult =
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||
}
|
||||
else {
|
||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||
}
|
||||
else {
|
||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
|
||||
Reference in New Issue
Block a user