diff --git a/include/exec/when_all_range.hpp b/include/exec/when_all_range.hpp new file mode 100644 index 000000000..4b2d07336 --- /dev/null +++ b/include/exec/when_all_range.hpp @@ -0,0 +1,413 @@ +/* + * Copyright (c) 2024 Maikel Nadolski + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace exec { +template class types; + +template class manual_alternative { + public: + manual_alternative(manual_alternative&& other) { emplace<0>(std::move(other.storage_.first)); } + + template explicit constexpr manual_alternative(std::in_place_t, Args&&... args) { + emplace<0>(std::forward(args)...); + } + + template constexpr void emplace(Args&&... args) { + if constexpr (Idx == 0) { + new (&storage_.first) T1(std::forward(args)...); + } else { + new (&storage_.second) T2(std::forward(args)...); + } + } + + template + constexpr void emplace_from(Fn&& fun, Args&&... args) { + if constexpr (Idx == 0) { + new (&storage_.first) T1(std::forward(fun)(std::forward(args)...)); + } else { + new (&storage_.second) T2(std::forward(fun)(std::forward(args)...)); + } + } + + template constexpr void destroy() noexcept { + if constexpr (Idx == 0) { + storage_.first.~T1(); + } else { + storage_.second.~T2(); + } + } + + template constexpr auto&& get(this Self&& self) noexcept { + if constexpr (Idx == 0) { + return std::forward(self).storage_.first; + } else { + return std::forward(self).storage_.second; + } + } + + private: + union storage_type { + constexpr storage_type() {} + + constexpr ~storage_type() noexcept {} + + T1 first; + T2 second; + }; + + storage_type storage_; +}; + +template class intrusive_queue; + +template class intrusive_queue { + public: + intrusive_queue() noexcept = default; + + void push(Tp& object) noexcept { + auto* item = std::addressof(object); + item->*Next = nullptr; + if (tail_ == nullptr) { + head_ = item; + } else { + tail_->*Next = item; + } + tail_ = item; + size_ += 1; + } + + [[nodiscard]] + Tp* pop() noexcept { + if (head_ == nullptr) { + return nullptr; + } + auto* item = head_; + head_ = head_->*Next; + if (head_ == nullptr) { + tail_ = nullptr; + } + size_ -= 1; + return item; + } + + [[nodiscard]] + std::size_t size() const noexcept { + return size_; + } + + private: + Tp* head_{nullptr}; + Tp* tail_{nullptr}; + std::size_t size_{0}; +}; + +namespace when_all_range_ { +struct unit {}; + +template struct local_operation_result { + local_operation_result* next_{nullptr}; + + using storage_type = std::conditional_t, unit, std::optional>; + + [[no_unique_address]] + storage_type result_{}; +}; + +template struct operation_base { + struct stop_callback_t { + operation_base& op_; + + void operator()() noexcept { op_.notify_stopped(); } + }; + + using stop_token = stdexec::stop_token_of_t>; + using stop_callback_type = stdexec::stop_callback_for_t; + + ErrorVariant error_{}; + Receiver receiver_; + stdexec::inplace_stop_source stop_source_{}; + intrusive_queue<&local_operation_result::next_> results_{}; + std::atomic count_{0}; + std::atomic disposition_{0}; + stdexec::__manual_lifetime stop_callback_{}; + + explicit operation_base(Receiver receiver) noexcept + : receiver_(std::move(receiver)) {} + + void do_start() noexcept { + stop_callback_.__construct(stdexec::get_stop_token(stdexec::get_env(receiver_)), + stop_callback_t{*this}); + } + + void notify() noexcept { + if (count_.fetch_sub(1) == 1) { + stop_callback_.__destroy(); + switch (disposition_) { + case 0: { + if constexpr (std::is_void_v) { + stdexec::set_value(std::move(receiver_)); + } else { + std::vector result; + result.reserve(results_.size()); + while (auto* item = results_.pop()) { + assert(item->result_); + result.push_back(*std::exchange(item->result_, std::nullopt)); + } + stdexec::set_value(std::move(receiver_), std::move(result)); + } + break; + } + case 1: { + std::visit( + [&](Err&& err) noexcept { + if constexpr (!std::same_as) { + stdexec::set_error(std::move(receiver_), std::move(err)); + } + }, + std::move(error_)); + break; + } + case 2: + stdexec::set_stopped(std::move(receiver_)); + break; + } + } + } + + template void notify_error(Error&& err) noexcept { + int expected_disposition = 0; + if (disposition_.compare_exchange_strong(expected_disposition, 1)) { + try { + error_.template emplace>(std::forward(err)); + } catch (...) { + if constexpr (!std::is_nothrow_constructible_v, Error>) { + error_.template emplace(std::current_exception()); + } + } + stop_source_.request_stop(); + } + this->notify(); + } + + void notify_stopped() noexcept { + int expected_disposition = 0; + if (disposition_.compare_exchange_strong(expected_disposition, 2)) { + stop_source_.request_stop(); + } + this->notify(); + } +}; + +template +struct local_operation_base : local_operation_result { + + explicit local_operation_base(operation_base& parent) noexcept + : local_operation_result{} + , parent_{parent} { + parent_.results_.push(*this); + } + + local_operation_base(const local_operation_base&) = delete; + local_operation_base& operator=(const local_operation_base&) = delete; + local_operation_base(local_operation_base&&) = delete; + local_operation_base& operator=(local_operation_base&&) = delete; + + operation_base& parent_; +}; + +template +using local_env_t = + exec::make_env_t, + exec::with_t>; + +template struct local_receiver { + using receiver_concept = stdexec::receiver_t; + + auto get_env() const noexcept -> local_env_t { + return exec::make_env( + stdexec::get_env(local_op_->parent_.receiver_), + exec::with(stdexec::get_stop_token, local_op_->parent_.stop_source_.get_token())); + } + + template void set_value(Args&&... result) && noexcept { + if constexpr (sizeof...(Args) > 0) { + local_op_->result_.emplace(std::forward(result)...); + } + local_op_->parent_.notify(); + } + + template void set_error(Error&& error) && noexcept { + local_op_->parent_.notify_error(std::forward(error)); + } + + void set_stopped() && noexcept { local_op_->parent_.notify_stopped(); } + + local_operation_base* local_op_; +}; + +template using nullable_std_variant = std::variant; + +template +using nullable_std_variant_for = + stdexec::__minvoke>, std::exception_ptr, + std::remove_cvref_t...>; + +template struct traits { + using Result = stdexec::__single_sender_value_t; + using ErrorVariant = stdexec::error_types_of_t; +}; + +template +struct local_operation + : local_operation_base>::Result, + typename traits>::ErrorVariant, + Receiver> { + using Result = typename traits>::Result; + using ErrorVariant = typename traits>::ErrorVariant; + + local_operation(operation_base& parent, Sender&& sndr) noexcept + : local_operation_base(parent) + , child_op_(stdexec::connect(std::forward(sndr), + local_receiver(this))) {} + + void start() noexcept { stdexec::start(child_op_); } + + stdexec::connect_result_t> child_op_; +}; + +template +class operation + : public operation_base, + stdexec::env_of_t>::Result, + typename traits, + stdexec::env_of_t>::ErrorVariant, + Receiver> { + public: + using operation_state_concept = stdexec::operation_state_t; + + using Sender = std::ranges::range_value_t; + + using Result = + typename traits, stdexec::env_of_t>::Result; + + using ErrorVariant = typename traits, + stdexec::env_of_t>::ErrorVariant; + + explicit operation(Range range, Receiver receiver) + : operation_base(std::move(receiver)) { + if constexpr (std::ranges::sized_range) { + children_.reserve(std::ranges::size(range)); + } + try { + for (auto&& sndr : range) { + children_.emplace_back(std::in_place, std::forward(sndr)); + } + } catch (...) { + for (auto& variant : children_) { + variant.template destroy<0>(); + } + throw; + } + std::size_t counter = 0; + for (auto& variant : children_) { + try { + auto sndr = std::move(variant).template get<0>(); + variant.template emplace_from<1>( + [&] { return local_operation(*this, std::move(sndr)); }); // NOLINT + } catch (...) { + std::size_t i = 0; + for (; i < counter; ++i) { + children_[i].template destroy<1>(); + } + for (; i < children_.size(); ++i) { + children_[i].template destroy<0>(); + } + throw; + } + counter += 1; + } + this->count_.store(static_cast(counter)); + } + + void start() noexcept { + this->do_start(); + for (auto& child : children_) { + stdexec::start(child.template get<1>()); + } + } + + std::vector>> children_; +}; + +template +using to_std_vector = + stdexec::completion_signatures>)...>; + +template class sender { + public: + using sender_concept = stdexec::sender_t; + + explicit sender(Range range) noexcept(std::is_nothrow_move_constructible_v) + : range_(std::move(range)) {} + + template + requires stdexec::__single_value_sender, + stdexec::env_of_t> + auto connect(this Self&& self, Receiver receiver) -> operation { + return operation(std::forward(self).range_, std::move(receiver)); + } + + template + requires stdexec::__single_value_sender, Env> + auto get_completion_signatures(Env&&) const noexcept + -> stdexec::transform_completion_signatures_of< + std::ranges::range_value_t, Env, + stdexec::completion_signatures> { + return {}; + } + + private: + Range range_; +}; +} // namespace when_all_range_ + +struct when_all_range_t { + template + auto operator()(Range range) const + noexcept(noexcept(when_all_range_::sender{std::move(range)})) + -> when_all_range_::sender { + return when_all_range_::sender{std::move(range)}; + } +}; + +inline constexpr when_all_range_t when_all_range{}; + +} // namespace exec \ No newline at end of file diff --git a/test/exec/CMakeLists.txt b/test/exec/CMakeLists.txt index 7389f2d64..259eb050e 100644 --- a/test/exec/CMakeLists.txt +++ b/test/exec/CMakeLists.txt @@ -38,6 +38,7 @@ set(exec_test_sources async_scope/test_empty.cpp async_scope/test_stop.cpp test_when_any.cpp + test_when_all_range.cpp test_at_coroutine_exit.cpp test_materialize.cpp $<$:test_io_uring_context.cpp> diff --git a/test/exec/test_when_all_range.cpp b/test/exec/test_when_all_range.cpp new file mode 100644 index 000000000..47ae88c07 --- /dev/null +++ b/test/exec/test_when_all_range.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Maikel Nadolski + * Copyright (c) 2024 NVIDIA Corporation + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace { + + TEST_CASE("when_all_range - sum up an array", "[when_all_range]") { + std::array array{42, 43, 44}; + int sum = 0; + auto sum_up = std::ranges::views::transform(std::ranges::views::all(array), [&sum](int x) { + return stdexec::then(stdexec::just(), [&sum, x]() noexcept { sum += x; }); + }); + auto when_all = exec::when_all_range(sum_up); + stdexec::sync_wait(std::move(when_all)); + CHECK(sum == (42 + 43 + 44)); + } + +} \ No newline at end of file