This commit is contained in:
@@ -18,13 +18,17 @@ namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
||||
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(values[Is]...);
|
||||
}
|
||||
|
||||
@@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -93,14 +99,15 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
detail::invokeWithValues(
|
||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult =
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
auto bodyResult = detail::invokeWithValues(
|
||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
@@ -123,6 +130,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -131,13 +140,13 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
|
||||
Reference in New Issue
Block a user