diff --git a/strawberry/promise/__init__.py b/strawberry/promise/__init__.py new file mode 100644 index 0000000000..ea5cfe9752 --- /dev/null +++ b/strawberry/promise/__init__.py @@ -0,0 +1,7 @@ +from .promise import Promise, is_thenable + + +__all__ = [ + "Promise", + "is_thenable", +] diff --git a/strawberry/promise/async_.py b/strawberry/promise/async_.py new file mode 100644 index 0000000000..aec516676e --- /dev/null +++ b/strawberry/promise/async_.py @@ -0,0 +1,159 @@ +# Copy of the Promise library (https://github.com/syrusakbary/promise) with some +# modifications. +# +# Promise is licensed under the terms of the MIT license, reproduced below. +# +# = = = = = +# +# The MIT License (MIT) +# +# Copyright (c) 2016 Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Based on https://github.com/petkaantonov/bluebird/blob/master/src/promise.js +from collections import deque +from threading import local +from typing import Any, Callable, Optional + + +if False: + from .promise import Promise + + +class Async(local): + def __init__(self, trampoline_enabled=True): + self.is_tick_used = False + self.late_queue = deque() # type: ignore + self.normal_queue = deque() # type: ignore + self.have_drained_queues = False + self.trampoline_enabled = trampoline_enabled + + def enable_trampoline(self): + self.trampoline_enabled = True + + def disable_trampoline(self): + self.trampoline_enabled = False + + def have_items_queued(self): + return self.is_tick_used or self.have_drained_queues + + def _async_invoke_later(self, fn, scheduler): + self.late_queue.append(fn) + self.queue_tick(scheduler) + + def _async_invoke(self, fn, scheduler): + # type: (Callable, Any) -> None + self.normal_queue.append(fn) + self.queue_tick(scheduler) + + def _async_settle_promise(self, promise): + # type: (Promise) -> None + self.normal_queue.append(promise) + self.queue_tick(promise.scheduler) + + def invoke(self, fn, scheduler): + # type: (Callable, Any) -> None + if self.trampoline_enabled: + self._async_invoke(fn, scheduler) + else: + scheduler.call(fn) + + def settle_promises(self, promise): + # type: (Promise) -> None + if self.trampoline_enabled: + self._async_settle_promise(promise) + else: + promise.scheduler.call(promise._settle_promises) + + def throw_later(self, reason, scheduler): + # type: (Exception, Any) -> None + def fn(): + # type: () -> None + raise reason + + scheduler.call(fn) + + fatal_error = throw_later + + def drain_queue(self, queue): + # type: (deque) -> None + from .promise import Promise + + while queue: + fn = queue.popleft() + if isinstance(fn, Promise): + fn._settle_promises() + continue + fn() + + def drain_queue_until_resolved(self, promise): + # type: (Promise) -> None + from .promise import Promise + + queue = self.normal_queue + while queue: + if not promise.is_pending: + return + fn = queue.popleft() + if isinstance(fn, Promise): + fn._settle_promises() + continue + fn() + + self.reset() + self.have_drained_queues = True + self.drain_queue(self.late_queue) + + def wait(self, promise, timeout=None): + # type: (Promise, Optional[float]) -> None + if not promise.is_pending: + # We return if the promise is already + # fulfilled or rejected + return + + target = promise._target() + + if self.trampoline_enabled: + if self.is_tick_used: + self.drain_queue_until_resolved(target) + + if not promise.is_pending: + # We return if the promise is already + # fulfilled or rejected + return + target.scheduler.wait(target, timeout) + + def drain_queues(self): + # type: () -> None + assert self.is_tick_used + self.drain_queue(self.normal_queue) + self.reset() + self.have_drained_queues = True + self.drain_queue(self.late_queue) + + def queue_tick(self, scheduler): + # type: (Any) -> None + if not self.is_tick_used: + self.is_tick_used = True + scheduler.call(self.drain_queues) + + def reset(self): + # type: () -> None + self.is_tick_used = False diff --git a/strawberry/promise/promise.py b/strawberry/promise/promise.py new file mode 100644 index 0000000000..aaa3050375 --- /dev/null +++ b/strawberry/promise/promise.py @@ -0,0 +1,819 @@ +# Copy of the Promise library (https://github.com/syrusakbary/promise) with some +# modifications. +# +# Promise is licensed under the terms of the MIT license, reproduced below. +# +# = = = = = +# +# The MIT License (MIT) +# +# Copyright (c) 2016 Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from asyncio import Future, ensure_future +from functools import partial, wraps +from inspect import iscoroutine +from sys import exc_info +from threading import Event, Lock +from types import TracebackType +from typing import ( + Any, + Callable, + Dict, + Generic, + Hashable, + Iterator, + List, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from weakref import WeakKeyDictionary + +from six import reraise + +from .async_ import Async +from .promise_list import PromiseList + + +class ImmediateScheduler(object): + def call(self, fn): + # type: (Callable) -> None + try: + fn() + except Exception: + pass + + def wait(self, promise, timeout=None): + # type: (Promise, Optional[float]) -> None + e = Event() + + def on_resolve_or_reject(_): + # type: (Any) -> None + e.set() + + promise._then(on_resolve_or_reject, on_resolve_or_reject) + waited = e.wait(timeout) + if not waited: + raise Exception("Timeout") + + +def iterate_promise(promise: "Promise") -> Iterator: + if not promise.is_fulfilled: + yield from promise.future # type: ignore + assert promise.is_fulfilled + return promise.get() + + +default_scheduler = ImmediateScheduler() + +async_instance = Async() + +_state_lock = Lock() + +DEFAULT_TIMEOUT = None # type: Optional[float] + +MAX_LENGTH = 0xFFFF | 0 +CALLBACK_SIZE = 3 + +CALLBACK_FULFILL_OFFSET = 0 +CALLBACK_REJECT_OFFSET = 1 +CALLBACK_PROMISE_OFFSET = 2 + +BASE_TYPES = (str, int, bool, float, complex, tuple, list, dict, bytes) + +# These are the potential states of a promise +STATE_PENDING = -1 +STATE_REJECTED = 0 +STATE_FULFILLED = 1 + + +def make_self_resolution_error() -> TypeError: + return TypeError("Promise is self") + + +def try_catch( + handler: Callable, *args: Any, **kwargs: Any +) -> Union[Tuple[Any, None], Tuple[None, Tuple[Exception, Optional[TracebackType]]]]: + try: + return (handler(*args, **kwargs), None) + except Exception as e: + tb = exc_info()[2] + return (None, (e, tb)) + + +T = TypeVar("T") +S = TypeVar("S", contravariant=True) + + +class Promise(Generic[T]): + """ + This is the Promise class that complies + Promises/A+ specification. + """ + + _state = STATE_PENDING # type: int + _is_final = False + _is_bound = False + _is_following = False + _is_async_guaranteed = False + _length = 0 + _handlers = None # type: Dict[int, Union[Callable, Promise, None]] + _fulfillment_handler0 = None # type: Any + _rejection_handler0 = None # type: Any + _promise0 = None # type: Optional[Promise] + _future = None # type: Future + _traceback = None # type: Optional[TracebackType] + _is_waiting = False + _scheduler = None + + def __init__(self, executor=None, scheduler=None): + # type: (Optional[Callable[[Callable[[T], None], Callable[[Exception], None]], None]], Any) -> None + """ + Initialize the Promise into a pending state. + """ + self._scheduler = scheduler + + if executor is not None: + self._resolve_from_executor(executor) + + @property + def scheduler(self) -> ImmediateScheduler: + return self._scheduler or default_scheduler + + @property + def future(self) -> Future: + if not self._future: + self._future = Future() # type: ignore + self._then( # type: ignore + self._future.set_result, self._future.set_exception + ) + return self._future + + def __iter__(self) -> Iterator: + return iterate_promise(self._target()) + + __await__ = __iter__ + + def _resolve_callback(self, value): + # type: (T) -> None + if value is self: + return self._reject_callback(make_self_resolution_error(), False) + + if not self.is_thenable(value): + return self._fulfill(value) + + promise = self._try_convert_to_promise(value)._target() + if promise == self: + self._reject(make_self_resolution_error()) + return + + if promise._state == STATE_PENDING: + len = self._length + if len > 0: + promise._migrate_callback0(self) + for i in range(1, len): + promise._migrate_callback_at(self, i) + + self._is_following = True + self._length = 0 + self._set_followee(promise) + elif promise._state == STATE_FULFILLED: + self._fulfill(promise._value()) + elif promise._state == STATE_REJECTED: + self._reject(promise._reason(), promise._target()._traceback) + + def _settled_value(self, _raise=False): + # type: (bool) -> Any + assert not self._is_following + + if self._state == STATE_FULFILLED: + return self._rejection_handler0 + elif self._state == STATE_REJECTED: + if _raise: + raise_val = self._fulfillment_handler0 + reraise(type(raise_val), raise_val, self._traceback) + return self._fulfillment_handler0 + + def _fulfill(self, value: T) -> None: + if value is self: + err = make_self_resolution_error() + return self._reject(err) + with _state_lock: + self._state = STATE_FULFILLED + self._rejection_handler0 = value + + if self._length > 0: + if self._is_async_guaranteed: + self._settle_promises() + else: + async_instance.settle_promises(self) + + def _reject(self, reason, traceback=None): + # type: (Exception, Optional[TracebackType]) -> None + with _state_lock: + self._state = STATE_REJECTED + self._fulfillment_handler0 = reason + self._traceback = traceback + + if self._is_final: + assert self._length == 0 + async_instance.fatal_error(reason, self.scheduler) + return + + if self._length > 0: + async_instance.settle_promises(self) + else: + self._ensure_possible_rejection_handled() + + if self._is_async_guaranteed: + self._settle_promises() + else: + async_instance.settle_promises(self) + + def _ensure_possible_rejection_handled(self) -> None: + pass + + def _reject_callback( + self, + reason: Exception, + synchronous: bool = False, + traceback: Optional[TracebackType] = None, + ) -> None: + assert isinstance( + reason, Exception + ), "A promise was rejected with a non-error: {}".format(reason) + self._reject(reason, traceback) + + def _clear_callback_data_index_at(self, index: int) -> None: + assert not self._is_following + assert index > 0 + base = index * CALLBACK_SIZE - CALLBACK_SIZE + self._handlers[base + CALLBACK_PROMISE_OFFSET] = None + self._handlers[base + CALLBACK_FULFILL_OFFSET] = None + self._handlers[base + CALLBACK_REJECT_OFFSET] = None + + def _fulfill_promises(self, length: int, value: T) -> None: + for i in range(1, length): + handler = self._fulfillment_handler_at(i) + promise = self._promise_at(i) + self._clear_callback_data_index_at(i) + self._settle_promise(promise, handler, value, None) + + def _reject_promises(self, length: int, reason: Exception) -> None: + for i in range(1, length): + handler = self._rejection_handler_at(i) + promise = self._promise_at(i) + self._clear_callback_data_index_at(i) + self._settle_promise(promise, handler, reason, None) + + def _settle_promise( + self, + promise: Optional["Promise"], + handler: Optional[Callable], + value: Union[T, Exception], + traceback: Optional[TracebackType], + ) -> None: + assert not self._is_following + is_promise = isinstance(promise, self.__class__) + async_guaranteed = self._is_async_guaranteed + if callable(handler): + if promise is None or not is_promise: + handler(value) + else: + if async_guaranteed: + promise._is_async_guaranteed = True + self._settle_promise_from_handler(handler, value, promise) + elif promise is not None and is_promise: + if async_guaranteed: + promise._is_async_guaranteed = True + if self._state == STATE_FULFILLED: + promise._fulfill(value) + else: + promise._reject(cast(Exception, value), self._traceback) + + def _settle_promise0( + self, + handler: Optional[Callable], + value: Any, + traceback: Optional[TracebackType], + ) -> None: + promise = self._promise0 + self._promise0 = None + self._settle_promise(promise, handler, value, traceback) # type: ignore + + def _settle_promise_from_handler(self, handler, value, promise): + # type: (Callable, Any, Promise) -> None + value, error_with_tb = try_catch(handler, value) # , promise + + if error_with_tb: + error, tb = error_with_tb + promise._reject_callback(error, False, tb) + else: + promise._resolve_callback(value) + + def _promise_at(self, index): + # type: (int) -> Optional[Promise] + assert index > 0 + assert not self._is_following + return self._handlers.get( # type: ignore + index * CALLBACK_SIZE - CALLBACK_SIZE + CALLBACK_PROMISE_OFFSET + ) + + def _fulfillment_handler_at(self, index): + # type: (int) -> Optional[Callable] + assert not self._is_following + assert index > 0 + return self._handlers.get( # type: ignore + index * CALLBACK_SIZE - CALLBACK_SIZE + CALLBACK_FULFILL_OFFSET + ) + + def _rejection_handler_at(self, index): + # type: (int) -> Optional[Callable] + assert not self._is_following + assert index > 0 + return self._handlers.get( # type: ignore + index * CALLBACK_SIZE - CALLBACK_SIZE + CALLBACK_REJECT_OFFSET + ) + + def _migrate_callback0(self, follower): + # type: (Promise) -> None + self._add_callbacks( + follower._fulfillment_handler0, + follower._rejection_handler0, + follower._promise0, + ) + + def _migrate_callback_at(self, follower, index): + self._add_callbacks( + follower._fulfillment_handler_at(index), + follower._rejection_handler_at(index), + follower._promise_at(index), + ) + + def _add_callbacks( + self, + fulfill, # type: Optional[Callable] + reject, # type: Optional[Callable] + promise, # type: Optional[Promise] + ): + # type: (...) -> int + assert not self._is_following + + if self._handlers is None: + self._handlers = {} + + index = self._length + if index > MAX_LENGTH - CALLBACK_SIZE: + index = 0 + self._length = 0 + + if index == 0: + assert not self._promise0 + assert not self._fulfillment_handler0 + assert not self._rejection_handler0 + + self._promise0 = promise + if callable(fulfill): + self._fulfillment_handler0 = fulfill + if callable(reject): + self._rejection_handler0 = reject + + else: + base = index * CALLBACK_SIZE - CALLBACK_SIZE + + assert (base + CALLBACK_PROMISE_OFFSET) not in self._handlers + assert (base + CALLBACK_FULFILL_OFFSET) not in self._handlers + assert (base + CALLBACK_REJECT_OFFSET) not in self._handlers + + self._handlers[base + CALLBACK_PROMISE_OFFSET] = promise + if callable(fulfill): + self._handlers[base + CALLBACK_FULFILL_OFFSET] = fulfill + if callable(reject): + self._handlers[base + CALLBACK_REJECT_OFFSET] = reject + + self._length = index + 1 + return index + + def _target(self): + # type: () -> Promise + ret = self + while ret._is_following: + ret = ret._followee() + return ret + + def _followee(self): + # type: () -> Promise + assert self._is_following + assert isinstance(self._rejection_handler0, Promise) + return self._rejection_handler0 + + def _set_followee(self, promise): + # type: (Promise) -> None + assert self._is_following + assert not isinstance(self._rejection_handler0, Promise) + self._rejection_handler0 = promise + + def _settle_promises(self): + # type: () -> None + length = self._length + if length > 0: + if self._state == STATE_REJECTED: + reason = self._fulfillment_handler0 + traceback = self._traceback + self._settle_promise0(self._rejection_handler0, reason, traceback) + self._reject_promises(length, reason) + else: + value = self._rejection_handler0 + self._settle_promise0(self._fulfillment_handler0, value, None) + self._fulfill_promises(length, value) + + self._length = 0 + + def _resolve_from_executor(self, executor): + # type: (Callable[[Callable[[T], None], Callable[[Exception], None]], None]) -> None + synchronous = True + + def resolve(value: T) -> None: + self._resolve_callback(value) + + def reject( + reason: Exception, traceback: Optional[TracebackType] = None + ) -> None: + self._reject_callback(reason, synchronous, traceback) + + error = None + traceback = None + try: + executor(resolve, reject) + except Exception as e: + traceback = exc_info()[2] + error = e + + synchronous = False + + if error is not None: + self._reject_callback(error, True, traceback) + + @classmethod + def wait(cls, promise, timeout=None): + # type: (Promise, Optional[float]) -> None + async_instance.wait(promise, timeout) + + def _wait(self, timeout=None): + # type: (Optional[float]) -> None + self.wait(self, timeout) + + def get(self, timeout=None): + # type: (Optional[float]) -> T + self._wait(timeout or DEFAULT_TIMEOUT) + return self._target_settled_value(_raise=True) + + def _target_settled_value(self, _raise=False): + # type: (bool) -> Any + with _state_lock: + return self._target()._settled_value(_raise) + + _value = _reason = _target_settled_value + value = reason = property(_target_settled_value) + + def __repr__(self): + # type: () -> str + hex_id = hex(id(self)) + if self._is_following: + return "".format(hex_id, self._target()) + state = self._state + if state == STATE_PENDING: + return "".format(hex_id) + elif state == STATE_FULFILLED: + return "".format( + hex_id, repr(self._rejection_handler0) + ) + elif state == STATE_REJECTED: + return "".format( + hex_id, repr(self._fulfillment_handler0) + ) + + return "" + + @property + def is_pending(self): + # type: (Promise) -> bool + """Indicate whether the Promise is still pending. Could be wrong the moment the function returns.""" + return self._target()._state == STATE_PENDING + + @property + def is_fulfilled(self): + # type: (Promise) -> bool + """Indicate whether the Promise has been fulfilled. Could be wrong the moment the function returns.""" + return self._target()._state == STATE_FULFILLED + + @property + def is_rejected(self): + # type: (Promise) -> bool + """Indicate whether the Promise has been rejected. Could be wrong the moment the function returns.""" + return self._target()._state == STATE_REJECTED + + def catch(self, on_rejection): + # type: (Promise, Callable[[Exception], Any]) -> Promise + """ + This method returns a Promise and deals with rejected cases only. + It behaves the same as calling Promise.then(None, on_rejection). + """ + return self.then(None, on_rejection) + + def _then( + self, + did_fulfill=None, # type: Optional[Callable[[T], S]] + did_reject=None, # type: Optional[Callable[[Exception], S]] + ) -> "Promise[S]": + promise = self.__class__() # type: Promise + with _state_lock: + target = self._target() + state = target._state + if state == STATE_PENDING: + target._add_callbacks(did_fulfill, did_reject, promise) + + if state != STATE_PENDING: + traceback = None + if state == STATE_FULFILLED: + value = target._rejection_handler0 + handler = did_fulfill + elif state == STATE_REJECTED: + value = target._fulfillment_handler0 + traceback = target._traceback + handler = did_reject # type: ignore + async_instance.invoke( + partial(target._settle_promise, promise, handler, value, traceback), + promise.scheduler + # target._settle_promise instead? + # settler, + # target, + ) + + return promise + + fulfill = _resolve_callback + do_resolve = _resolve_callback + do_reject = _reject_callback + + def then(self, did_fulfill=None, did_reject=None): + # type: (Promise, Callable[[T], S], Optional[Callable[[Exception], S]]) -> Promise[S] + """ + This method takes two optional arguments. The first argument + is used if the "self promise" is fulfilled and the other is + used if the "self promise" is rejected. In either case, this + method returns another promise that effectively represents + the result of either the first of the second argument (in the + case that the "self promise" is fulfilled or rejected, + respectively). + Each argument can be either: + * None - Meaning no action is taken + * A function - which will be called with either the value + of the "self promise" or the reason for rejection of + the "self promise". The function may return: + * A value - which will be used to fulfill the promise + returned by this method. + * A promise - which, when fulfilled or rejected, will + cascade its value or reason to the promise returned + by this method. + * A value - which will be assigned as either the value + or the reason for the promise returned by this method + when the "self promise" is either fulfilled or rejected, + respectively. + :type success: (Any) -> object + :type failure: (Any) -> object + :rtype : Promise + """ + return self._then(did_fulfill, did_reject) + + def done(self, did_fulfill=None, did_reject=None): + # type: (Optional[Callable], Optional[Callable]) -> None + promise = self._then(did_fulfill, did_reject) + promise._is_final = True + + def done_all(self, handlers=None): + # type: (Promise, Optional[List[Union[Dict[str, Optional[Callable]], Tuple[Callable, Callable], Callable]]]) -> None + """ + :type handlers: list[(Any) -> object] | list[((Any) -> object, (Any) -> object)] + """ + if not handlers: + return + + for handler in handlers: + if isinstance(handler, tuple): + s, f = handler + self.done(s, f) + elif isinstance(handler, dict): + s = handler.get("success") # type: ignore + f = handler.get("failure") # type: ignore + + self.done(s, f) + else: + self.done(handler) + + def then_all(self, handlers=None): + # type: (Promise, List[Callable]) -> List[Promise] + """ + Utility function which calls 'then' for each handler provided. Handler can either + be a function in which case it is used as success handler, or a tuple containing + the success and the failure handler, where each of them could be None. + :type handlers: list[(Any) -> object] | list[((Any) -> object, (Any) -> object)] + :param handlers + :rtype : list[Promise] + """ + if not handlers: + return [] + + promises = [] # type: List[Promise] + + for handler in handlers: + if isinstance(handler, tuple): + s, f = handler + + promises.append(self.then(s, f)) + elif isinstance(handler, dict): + s = handler.get("success") + f = handler.get("failure") + + promises.append(self.then(s, f)) + else: + promises.append(self.then(handler)) + + return promises + + @classmethod + def _try_convert_to_promise(cls, obj): + # type: (Any) -> Promise + _type = obj.__class__ + if issubclass(_type, Promise): + if cls is not Promise: + return cls(obj.then, obj._scheduler) + return obj + + if iscoroutine(obj): # type: ignore + obj = ensure_future(obj) # type: ignore + _type = obj.__class__ + + if is_future_like(_type): + + def executor(resolve, reject): + # type: (Callable, Callable) -> None + if obj.done(): + _process_future_result(resolve, reject)(obj) + else: + obj.add_done_callback(_process_future_result(resolve, reject)) + + promise = cls(executor) # type: Promise + promise._future = obj + return promise + + return obj + + @classmethod + def reject(cls, reason): + # type: (Exception) -> Promise + ret = cls() # type: Promise + ret._reject_callback(reason, True) + return ret + + rejected = reject + + @classmethod + def resolve(cls, obj): + # type: (T) -> Promise[T] + if not cls.is_thenable(obj): + ret = cls() # type: Promise + ret._state = STATE_FULFILLED + ret._rejection_handler0 = obj + return ret + + return cls._try_convert_to_promise(obj) + + cast = resolve + fulfilled = cast + + @classmethod + def promisify(cls, f): + # type: (Callable) -> Callable[..., Promise] + + @wraps(f) + def wrapper(*args, **kwargs): + # type: (*Any, **Any) -> Promise + def executor(resolve, reject): + # type: (Callable, Callable) -> Optional[Any] + return resolve(f(*args, **kwargs)) + + return cls(executor) + + return wrapper + + _safe_resolved_promise = None # type: Promise + + @classmethod + def safe(cls, fn): + # type: (Callable) -> Callable + from functools import wraps + + if not cls._safe_resolved_promise: + cls._safe_resolved_promise = Promise.resolve(None) + + @wraps(fn) + def wrapper(*args, **kwargs): + # type: (*Any, **Any) -> Promise + return cls._safe_resolved_promise.then(lambda v: fn(*args, **kwargs)) + + return wrapper + + @classmethod + def all(cls, promises): + # type: (Any) -> Promise + return PromiseList(promises, promise_class=cls).promise + + @classmethod + def for_dict(cls, m): + # type: (Dict[Hashable, Promise[S]]) -> Promise[Dict[Hashable, S]] + """ + A special function that takes a dictionary of promises + and turns them into a promise for a dictionary of values. + In other words, this turns an dictionary of promises for values + into a promise for a dictionary of values. + """ + dict_type = type(m) # type: Type[Dict] + + if not m: + return cls.resolve(dict_type()) # type: ignore + + def handle_success(resolved_values): + # type: (List[S]) -> Dict[Hashable, S] + return dict_type(zip(m.keys(), resolved_values)) + + return cls.all(m.values()).then(handle_success) + + @classmethod + def is_thenable(cls, obj): + # type: (Any) -> bool + """ + A utility function to determine if the specified + object is a promise using "duck typing". + """ + _type = obj.__class__ + if obj is None or _type in BASE_TYPES: + return False + + return ( + issubclass(_type, Promise) + or iscoroutine(obj) # type: ignore + or is_future_like(_type) + ) + + +_type_done_callbacks = WeakKeyDictionary() # type: MutableMapping[type, bool] + + +def is_future_like(_type): + # type: (type) -> bool + if _type not in _type_done_callbacks: + _type_done_callbacks[_type] = callable( + getattr(_type, "add_done_callback", None) + ) + return _type_done_callbacks[_type] + + +promisify = Promise.promisify +promise_for_dict = Promise.for_dict +is_thenable = Promise.is_thenable + + +def _process_future_result(resolve, reject): + # type: (Callable, Callable) -> Callable + def handle_future_result(future): + # type: (Any) -> None + try: + resolve(future.result()) + except Exception as e: + tb = exc_info()[2] + reject(e, tb) + + return handle_future_result diff --git a/strawberry/promise/promise_list.py b/strawberry/promise/promise_list.py new file mode 100644 index 0000000000..01b7b2f1f5 --- /dev/null +++ b/strawberry/promise/promise_list.py @@ -0,0 +1,168 @@ +# Copy of the Promise library (https://github.com/syrusakbary/promise) with some +# modifications. +# +# Promise is licensed under the terms of the MIT license, reproduced below. +# +# = = = = = +# +# The MIT License (MIT) +# +# Copyright (c) 2016 Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from collections.abc import Iterable +from functools import partial +from types import TracebackType +from typing import Any, Collection, Optional, Type, Union + + +if False: + from .promise import Promise + + +class PromiseList(object): + + __slots__ = ("_values", "_length", "_total_resolved", "promise", "_promise_class") + + def __init__(self, values, promise_class): + # type: (Union[Collection, Promise[Collection]], Type[Promise]) -> None + self._promise_class = promise_class + self.promise = self._promise_class() + + self._length = 0 + self._total_resolved = 0 + self._values = None # type: Optional[Collection] + Promise = self._promise_class + if Promise.is_thenable(values): + values_as_promise = Promise._try_convert_to_promise( + values + )._target() # type: ignore + self._init_promise(values_as_promise) + else: + self._init(values) # type: ignore + + def __len__(self): + # type: () -> int + return self._length + + def _init_promise(self, values): + # type: (Promise[Collection]) -> None + if values.is_fulfilled: + values = values._value() + elif values.is_rejected: + self._reject(values._reason()) + return + + self.promise._is_async_guaranteed = True + values._then(self._init, self._reject) + return + + def _init(self, values): + # type: (Collection) -> None + self._values = values + if not isinstance(values, Iterable): + err = Exception( + "PromiseList requires an iterable. Received {}.".format(repr(values)) + ) + self.promise._reject_callback(err, False) + return + + if not values: + self._resolve([]) + return + + self._iterate(values) + return + + def _iterate(self, values): + # type: (Collection[Any]) -> None + Promise = self._promise_class + is_resolved = False + + self._length = len(values) + self._values = [None] * self._length + + result = self.promise + + for i, val in enumerate(values): + if Promise.is_thenable(val): + maybe_promise = Promise._try_convert_to_promise(val)._target() + # if is_resolved: + # # maybe_promise.suppressUnhandledRejections + # pass + if maybe_promise.is_pending: + maybe_promise._add_callbacks( + partial(self._promise_fulfilled, i=i), + partial(self._promise_rejected, promise=maybe_promise), + None, + ) + self._values[i] = maybe_promise + elif maybe_promise.is_fulfilled: + is_resolved = self._promise_fulfilled(maybe_promise._value(), i) + elif maybe_promise.is_rejected: + is_resolved = self._promise_rejected( + maybe_promise._reason(), promise=maybe_promise + ) + + else: + is_resolved = self._promise_fulfilled(val, i) + + if is_resolved: + break + + if not is_resolved: + result._is_async_guaranteed = True + + def _promise_fulfilled(self, value, i): + # type: (Any, int) -> bool + if self.is_resolved: + return False + self._values[i] = value # type: ignore + self._total_resolved += 1 + if self._total_resolved >= self._length: + self._resolve(self._values) # type: ignore + return True + return False + + def _promise_rejected(self, reason, promise): + # type: (Exception, Promise) -> bool + if self.is_resolved: + return False + self._total_resolved += 1 + self._reject(reason, traceback=promise._target()._traceback) + return True + + @property + def is_resolved(self): + # type: () -> bool + return self._values is None + + def _resolve(self, value): + # type: (Collection[Any]) -> None + assert not self.is_resolved + assert not isinstance(value, self._promise_class) + self._values = None + self.promise._fulfill(value) + + def _reject(self, reason, traceback=None): + # type: (Exception, Optional[TracebackType]) -> None + assert not self.is_resolved + self._values = None + self.promise._reject_callback(reason, False, traceback=traceback) diff --git a/tests/promise/__init__.py b/tests/promise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/promise/test_awaitable.py b/tests/promise/test_awaitable.py new file mode 100644 index 0000000000..90ad1ba810 --- /dev/null +++ b/tests/promise/test_awaitable.py @@ -0,0 +1,78 @@ +# Copy of the Promise library (https://github.com/syrusakbary/promise) with some +# modifications. +# +# Promise is licensed under the terms of the MIT license, reproduced below. +# +# = = = = = +# +# The MIT License (MIT) +# +# Copyright (c) 2016 Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from asyncio import Future, sleep + +from pytest import mark + +from strawberry.promise import Promise, is_thenable + + +@mark.asyncio +async def test_await(): + assert await Promise.resolve(True) + + +@mark.asyncio +async def test_promisify_coroutine(): + async def my_coroutine(): + await sleep(0.01) + return True + + assert await Promise.resolve(my_coroutine()) + + +@mark.asyncio +async def test_coroutine_is_thenable(): + async def my_coroutine(): + await sleep(0.01) + return True + + assert is_thenable(my_coroutine()) + + +@mark.asyncio +async def test_promisify_future(): + future = Future() + future.set_result(True) + assert await Promise.resolve(future) + + +@mark.asyncio +async def test_await_in_safe_promise(): + async def inner(): + @Promise.safe + def x(): + promise = Promise.resolve(True).then(lambda x: x) + return promise + + return await x() + + result = await inner() + assert result is True diff --git a/tests/promise/test_promise_list.py b/tests/promise/test_promise_list.py new file mode 100644 index 0000000000..1ea042a13d --- /dev/null +++ b/tests/promise/test_promise_list.py @@ -0,0 +1,98 @@ +# Copy of the Promise library (https://github.com/syrusakbary/promise) with some +# modifications. +# +# Promise is licensed under the terms of the MIT license, reproduced below. +# +# = = = = = +# +# The MIT License (MIT) +# +# Copyright (c) 2016 Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from pytest import raises + +from strawberry.promise import Promise +from strawberry.promise.promise_list import PromiseList + + +def all(promises): + return PromiseList(promises, Promise).promise + + +def test_empty_promises(): + all_promises = all([]) + assert all_promises.get() == [] + + +def test_bad_promises(): + all_promises = all(None) + + with raises(Exception) as exc_info: + all_promises.get() + + assert str(exc_info.value) == "PromiseList requires an iterable. Received None." + + +def test_promise_basic(): + all_promises = all([1, 2]) + assert all_promises.get() == [1, 2] + + +def test_promise_mixed(): + all_promises = all([1, 2, Promise.resolve(3)]) + assert all_promises.get() == [1, 2, 3] + + +def test_promise_rejected(): + e = Exception("Error") + all_promises = all([1, 2, Promise.reject(e)]) + + with raises(Exception) as exc_info: + all_promises.get() + + assert str(exc_info.value) == "Error" + + +def test_promise_reject_skip_all_other_values(): + e1 = Exception("Error1") + e2 = Exception("Error2") + all_promises = all([1, Promise.reject(e1), Promise.reject(e2)]) + + with raises(Exception) as exc_info: + all_promises.get() + + assert str(exc_info.value) == "Error1" + + +def test_promise_lazy_promise(): + p = Promise() + all_promises = all([1, 2, p]) + assert not all_promises.is_fulfilled + p.do_resolve(3) + assert all_promises.get() == [1, 2, 3] + + +def test_promise_contained_promise(): + p = Promise() + all_promises = all([1, 2, Promise.resolve(None).then(lambda v: p)]) + assert not all_promises.is_fulfilled + p.do_resolve(3) + assert all_promises.get() == [1, 2, 3] diff --git a/tests/promise/test_spec.py b/tests/promise/test_spec.py new file mode 100644 index 0000000000..78c5a2554e --- /dev/null +++ b/tests/promise/test_spec.py @@ -0,0 +1,611 @@ +# Copy of the Promise library (https://github.com/syrusakbary/promise) with some +# modifications. +# +# Promise is licensed under the terms of the MIT license, reproduced below. +# +# = = = = = +# +# The MIT License (MIT) +# +# Copyright (c) 2016 Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Tests the spec based on: +# https://github.com/promises-aplus/promises-tests + +from strawberry.promise import Promise + +from .utils import assert_exception + + +class Counter: + """ + A helper class with some side effects + we can test. + """ + + def __init__(self): + self.count = 0 + + def tick(self): + self.count += 1 + + def value(self): + return self.count + + +def test_3_2_1(): + """ + Test that the arguments to 'then' are optional. + """ + + p1 = Promise() + p1.then() + p3 = Promise() + p3.then() + p1.do_resolve(5) + p3.do_reject(Exception("How dare you!")) + + +def test_3_2_1_1(): + """ + That that the first argument to 'then' is ignored if it + is not a function. + """ + results = {} + nonFunctions = [None, False, 5, {}, []] + + def testNonFunction(nonFunction): + def foo(k, r): + results[k] = r + + p1 = Promise.reject(Exception("Error: " + str(nonFunction))) + p2 = p1.then(nonFunction, lambda r: foo(str(nonFunction), r)) + p2._wait() + + for v in nonFunctions: + testNonFunction(v) + + for v in nonFunctions: + assert_exception(results[str(v)], Exception, "Error: " + str(v)) + + +def test_3_2_1_2(): + """ + That that the second argument to 'then' is ignored if it + is not a function. + """ + results = {} + nonFunctions = [None, False, 5, {}, []] + + def testNonFunction(nonFunction): + def foo(k, r): + results[k] = r + + p1 = Promise.resolve("Error: " + str(nonFunction)) + p2 = p1.then(lambda r: foo(str(nonFunction), r), nonFunction) + p2._wait() + + for v in nonFunctions: + testNonFunction(v) + + for v in nonFunctions: + assert "Error: " + str(v) == results[str(v)] + + +def test_3_2_2_1(): + """ + The first argument to 'then' must be called when a promise is + fulfilled. + """ + + c = Counter() + + def check(v, c): + assert v == 5 + c.tick() + + p1 = Promise.resolve(5) + p2 = p1.then(lambda v: check(v, c)) + p2._wait() + assert 1 == c.value() + + +def test_3_2_2_2(): + """ + Make sure callbacks are never called more than once. + """ + + c = Counter() + p1 = Promise.resolve(5) + p2 = p1.then(lambda v: c.tick()) + p2._wait() + try: + # I throw an exception + p1.do_resolve(5) + raise AssertionError # Should not get here! + except AssertionError: + # This is expected + pass + assert 1 == c.value() + + +def test_3_2_2_3(): + """ + Make sure fulfilled callback never called if promise is rejected + """ + + cf = Counter() + cr = Counter() + p1 = Promise.reject(Exception("Error")) + p2 = p1.then(lambda v: cf.tick(), lambda r: cr.tick()) + p2._wait() + assert 0 == cf.value() + assert 1 == cr.value() + + +def test_3_2_3_1(): + """ + The second argument to 'then' must be called when a promise is + rejected. + """ + + c = Counter() + + def check(r, c): + assert_exception(r, Exception, "Error") + c.tick() + + p1 = Promise.reject(Exception("Error")) + p2 = p1.then(None, lambda r: check(r, c)) + p2._wait() + assert 1 == c.value() + + +def test_3_2_3_2(): + """ + Make sure callbacks are never called more than once. + """ + + c = Counter() + p1 = Promise.reject(Exception("Error")) + p2 = p1.then(None, lambda v: c.tick()) + p2._wait() + try: + # I throw an exception + p1.do_reject(Exception("Error")) + raise AssertionError # Should not get here! + except AssertionError: + # This is expected + pass + assert 1 == c.value() + + +def test_3_2_3_3(): + """ + Make sure rejected callback never called if promise is fulfilled + """ + + cf = Counter() + cr = Counter() + p1 = Promise.resolve(5) + p2 = p1.then(lambda v: cf.tick(), lambda r: cr.tick()) + p2._wait() + assert 0 == cr.value() + assert 1 == cf.value() + + +def test_3_2_5_1_when(): + """ + Then can be called multiple times on the same promise + and callbacks must be called in the order of the + then calls. + """ + + def add(ls, v): + ls.append(v) + + p1 = Promise.resolve(2) + order = [] + p2 = p1.then(lambda v: add(order, "p2")) + p3 = p1.then(lambda v: add(order, "p3")) + p2._wait() + p3._wait() + assert 2 == len(order) + assert "p2" == order[0] + assert "p3" == order[1] + + +def test_3_2_5_1_if(): + """ + Then can be called multiple times on the same promise + and callbacks must be called in the order of the + then calls. + """ + + def add(ls, v): + ls.append(v) + + p1 = Promise.resolve(2) + order = [] + p2 = p1.then(lambda v: add(order, "p2")) + p3 = p1.then(lambda v: add(order, "p3")) + p2._wait() + p3._wait() + assert 2 == len(order) + assert "p2" == order[0] + assert "p3" == order[1] + + +def test_3_2_5_2_when(): + """ + Then can be called multiple times on the same promise + and callbacks must be called in the order of the + then calls. + """ + + def add(ls, v): + ls.append(v) + + p1 = Promise.reject(Exception("Error")) + order = [] + p2 = p1.then(None, lambda v: add(order, "p2")) + p3 = p1.then(None, lambda v: add(order, "p3")) + p2._wait() + p3._wait() + assert 2 == len(order) + assert "p2" == order[0] + assert "p3" == order[1] + + +def test_3_2_5_2_if(): + """ + Then can be called multiple times on the same promise + and callbacks must be called in the order of the + then calls. + """ + + def add(ls, v): + ls.append(v) + + p1 = Promise.reject(Exception("Error")) + order = [] + p2 = p1.then(None, lambda v: add(order, "p2")) + p3 = p1.then(None, lambda v: add(order, "p3")) + p2._wait() + p3._wait() + assert 2 == len(order) + assert "p2" == order[0] + assert "p3" == order[1] + + +def test_3_2_6_1(): + """ + Promises returned by then must be fulfilled when the promise + they are chained from is fulfilled IF the fulfillment value + is not a promise. + """ + + p1 = Promise.resolve(5) + pf = p1.then(lambda v: v * v) + assert pf.get() == 25 + + p2 = Promise.reject(Exception("Error")) + pr = p2.then(None, lambda r: 5) + assert 5 == pr.get() + + +def test_3_2_6_2_when(): + """ + Promises returned by then must be rejected when any of their + callbacks throw an exception. + """ + + def fail(v): + raise AssertionError("Exception Message") + + p1 = Promise.resolve(5) + pf = p1.then(fail) + pf._wait() + assert pf.is_rejected + assert_exception(pf.reason, AssertionError, "Exception Message") + + p2 = Promise.reject(Exception("Error")) + pr = p2.then(None, fail) + pr._wait() + assert pr.is_rejected + assert_exception(pr.reason, AssertionError, "Exception Message") + + +def test_3_2_6_2_if(): + """ + Promises returned by then must be rejected when any of their + callbacks throw an exception. + """ + + def fail(v): + raise AssertionError("Exception Message") + + p1 = Promise.resolve(5) + pf = p1.then(fail) + pf._wait() + assert pf.is_rejected + assert_exception(pf.reason, AssertionError, "Exception Message") + + p2 = Promise.reject(Exception("Error")) + pr = p2.then(None, fail) + pr._wait() + assert pr.is_rejected + assert_exception(pr.reason, AssertionError, "Exception Message") + + +def test_3_2_6_3_when_fulfilled(): + """ + Testing return of pending promises to make + sure they are properly chained. + This covers the case where the root promise + is fulfilled after the chaining is defined. + """ + + p1 = Promise() + pending = Promise() + + def p1_resolved(v): + return pending + + pf = p1.then(p1_resolved) + + assert pending.is_pending + assert pf.is_pending + p1.do_resolve(10) + pending.do_resolve(5) + pending._wait() + assert pending.is_fulfilled + assert 5 == pending.get() + pf._wait() + assert pf.is_fulfilled + assert 5 == pf.get() + + p2 = Promise() + bad = Promise() + pr = p2.then(lambda r: bad) + assert bad.is_pending + assert pr.is_pending + p2.do_resolve(10) + bad._reject_callback(Exception("Error")) + bad._wait() + assert bad.is_rejected + assert_exception(bad.reason, Exception, "Error") + pr._wait() + assert pr.is_rejected + assert_exception(pr.reason, Exception, "Error") + + +def test_3_2_6_3_if_fulfilled(): + """ + Testing return of pending promises to make + sure they are properly chained. + This covers the case where the root promise + is fulfilled before the chaining is defined. + """ + + p1 = Promise() + p1.do_resolve(10) + pending = Promise() + pending.do_resolve(5) + pf = p1.then(lambda r: pending) + pending._wait() + assert pending.is_fulfilled + assert 5 == pending.get() + pf._wait() + assert pf.is_fulfilled + assert 5 == pf.get() + + p2 = Promise() + p2.do_resolve(10) + bad = Promise() + bad.do_reject(Exception("Error")) + pr = p2.then(lambda r: bad) + bad._wait() + assert_exception(bad.reason, Exception, "Error") + pr._wait() + assert pr.is_rejected + assert_exception(pr.reason, Exception, "Error") + + +def test_3_2_6_3_when_rejected(): + """ + Testing return of pending promises to make + sure they are properly chained. + This covers the case where the root promise + is rejected after the chaining is defined. + """ + + p1 = Promise() + pending = Promise() + pr = p1.then(None, lambda r: pending) + assert pending.is_pending + assert pr.is_pending + p1.do_reject(Exception("Error")) + pending.do_resolve(10) + pending._wait() + assert pending.is_fulfilled + assert 10 == pending.get() + assert 10 == pr.get() + + p2 = Promise() + bad = Promise() + pr = p2.then(None, lambda r: bad) + assert bad.is_pending + assert pr.is_pending + p2.do_reject(Exception("Error")) + bad.do_reject(Exception("Assertion")) + bad._wait() + assert bad.is_rejected + assert_exception(bad.reason, Exception, "Assertion") + pr._wait() + assert pr.is_rejected + assert_exception(pr.reason, Exception, "Assertion") + + +def test_3_2_6_3_if_rejected(): + """ + Testing return of pending promises to make + sure they are properly chained. + This covers the case where the root promise + is rejected before the chaining is defined. + """ + + p1 = Promise() + p1.do_reject(Exception("Error")) + pending = Promise() + pending.do_resolve(10) + pr = p1.then(None, lambda r: pending) + pending._wait() + assert pending.is_fulfilled + assert 10 == pending.get() + pr._wait() + assert pr.is_fulfilled + assert 10 == pr.get() + + p2 = Promise() + p2.do_reject(Exception("Error")) + bad = Promise() + bad.do_reject(Exception("Assertion")) + pr = p2.then(None, lambda r: bad) + bad._wait() + assert bad.is_rejected + assert_exception(bad.reason, Exception, "Assertion") + pr._wait() + assert pr.is_rejected + assert_exception(pr.reason, Exception, "Assertion") + + +def test_3_2_6_4_pending(): + """ + Handles the case where the arguments to then + are not functions or promises. + """ + p1 = Promise() + p2 = p1.then(5) + p1.do_resolve(10) + assert 10 == p1.get() + p2._wait() + assert p2.is_fulfilled + assert 10 == p2.get() + + +def test_3_2_6_4_fulfilled(): + """ + Handles the case where the arguments to then + are values, not functions or promises. + """ + p1 = Promise() + p1.do_resolve(10) + p2 = p1.then(5) + assert 10 == p1.get() + p2._wait() + assert p2.is_fulfilled + assert 10 == p2.get() + + +def test_3_2_6_5_pending(): + """ + Handles the case where the arguments to then + are values, not functions or promises. + """ + p1 = Promise() + p2 = p1.then(None, 5) + p1.do_reject(Exception("Error")) + assert_exception(p1.reason, Exception, "Error") + p2._wait() + assert p2.is_rejected + assert_exception(p2.reason, Exception, "Error") + + +def test_3_2_6_5_rejected(): + """ + Handles the case where the arguments to then + are values, not functions or promises. + """ + p1 = Promise() + p1.do_reject(Exception("Error")) + p2 = p1.then(None, 5) + assert_exception(p1.reason, Exception, "Error") + p2._wait() + assert p2.is_rejected + assert_exception(p2.reason, Exception, "Error") + + +def test_chained_promises(): + """ + Handles the case where the arguments to then + are values, not functions or promises. + """ + p1 = Promise(lambda resolve, reject: resolve(Promise.resolve(True))) + assert p1.get() is True + + +def test_promise_resolved_after(): + """ + The first argument to 'then' must be called when a promise is + fulfilled. + """ + + c = Counter() + + def check(v, c): + assert v == 5 + c.tick() + + p1 = Promise() + p2 = p1.then(lambda v: check(v, c)) + p1.do_resolve(5) + Promise.wait(p2) + + assert 1 == c.value() + + +def test_promise_follows_indifentely(): + a = Promise.resolve(None) + b = a.then(lambda x: Promise.resolve("X")) + + def b_then(v): + + c = Promise.resolve(None) + d = c.then(lambda v: Promise.resolve("B")) + return d + + promise = b.then(b_then) + + assert promise.get() == "B" + + +def test_promise_all_follows_indifentely(): + promises = Promise.all( + [ + Promise.resolve("A"), + Promise.resolve(None) + .then(Promise.resolve) + .then(lambda v: Promise.resolve(None).then(lambda v: Promise.resolve("B"))), + ] + ) + + assert promises.get() == ["A", "B"] diff --git a/tests/promise/test_thread_safety.py b/tests/promise/test_thread_safety.py new file mode 100644 index 0000000000..cd5da214bc --- /dev/null +++ b/tests/promise/test_thread_safety.py @@ -0,0 +1,240 @@ +# Copy of the Promise library (https://github.com/syrusakbary/promise) with some +# modifications. +# +# Promise is licensed under the terms of the MIT license, reproduced below. +# +# = = = = = +# +# The MIT License (MIT) +# +# Copyright (c) 2016 Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import threading + +import pytest + +from strawberry.promise import Promise + + +def test_promise_thread_safety(): + """ + Promise tasks should never be executed in a different thread from the one + they are scheduled from, unless the ThreadPoolExecutor is used. + + Here we assert that the pending promise tasks on thread 1 are not executed on + thread 2 as thread 2 resolves its own promise tasks. + """ + event_1 = threading.Event() + event_2 = threading.Event() + + assert_object = {"is_same_thread": True} + + def task_1(): + thread_name = threading.current_thread().getName() + + def then_1(value): + # Enqueue tasks to run later. + # This relies on the fact that `then` does not execute the function + # synchronously when called from within another `then` callback function. + promise = Promise.resolve(None).then(then_2) + assert promise.is_pending + event_1.set() # Unblock main thread + event_2.wait() # Wait for thread 2 + + def then_2(value): + assert_object["is_same_thread"] = ( + thread_name == threading.current_thread().getName() + ) + + Promise.resolve(None).then(then_1) + + def task_2(): + promise = Promise.resolve(None).then(lambda v: None) + promise.get() # Drain task queue + event_2.set() # Unblock thread 1 + + thread_1 = threading.Thread(target=task_1) + thread_1.start() + + event_1.wait() # Wait for Thread 1 to enqueue promise tasks + + thread_2 = threading.Thread(target=task_2) + thread_2.start() + + for thread in (thread_1, thread_2): + thread.join() + + assert assert_object["is_same_thread"] + + +# def test_dataloader_thread_safety(): +# """ +# Dataloader should only batch `load` calls that happened on the same thread. + +# Here we assert that `load` calls on thread 2 are not batched on thread 1 as +# thread 1 batches its own `load` calls. +# """ +# def load_many(keys): +# thead_name = threading.current_thread().getName() +# return Promise.resolve([thead_name for key in keys]) + +# thread_name_loader = DataLoader(load_many) + +# event_1 = threading.Event() +# event_2 = threading.Event() +# event_3 = threading.Event() + +# assert_object = { +# 'is_same_thread_1': True, +# 'is_same_thread_2': True, +# } + +# def task_1(): +# @Promise.safe +# def do(): +# promise = thread_name_loader.load(1) +# event_1.set() +# event_2.wait() # Wait for thread 2 to call `load` +# assert_object['is_same_thread_1'] = ( +# promise.get() == threading.current_thread().getName() +# ) +# event_3.set() # Unblock thread 2 + +# do().get() + +# def task_2(): +# @Promise.safe +# def do(): +# promise = thread_name_loader.load(2) +# event_2.set() +# event_3.wait() # Wait for thread 1 to run `dispatch_queue_batch` +# assert_object['is_same_thread_2'] = ( +# promise.get() == threading.current_thread().getName() +# ) + +# do().get() + +# thread_1 = threading.Thread(target=task_1) +# thread_1.start() + +# event_1.wait() # Wait for thread 1 to call `load` + +# thread_2 = threading.Thread(target=task_2) +# thread_2.start() + +# for thread in (thread_1, thread_2): +# thread.join() + +# assert assert_object['is_same_thread_1'] +# assert assert_object['is_same_thread_2'] + + +@pytest.mark.parametrize("num_threads", [1]) +@pytest.mark.parametrize("count", [10000]) +def test_with_process_loop(num_threads, count): + """ + Start a Promise in one thread, but resolve it in another. + """ + import queue + from random import randint + from sys import setswitchinterval + from threading import Barrier, Thread + from traceback import format_exc, print_exc + + test_with_process_loop._force_stop = False + items = queue.Queue() + barrier = Barrier(num_threads) + + asserts = [] + timeouts = [] + + def event_loop(): + stop_count = num_threads + while True: + item = items.get() + if item[0] == "STOP": + stop_count -= 1 + if stop_count == 0: + break + if item[0] == "ABORT": + break + if item[0] == "ITEM": + (_, resolve, reject, i) = item + random_integer = randint(0, 100) + # 25% chances per each + if 0 <= random_integer < 25: + # nested rejected promise + resolve(Promise.rejected(ZeroDivisionError(i))) + elif 25 <= random_integer < 50: + # nested resolved promise + resolve(Promise.resolve(i)) + elif 50 <= random_integer < 75: + # plain resolve + resolve(i) + else: + # plain reject + reject(ZeroDivisionError(i)) + + def worker(): + barrier.wait() + # Force fast switching of threads, this is NOT used in real world case. + # However without this I was unable to reproduce the issue. + setswitchinterval(0.000001) + for i in range(0, count): + if test_with_process_loop._force_stop: + break + + def do(resolve, reject): + items.put(("ITEM", resolve, reject, i)) + + p = Promise(do) + try: + p.get(timeout=1) + except ZeroDivisionError: + pass + except AssertionError as e: + print("ASSERT", e) + print_exc() + test_with_process_loop._force_stop = True + items.put(("ABORT",)) + asserts.append(format_exc()) + except Exception as e: + print("Timeout", e) + print_exc() + test_with_process_loop._force_stop = True + items.put(("ABORT",)) + timeouts.append(format_exc()) + + items.put(("STOP",)) + + loop_thread = Thread(target=event_loop) + loop_thread.start() + + worker_threads = [Thread(target=worker) for i in range(0, num_threads)] + for t in worker_threads: + t.start() + + loop_thread.join() + for t in worker_threads: + t.join() + + assert asserts == [] + assert timeouts == [] diff --git a/tests/promise/utils.py b/tests/promise/utils.py new file mode 100644 index 0000000000..73cd18c663 --- /dev/null +++ b/tests/promise/utils.py @@ -0,0 +1,32 @@ +# Copy of the Promise library (https://github.com/syrusakbary/promise) with some +# modifications. +# +# Promise is licensed under the terms of the MIT license, reproduced below. +# +# = = = = = +# +# The MIT License (MIT) +# +# Copyright (c) 2016 Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +def assert_exception(exception, expected_exception_cls, expected_message): + assert isinstance(exception, expected_exception_cls) + assert str(exception) == expected_message