diff --git a/concurrent.py b/concurrent.py index 0169c37..dfb0fbf 100644 --- a/concurrent.py +++ b/concurrent.py @@ -2,10 +2,11 @@ # pyre-strict +import os import threading import time from collections.abc import Iterator -from queue import Empty +from queue import Empty, Full try: from queue import ShutDown # type: ignore @@ -121,6 +122,10 @@ def iterator_local(self, max_key: int, clear: bool = True) -> Iterator[Any]: # class ConcurrentQueue: """ A thread-safe queue implementation that allows concurrent access and modification. + + Note: + ConcurrentQueue deliberately does not follow the same API as queue.Queue. To get a replacement + for queue.Queue use StdConcurrentQueue. """ _SHUTDOWN = 1 @@ -204,12 +209,16 @@ def pop(self, timeout: float | None = None) -> Any: # type: ignore """ Removes and returns an element from the front of the queue. Args: - timeout (float | None, optional): The maximum time to wait for an element to become available. Defaults to None. + timeout (float | None, optional): The maximum time to wait for an element to become available. + Defaults to None. Returns: Any: The removed element. Raises: Empty: If the queue is empty and the timeout expires. ShutDown: If the queue is shutting down - i.e. shutdown() has been called. + + Note: + Timeout can be 0 but this is not recommended; if you want non-blocking behaviour use StdConcurrentQueue. """ next_key = self._outkey.incr() _flags = LocalWrapper(self._flags) @@ -269,7 +278,7 @@ def pop(self, timeout: float | None = None) -> Any: # type: ignore raise RuntimeError("Queue failed") if timeout is None: _cond.wait() - elif not _cond.wait(timeout): + elif timeout == 0.0 or not _cond.wait(timeout): timed_out = True break if timed_out: @@ -370,14 +379,94 @@ def pop_local(self, timeout: float | None = None) -> LocalWrapper: """ return LocalWrapper(self.pop(timeout)) - def get(self, timeout: float | None = None) -> Any: # type: ignore - """ - An aliase for pop. See the docs for pop(). - """ - return self.pop(timeout) - def put(self, value: Any) -> None: # type: ignore - """ - An alias for push(value=Any). - """ - self.push(value) +class StdConcurrentQueue(ConcurrentQueue): + """ + A class which is a drop in replacement for queue.Queue and behaves as a lock free ConcurrentQueue but supports + the features of queue.Queue which ConcurrentQueue does not. These extra features may add some overhead to + operation and so this Queue is only preferred when an exact replacement for queue.Queue is required. + + Also note that there might be subtle differences in the way sequencing behaves in a multi-threaded environment + compared to queue.Queue simply because this is a (mainly) lock free algorithm. + """ + + def __init__(self, maxsize: int = 0) -> None: + osc = os.cpu_count() + if osc: + super().__init__(scaling=max(1, osc // 2), lock_free=True) + else: + super().__init__(lock_free=True) + + self._maxsize: int = max(maxsize, 0) + self._active_tasks = AtomicInt64(0) + + def qsize(self) -> int: + return self.size() + + def get(self, block: bool = True, timeout: float | None = None) -> Any: # type: ignore + if block and timeout != 0.0: + return self.pop(timeout=timeout) + else: + # Use this to attempt to avoid excessive placeholder creation. + if self.size() > 0: + return self.pop(timeout=0.0) + else: + raise Empty + + def full(self) -> bool: + _maxsize = self._maxsize + return bool(_maxsize and self.size() >= _maxsize) + + def put(self, item: Any, block: bool = True, timeout: float | None = None) -> None: # type: ignore + + if block and self._maxsize and self.full(): + _flags = LocalWrapper(self._flags) + _shutdown = self._SHUTDOWN + _sleep = LocalWrapper(time.sleep) + _now = LocalWrapper(time.monotonic) + start = _now() + if timeout is not None: + end_time = start + timeout + else: + end_time = None + pause_time = start + 0.05 + while self.full(): + it_time = _now() + if _flags & _shutdown: + raise ShutDown + if end_time is not None and it_time > end_time: + raise Full + if it_time < pause_time: + _sleep(0) + else: + _sleep(0.05) + else: + if self.full(): + raise Full + + self.push(item) + # The push succeeded so we can do this here. + self._active_tasks.incr() + + def put_nowait(self, item: Any) -> None: # type: ignore + return self.put(item, block=False) + + def get_nowait(self) -> Any: # type: ignore + return self.get(block=False) + + def task_done(self) -> None: + self._active_tasks.decr() + + def join(self) -> None: + _sleep = LocalWrapper(time.sleep) + _now = LocalWrapper(time.monotonic) + _flags = LocalWrapper(self._flags) + _shut_now = self._SHUT_NOW + _active_tasks = LocalWrapper(self._active_tasks) + start = _now() + pause_time = start + 0.05 + while _active_tasks and not (_flags & _shut_now): + if _now() < pause_time: + _sleep(0) + else: + _sleep(0.05) diff --git a/concurrent_queue_bench.py b/concurrent_queue_bench.py index f99b6e4..476f79a 100644 --- a/concurrent_queue_bench.py +++ b/concurrent_queue_bench.py @@ -6,21 +6,26 @@ import queue from ft_utils.benchmark_utils import BenchmarkProvider, execute_benchmarks -from ft_utils.concurrent import ConcurrentQueue +from ft_utils.concurrent import ConcurrentQueue, StdConcurrentQueue from ft_utils.local import LocalWrapper +ConcurrentQueue.put = ConcurrentQueue.push # type: ignore +ConcurrentQueue.get = ConcurrentQueue.pop # type: ignore + class ConcurretQueueBenchmarkProvider(BenchmarkProvider): def __init__(self, operations: int) -> None: self._operations = operations self._queue: ConcurrentQueue | None = None self._queue_lf: ConcurrentQueue | None = None - self._queue_std: queue.Queue | None = None # type: ignore + self._queue_queue: queue.Queue | None = None # type: ignore + self._queue_std: StdConcurrentQueue | None = None # type: ignore def set_up(self) -> None: self._queue = ConcurrentQueue(os.cpu_count()) self._queue_lf = ConcurrentQueue(os.cpu_count(), lock_free=True) - self._queue_std = queue.Queue() + self._queue_queue = queue.Queue() + self._queue_std = StdConcurrentQueue() def benchmark_locked(self) -> None: lw = LocalWrapper(self._queue) @@ -34,6 +39,10 @@ def benchmark_std(self) -> None: lw = LocalWrapper(self._queue_std) self._bm(lw) + def benchmark_queue(self) -> None: + lw = LocalWrapper(self._queue_queue) + self._bm(lw) + def _bm(self, lw) -> None: # type: ignore for n in range(self._operations): lw.put(n) @@ -51,6 +60,10 @@ def benchmark_std_batch(self) -> None: lw = LocalWrapper(self._queue_std) self._bmb(lw) + def benchmark_queue_batch(self) -> None: + lw = LocalWrapper(self._queue_queue) + self._bmb(lw) + def _bmb(self, lw) -> None: # type: ignore for n in range(self._operations // 100): for _ in range(100): diff --git a/docs/concurrent_api.md b/docs/concurrent_api.md index 97c12ee..757994b 100644 --- a/docs/concurrent_api.md +++ b/docs/concurrent_api.md @@ -221,9 +221,7 @@ A concurrent queue that allows multiple threads to push and pop values. * `__init__(scaling=None, lock_free=False)`: Initializes a new ConcurrentQueue with the specified scaling factor. If `lock_free` is True, the queue will use a lock-free implementation, which can improve performance in certain scenarios. * `push(value)`: Pushes a value onto the queue. This method is thread-safe and can be called from multiple threads. -* `put(value)`: An alias for `push(value)`. * `pop(timeout=None)`: Pops a value from the queue. The method will block until a value is available. If `timeout` is specified, the method will raise an Empty exception if no value is available within the specified time. -* `get(timeout=None)`: An alias for `pop(timeout)`. * `pop_local(timeout=None)`: Returns a LocalWrapper object containing the popped value. The behavior is otherwise identical to `pop(timeout)`. * `shutdown(immediate=False)`: Initiates shutdown of the queue. If `immediate` is True, the queue will shut down immediately, otherwise it will wait for any pending operations to complete. * `size()`: Returns the number of elements currently in the queue. @@ -275,3 +273,14 @@ queue.shutdown() # Raises ShutDown queue.pop() ``` + +## StdConcurrentQueue + +This follows the same API as [queue.Queue](https://docs.python.org/3/library/queue.html#queue.Queue). For simple applications StdConcurrentQueue will work as a drop in replacement for queue.Queue. However, there are subtle differences: + +* StdConcurrentQueue will use a very small amount of CPU time even when not processing elements. +* This implementation has weeker FIFO guaratees than queue.Queue which might cause subtle issues in some applications. +* StdConcurrentQueue will use a release memory in a different pattern than queue.Queue. +* The maxsize is not as strictly guaranteed. If maxsize is set and a large number of threads attempt to fill the queue beyond maxsize then a small overfill might occur due to the lack of a lock to prevent this race condition. + +Therefore, in complex applications it may be a better approach to mindfully replace highly contended queue.Queue instances with StdConcurrentQueue. In this case it is also better to use the simpler ConcurrentQueue where possible. diff --git a/test_concurrent.py b/test_concurrent.py index bd3056a..200e15f 100644 --- a/test_concurrent.py +++ b/test_concurrent.py @@ -402,23 +402,6 @@ def worker(): t.start() self.assertEqual(q.pop(), 10) - def test_pop_timeout_expires(self): - q = self._get_queue() - f = concurrent.AtomicFlag(False) - - def worker(): - f.set(True) - time.sleep(1) - q.push(10) - - t = threading.Thread(target=worker) - t.start() - while not f: - pass - with self.assertRaises(queue.Empty): - q.pop(timeout=0.5) - t.join() - def test_pop(self): q = self._get_queue() @@ -431,12 +414,7 @@ def worker(): self.assertEqual(q.pop(), 10) t.join() - def test_get(self): - q = self._get_queue() - q.push(10) - self.assertEqual(q.get(), 10) - - def test_get_timeout(self): + def test_pop_timeout_sleep(self): q = self._get_queue() f = concurrent.AtomicFlag(False) @@ -449,10 +427,10 @@ def worker(): t.start() while not f: pass - self.assertEqual(q.get(timeout=1), 10) + self.assertEqual(q.pop(timeout=1), 10) t.join() - def test_get_timeout_expires(self): + def test_pop_timeout_expires(self): q = self._get_queue() f = concurrent.AtomicFlag(False) @@ -466,10 +444,10 @@ def worker(): while not f: pass with self.assertRaises(queue.Empty): - q.get(timeout=0.1) + q.pop(timeout=0.1) t.join() - def test_get_waiting(self): + def test_pop_waiting(self): q = self._get_queue() def worker(): @@ -478,7 +456,7 @@ def worker(): t = threading.Thread(target=worker) t.start() - self.assertEqual(q.get(), 10) + self.assertEqual(q.pop(), 10) t.join() def test_shutdown(self): @@ -487,7 +465,7 @@ def test_shutdown(self): q.shutdown() with self.assertRaises(concurrent.ShutDown): q.push(20) - self.assertEqual(q.get(), 10) + self.assertEqual(q.pop(), 10) with self.assertRaises(concurrent.ShutDown): q.pop() @@ -581,6 +559,174 @@ def _get_queue(self): return concurrent.ConcurrentQueue(lock_free=True) +class TestStdConcurrentQueue(unittest.TestCase): + + def _get_queue(self, maxsize=0): + return concurrent.StdConcurrentQueue(maxsize) + + def test_smoke(self): + q = self._get_queue() + q.put(10) + self.assertEqual(q.get(), 10) + + def test_multiple_put(self): + q = self._get_queue() + for i in range(10): + q.put(i) + for i in range(10): + self.assertEqual(q.get(), i) + + def test_multiple_threads(self): + q = self._get_queue() + flag = concurrent.AtomicFlag(False) + + def worker(n): + flag.set(True) + for i in range(n): + q.put(i) + + threads = [threading.Thread(target=worker, args=(10,)) for _ in range(10)] + for t in threads: + t.start() + while not flag: + pass + for t in threads: + t.join() + for _ in range(100): + x = q.get() + self.assertIn(x, list(range(10))) + + def test_get_timeout(self): + q = self._get_queue() + flag = concurrent.AtomicFlag(False) + + def worker(): + flag.set(True) + time.sleep(0.1) + q.put(10) + + t = threading.Thread(target=worker) + t.start() + while not flag: + pass + self.assertEqual(q.get(timeout=1), 10) + t.join() + + def test_get_timeout_expires(self): + q = self._get_queue() + flag = concurrent.AtomicFlag(False) + + def worker(): + flag.set(True) + time.sleep(0.5) + q.put(10) + + t = threading.Thread(target=worker) + t.start() + while not flag: + pass + with self.assertRaises(queue.Empty): + q.get(timeout=0.1) + t.join() + + def test_get_waiting(self): + q = self._get_queue() + flag = concurrent.AtomicFlag(False) + + def worker(): + flag.set(True) + time.sleep(0.1) + q.put(10) + + t = threading.Thread(target=worker) + t.start() + while not flag: + pass + self.assertEqual(q.get(), 10) + t.join() + + def test_put_nowait(self): + q = self._get_queue(maxsize=1) + q.put_nowait(10) + with self.assertRaises(queue.Full): + q.put_nowait(20) + + def test_get_nowait(self): + q = self._get_queue() + q.put(10) + self.assertEqual(q.get_nowait(), 10) + with self.assertRaises(queue.Empty): + q.get_nowait() + + def test_empty_queue(self): + q = self._get_queue() + flag = concurrent.AtomicFlag(False) + + def worker(): + flag.set(True) + time.sleep(0.1) + q.put(10) + + for _ in range(5): + t = threading.Thread(target=worker) + t.start() + while not flag: + pass + self.assertEqual(q.get(), 10) + + def test_qsize(self): + q = self._get_queue() + self.assertEqual(q.qsize(), 0) + q.put(10) + self.assertEqual(q.qsize(), 1) + q.get() + self.assertEqual(q.qsize(), 0) + + def test_full(self): + q = self._get_queue(maxsize=1) + self.assertFalse(q.full()) + q.put(10) + self.assertEqual(q.size(), 1) + self.assertEqual(q._maxsize, 1) + self.assertTrue(q.full()) + + def test_task_done(self): + q = self._get_queue() + q.put(10) + self.assertEqual(10, q.get()) + q.task_done() + self.assertEqual(int(q._active_tasks), 0) + q.join() + + def test_join(self): + q = self._get_queue() + + def worker(): + q.get() + q.task_done() + + ts = [threading.Thread(target=worker) for _ in range(10)] + for t in ts: + t.start() + q.put(10) + q.join() + t.join() + self.assertEqual(int(q._active_tasks), 0) + + def test_full_shutdown(self): + q = self._get_queue(1) + q.put(23) + + def worker(): + q.shutdown() + q.get() + + t = threading.Thread(target=worker) + t.start() + with self.assertRaises(concurrent.ShutDown): + q.put(32) + + class TestConcurrentGatheringIterator(unittest.TestCase): def test_smoke(self): iterator = concurrent.ConcurrentGatheringIterator()