diff --git a/.gitignore b/.gitignore index 4581ef2..b3f92b8 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,6 @@ *.exe *.out *.app + +# Vim +*.swp diff --git a/LICENSE b/LICENSE index 8dada3e..29c69bc 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright {yyyy} {name of copyright owner} + Copyright 2016 Regents of the University of Michigan Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/channel.cc b/channel.cc new file mode 100644 index 0000000..aef0608 --- /dev/null +++ b/channel.cc @@ -0,0 +1,40 @@ +// Copyright 2016 Regents of the University of Michigan +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "channel.h" + +using namespace std; + +namespace dadrian { + +WaitGroup::WaitGroup() : m_counter(0) {} + +void WaitGroup::add(size_t delta) { + m_counter += delta; +} + +void WaitGroup::done() { + auto previous = m_counter.fetch_sub(1); + if (previous == 1) { + // Current is now zero, wake everything up + m_cv.notify_all(); + } +} + +void WaitGroup::wait() { + unique_lock lock(m_mutex); + m_cv.wait(lock, [&]() { return m_counter == 0; }); +} + +} // namespace dadrian diff --git a/channel.h b/channel.h new file mode 100644 index 0000000..264b9dc --- /dev/null +++ b/channel.h @@ -0,0 +1,157 @@ +// Copyright 2016 Regents of the University of Michigan +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DADRIAN_CHANNEL_H +#define DADRIAN_CHANNEL_H + +#include +#include +#include +#include +#include + +namespace dadrian { + +template +class Channel { + private: + std::queue m_queue; + mutable std::mutex m_queue_mutex; + mutable std::condition_variable m_queue_not_full; + mutable std::condition_variable m_queue_not_empty; + size_t m_max_size; + + std::atomic m_closed; + + friend class ChannelIterator; + + public: + class ChannelIterator { + public: + using value_type = T; + + private: + bool m_valid; + Channel& m_channel; + value_type m_value; + + friend class Channel; + + ChannelIterator(Channel& channel) : m_channel(channel), m_valid(true) { + m_channel.recv(*this); + } + + public: + ChannelIterator(const ChannelIterator& other) = delete; + ChannelIterator(ChannelIterator&& other) + : m_channel(other.m_channel), + m_value(std::move(other.m_value)), + m_valid(other.m_valid) { + other.m_valid = false; + } + + bool operator==(const ChannelIterator& other) const { + return this == &other; + } + + bool operator!=(const ChannelIterator& other) const { + return !(*this == other); + } + + const value_type& operator*() const { return m_value; } + + value_type& operator*() { return m_value; } + + value_type* operator->() { return &m_value; } + + const value_type* operator->() const { return &m_value; } + + ChannelIterator& operator++() { + m_channel.recv(*this); + return *this; + } + + ChannelIterator operator++(int) = delete; + + inline bool valid() const { return m_valid; } + }; + + using iterator = ChannelIterator; + + Channel() { + m_max_size = 1024; + m_closed.store(false); + } + + void close() { + std::lock_guard lock(m_queue_mutex); + m_closed.store(true); + m_queue_not_full.notify_all(); + m_queue_not_empty.notify_all(); + } + + void send(T&& elt) { + std::unique_lock lock(m_queue_mutex); + m_queue_not_full.wait(lock, [&]() { + return (m_queue.size() < m_max_size) || m_closed; + }); + std::lock_guard guard(*lock.release(), std::adopt_lock); + if (m_closed) { + throw std::bad_function_call(); + } + m_queue.push(std::move(elt)); + m_queue_not_empty.notify_one(); + } + + iterator range() { + ChannelIterator it(*this); + return it; + } + + private: + void recv(ChannelIterator& it) { + std::unique_lock lock(m_queue_mutex); + m_queue_not_empty.wait( + lock, [&]() { return (m_queue.size() > 0) || m_closed; }); + std::lock_guard guard(*lock.release(), std::adopt_lock); + if (m_queue.size() > 0) { + it.m_value = std::move(m_queue.front()); + m_queue.pop(); + m_queue_not_full.notify_one(); + } else if (m_closed) { + it.m_valid = false; + } else { + assert(false); + } + } +}; + +class WaitGroup { + private: + std::atomic_uint_fast64_t m_counter; + std::mutex m_mutex; + std::condition_variable m_cv; + + public: + WaitGroup(); + WaitGroup(const WaitGroup&) = delete; + WaitGroup(WaitGroup&&) = delete; + void add(size_t delta); + void done(); + void wait(); +}; + +} // namespace dadrian + +#endif /* DADRIAN_CHANNEL_H */