diff --git a/examples/caching_requestor.py b/examples/caching_requestor.py index 6d0dc84..7d04ce3 100755 --- a/examples/caching_requestor.py +++ b/examples/caching_requestor.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -"""This example shows how simple in-memory caching can be used. +"""Example program that shows how simple in-memory caching can be used. Demonstrates the use of custom sessions with :class:`.Requestor`. It's an adaptation of ``read_only_auth_trophies.py``. @@ -28,10 +28,9 @@ class CachingSession(requests.Session): def request(self, method, url, params=None, **kwargs): """Perform a request, or return a cached response if available.""" params_key = tuple(params.items()) if params else () - if method.upper() == "GET": - if (url, params_key) in self.get_cache: - print("Returning cached response for:", method, url, params) - return self.get_cache[(url, params_key)] + if method.upper() == "GET" and (url, params_key) in self.get_cache: + print("Returning cached response for:", method, url, params) + return self.get_cache[(url, params_key)] result = super().request(method, url, params, **kwargs) if method.upper() == "GET": self.get_cache[(url, params_key)] = result @@ -45,9 +44,7 @@ def main(): print(f"Usage: {sys.argv[0]} USERNAME") return 1 - caching_requestor = prawcore.Requestor( - "prawcore_device_id_auth_example", session=CachingSession() - ) + caching_requestor = prawcore.Requestor("prawcore_device_id_auth_example", session=CachingSession()) authenticator = prawcore.TrustedAuthenticator( caching_requestor, os.environ["PRAWCORE_CLIENT_ID"], diff --git a/examples/device_id_auth_trophies.py b/examples/device_id_auth_trophies.py index 023ed2a..3110018 100755 --- a/examples/device_id_auth_trophies.py +++ b/examples/device_id_auth_trophies.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -"""This example outputs a user's list of trophies. +"""Example program that outputs a user's list of trophies. This program demonstrates the use of ``prawcore.DeviceIDAuthorizer``. diff --git a/examples/obtain_refresh_token.py b/examples/obtain_refresh_token.py index 0ae6f46..a0faa27 100755 --- a/examples/obtain_refresh_token.py +++ b/examples/obtain_refresh_token.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -"""This example demonstrates the flow for retrieving a refresh token. +"""Example program that demonstrates the flow for retrieving a refresh token. In order for this example to work your application's redirect URI must be set to http://localhost:8080. @@ -36,7 +36,7 @@ def receive_connection(): def send_message(client, message): """Send message to client and close the connection.""" print(message) - client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode("utf-8")) + client.send(f"HTTP/1.1 200 OK\r\n\r\n{message}".encode()) client.close() @@ -53,16 +53,14 @@ def main(): "http://localhost:8080", ) - state = str(random.randint(0, 65000)) + state = str(random.randint(0, 65000)) # noqa: S311 url = authenticator.authorize_url("permanent", sys.argv[1:], state) print(url) client = receive_connection() data = client.recv(1024).decode("utf-8") param_tokens = data.split(" ", 2)[1].split("?", 1)[1].split("&") - params = { - key: value for (key, value) in [token.split("=") for token in param_tokens] - } + params = dict([token.split("=") for token in param_tokens]) if state != params["state"]: send_message( @@ -70,7 +68,7 @@ def main(): f"State mismatch. Expected: {state} Received: {params['state']}", ) return 1 - elif "error" in params: + if "error" in params: send_message(client, params["error"]) return 1 diff --git a/examples/read_only_auth_trophies.py b/examples/read_only_auth_trophies.py index 3d71d96..8d2e6c5 100755 --- a/examples/read_only_auth_trophies.py +++ b/examples/read_only_auth_trophies.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -"""This example outputs a user's list of trophies. +"""Example program that outputs a user's list of trophies. This program demonstrates the use of ``prawcore.ReadOnlyAuthorizer`` that does not require an access token to make authenticated requests to Reddit. diff --git a/prawcore/auth.py b/prawcore/auth.py index 101daf4..e78e629 100644 --- a/prawcore/auth.py +++ b/prawcore/auth.py @@ -4,7 +4,7 @@ import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, cast from requests import Request from requests.status_codes import codes @@ -30,7 +30,7 @@ def __init__( requestor: Requestor, client_id: str, redirect_uri: str | None = None, - ): + ) -> None: """Represent a single authentication to Reddit's API. :param requestor: An instance of :class:`.Requestor`. @@ -46,9 +46,7 @@ def __init__( self.client_id = client_id self.redirect_uri = redirect_uri - def _post( - self, url: str, success_status: int = codes["ok"], **data: Any - ) -> Response: + def _post(self, *, url: str, **data: Any) -> Response: response = self._requestor.request( "post", url, @@ -56,13 +54,11 @@ def _post( data=sorted(data.items()), headers={"Connection": "close"}, ) - if response.status_code != success_status: + if response.status_code != codes["ok"]: raise ResponseException(response) return response - def authorize_url( - self, duration: str, scopes: list[str], state: str, implicit: bool = False - ) -> str: + def authorize_url(self, duration: str, scopes: list[str], state: str, implicit: bool = False) -> str: """Return the URL used out-of-band to grant access to your application. :param duration: Either ``"permanent"`` or ``"temporary"``. ``"temporary"`` @@ -90,9 +86,7 @@ def authorize_url( msg = "redirect URI not provided" raise InvalidInvocation(msg) if implicit and not isinstance(self, UntrustedAuthenticator): - msg = ( - "Only UntrustedAuthenticator instances can use the implicit grant flow." - ) + msg = "Only UntrustedAuthenticator instances can use the implicit grant flow." raise InvalidInvocation(msg) if implicit and duration != "temporary": msg = "The implicit grant flow only supports temporary access tokens." @@ -108,9 +102,9 @@ def authorize_url( } url = self._requestor.reddit_url + const.AUTHORIZATION_PATH request = Request("GET", url, params=params) - return request.prepare().url + return cast(str, request.prepare().url) - def revoke_token(self, token: str, token_type: str | None = None): + def revoke_token(self, token: str, token_type: str | None = None) -> None: """Ask Reddit to revoke the provided token. :param token: The access or refresh token to revoke. @@ -123,7 +117,7 @@ def revoke_token(self, token: str, token_type: str | None = None): if token_type is not None: data["token_type_hint"] = token_type url = self._requestor.reddit_url + const.REVOKE_TOKEN_PATH - self._post(url, **data) + self._post(url=url, **data) class BaseAuthorizer: @@ -131,7 +125,7 @@ class BaseAuthorizer: AUTHENTICATOR_CLASS: tuple | type = BaseAuthenticator - def __init__(self, authenticator: BaseAuthenticator): + def __init__(self, authenticator: BaseAuthenticator) -> None: """Represent a single authorization to Reddit's API. :param authenticator: An instance of :class:`.BaseAuthenticator`. @@ -152,13 +146,9 @@ def _request_token(self, **data: Any): response = self._authenticator._post(url=url, **data) payload = response.json() if "error" in payload: # Why are these OKAY responses? - raise OAuthException( - response, payload["error"], payload.get("error_description") - ) + raise OAuthException(response, payload["error"], payload.get("error_description")) - self._expiration_timestamp_ns = ( - pre_request_timestamp_ns + (payload["expires_in"] + 10) * const.NANOSECONDS - ) + self._expiration_timestamp_ns = pre_request_timestamp_ns + (payload["expires_in"] + 10) * const.NANOSECONDS self.access_token = payload["access_token"] if "refresh_token" in payload: self.refresh_token = payload["refresh_token"] @@ -170,9 +160,7 @@ def _validate_authenticator(self): if isinstance(self.AUTHENTICATOR_CLASS, type): msg += f" {self.AUTHENTICATOR_CLASS.__name__}." else: - msg += ( - f" {' or '.join([i.__name__ for i in self.AUTHENTICATOR_CLASS])}." - ) + msg += f" {' or '.join([i.__name__ for i in self.AUTHENTICATOR_CLASS])}." raise InvalidInvocation(msg) def is_valid(self) -> bool: @@ -182,12 +170,9 @@ def is_valid(self) -> bool: valid on the server side. """ - return ( - self.access_token is not None - and time.monotonic_ns() < self._expiration_timestamp_ns - ) + return self.access_token is not None and time.monotonic_ns() < self._expiration_timestamp_ns - def revoke(self): + def revoke(self) -> None: """Revoke the current Authorization.""" if self.access_token is None: msg = "no token available to revoke" @@ -208,7 +193,7 @@ def __init__( client_id: str, client_secret: str, redirect_uri: str | None = None, - ): + ) -> None: """Represent a single authentication to Reddit's API. :param requestor: An instance of :class:`.Requestor`. @@ -245,7 +230,7 @@ def __init__( post_refresh_callback: Callable[[Authorizer], None] | None = None, pre_refresh_callback: Callable[[Authorizer], None] | None = None, refresh_token: str | None = None, - ): + ) -> None: """Represent a single authorization to Reddit's API. :param authenticator: An instance of a subclass of :class:`.BaseAuthenticator`. @@ -267,7 +252,7 @@ def __init__( self._pre_refresh_callback = pre_refresh_callback self.refresh_token = refresh_token - def authorize(self, code: str): + def authorize(self, code: str) -> None: """Obtain and set authorization tokens based on ``code``. :param code: The code obtained by an out-of-band authorization request to @@ -283,20 +268,18 @@ def authorize(self, code: str): redirect_uri=self._authenticator.redirect_uri, ) - def refresh(self): + def refresh(self) -> None: """Obtain a new access token from the refresh_token.""" if self._pre_refresh_callback: self._pre_refresh_callback(self) if self.refresh_token is None: msg = "refresh token not provided" raise InvalidInvocation(msg) - self._request_token( - grant_type="refresh_token", refresh_token=self.refresh_token - ) + self._request_token(grant_type="refresh_token", refresh_token=self.refresh_token) if self._post_refresh_callback: self._post_refresh_callback(self) - def revoke(self, only_access: bool = False): + def revoke(self, only_access: bool = False) -> None: """Revoke the current Authorization. :param only_access: When explicitly set to ``True``, do not evict the refresh @@ -325,7 +308,7 @@ def __init__( access_token: str, expires_in: int, scope: str, - ): + ) -> None: """Represent a single implicit authorization to Reddit's API. :param authenticator: An instance of :class:`.UntrustedAuthenticator`. @@ -341,9 +324,7 @@ def __init__( """ super().__init__(authenticator) - self._expiration_timestamp_ns = ( - time.monotonic_ns() + expires_in * const.NANOSECONDS - ) + self._expiration_timestamp_ns = time.monotonic_ns() + expires_in * const.NANOSECONDS self.access_token = access_token self.scopes = set(scope.split(" ")) @@ -362,7 +343,7 @@ def __init__( self, authenticator: BaseAuthenticator, scopes: list[str] | None = None, - ): + ) -> None: """Represent a ReadOnly authorization to Reddit's API. :param scopes: A list of OAuth scopes to request authorization for (default: @@ -372,7 +353,7 @@ def __init__( super().__init__(authenticator) self._scopes = scopes - def refresh(self): + def refresh(self) -> None: """Obtain a new ReadOnly access token.""" additional_kwargs = {} if self._scopes: @@ -397,7 +378,7 @@ def __init__( password: str | None, two_factor_callback: Callable | None = None, scopes: list[str] | None = None, - ): + ) -> None: """Represent a single personal-use authorization to Reddit's API. :param authenticator: An instance of :class:`.TrustedAuthenticator`. @@ -416,7 +397,7 @@ def __init__( self._two_factor_callback = two_factor_callback self._username = username - def refresh(self): + def refresh(self) -> None: """Obtain a new personal-use script type access token.""" additional_kwargs = {} if self._scopes: @@ -447,7 +428,7 @@ def __init__( authenticator: BaseAuthenticator, device_id: str | None = None, scopes: list[str] | None = None, - ): + ) -> None: """Represent an app-only OAuth2 authorization for 'installed' apps. :param authenticator: An instance of :class:`.UntrustedAuthenticator` or @@ -466,7 +447,7 @@ def __init__( self._device_id = device_id self._scopes = scopes - def refresh(self): + def refresh(self) -> None: """Obtain a new access token.""" additional_kwargs = {} if self._scopes: diff --git a/prawcore/exceptions.py b/prawcore/exceptions.py index f29e153..6fc4256 100644 --- a/prawcore/exceptions.py +++ b/prawcore/exceptions.py @@ -20,7 +20,7 @@ class InvalidInvocation(PrawcoreException): class OAuthException(PrawcoreException): """Indicate that there was an OAuth2 related error with the request.""" - def __init__(self, response: Response, error: str, description: str | None = None): + def __init__(self, response: Response, error: str, description: str | None = None) -> None: """Initialize a OAuthException instance. :param response: A ``requests.response`` instance. @@ -44,10 +44,8 @@ def __init__( self, original_exception: Exception, request_args: tuple[Any, ...], - request_kwargs: dict[ - str, bool | (dict[str, int] | (dict[str, str] | str)) | None - ], - ): + request_kwargs: dict[str, bool | (dict[str, int] | (dict[str, str] | str)) | None], + ) -> None: """Initialize a RequestException instance. :param original_exception: The original exception that occurred. @@ -64,7 +62,7 @@ def __init__( class ResponseException(PrawcoreException): """Indicate that there was an error with the completed HTTP request.""" - def __init__(self, response: Response): + def __init__(self, response: Response) -> None: """Initialize a ResponseException instance. :param response: A ``requests.response`` instance. @@ -110,7 +108,7 @@ class Redirect(ResponseException): """ - def __init__(self, response: Response): + def __init__(self, response: Response) -> None: """Initialize a Redirect exception instance. :param response: A ``requests.response`` instance containing a location header. @@ -121,8 +119,7 @@ def __init__(self, response: Response): self.response = response msg = f"Redirect to {self.path}" msg += ( - " (You may be trying to perform a non-read-only action via a " - "read-only instance.)" + " (You may be trying to perform a non-read-only action via a read-only instance.)" if "/login/" in self.path else "" ) @@ -136,7 +133,7 @@ class ServerError(ResponseException): class SpecialError(ResponseException): """Indicate syntax or spam-prevention issues.""" - def __init__(self, response: Response): + def __init__(self, response: Response) -> None: """Initialize a SpecialError exception instance. :param response: A ``requests.response`` instance containing a message and a @@ -159,7 +156,7 @@ class TooLarge(ResponseException): class TooManyRequests(ResponseException): """Indicate that the user has sent too many requests in a given amount of time.""" - def __init__(self, response: Response): + def __init__(self, response: Response) -> None: """Initialize a TooManyRequests exception instance. :param response: A ``requests.response`` instance that may contain a retry-after @@ -172,10 +169,7 @@ def __init__(self, response: Response): msg = f"received {response.status_code} HTTP response" if self.retry_after: - msg += ( - f". Please wait at least {float(self.retry_after)} seconds before" - f" re-trying this request." - ) + msg += f". Please wait at least {float(self.retry_after)} seconds before re-trying this request." PrawcoreException.__init__(self, msg) diff --git a/prawcore/rate_limit.py b/prawcore/rate_limit.py index 20d9118..cf112ab 100644 --- a/prawcore/rate_limit.py +++ b/prawcore/rate_limit.py @@ -23,7 +23,7 @@ class RateLimiter: """ - def __init__(self, *, window_size: int): + def __init__(self, *, window_size: int) -> None: """Create an instance of the RateLimit class.""" self.remaining: int | None = None self.next_request_timestamp_ns: int | None = None @@ -52,20 +52,18 @@ def call( self.update(response.headers) return response - def delay(self): + def delay(self) -> None: """Sleep for an amount of time to remain under the rate limit.""" if self.next_request_timestamp_ns is None: return - sleep_seconds = ( - float(self.next_request_timestamp_ns - time.monotonic_ns()) / NANOSECONDS - ) + sleep_seconds = float(self.next_request_timestamp_ns - time.monotonic_ns()) / NANOSECONDS if sleep_seconds <= 0: return message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to call" log.debug(message) time.sleep(sleep_seconds) - def update(self, response_headers: Mapping[str, str]): + def update(self, response_headers: Mapping[str, str]) -> None: """Update the state of the rate limiter based on the response headers. This method should only be called following an HTTP request to Reddit. @@ -76,7 +74,7 @@ def update(self, response_headers: Mapping[str, str]): """ if "x-ratelimit-remaining" not in response_headers: - if self.remaining is not None: + if self.remaining is not None and self.used is not None: self.remaining -= 1 self.used += 1 return @@ -88,23 +86,16 @@ def update(self, response_headers: Mapping[str, str]): seconds_to_reset = int(response_headers["x-ratelimit-reset"]) if self.remaining <= 0: - self.next_request_timestamp_ns = now_ns + max( - NANOSECONDS, seconds_to_reset * NANOSECONDS - ) + self.next_request_timestamp_ns = now_ns + max(NANOSECONDS, seconds_to_reset * NANOSECONDS) return - self.next_request_timestamp_ns = ( + self.next_request_timestamp_ns = int( now_ns + min( seconds_to_reset, max( seconds_to_reset - - ( - self.window_size - - self.window_size - / (float(self.remaining) + self.used) - * self.used - ), + - (self.window_size - self.window_size / (float(self.remaining) + self.used) * self.used), 0, ), 10, diff --git a/prawcore/requestor.py b/prawcore/requestor.py index eaf6269..39ea966 100644 --- a/prawcore/requestor.py +++ b/prawcore/requestor.py @@ -10,13 +10,15 @@ from .exceptions import InvalidInvocation, RequestException if TYPE_CHECKING: - from requests.models import Response, Session + from requests import Response, Session class Requestor: """Requestor provides an interface to HTTP requests.""" - def __getattr__(self, attribute: str) -> Any: + MIN_USER_AGENT_LENGTH = 7 + + def __getattr__(self, attribute: str) -> object: """Pass all undefined attributes to the ``_http`` attribute.""" if attribute.startswith("__"): raise AttributeError @@ -29,7 +31,7 @@ def __init__( reddit_url: str = "https://www.reddit.com", session: Session | None = None, timeout: float = TIMEOUT, - ): + ) -> None: """Create an instance of the Requestor class. :param user_agent: The user-agent for your application. Please follow Reddit's @@ -47,7 +49,7 @@ def __init__( # Imported locally to avoid an import cycle, with __init__ from . import __version__ - if user_agent is None or len(user_agent) < 7: + if user_agent is None or len(user_agent) < self.MIN_USER_AGENT_LENGTH: msg = "user_agent is not descriptive" raise InvalidInvocation(msg) @@ -58,13 +60,11 @@ def __init__( self.reddit_url = reddit_url self.timeout = timeout - def close(self): + def close(self) -> None: """Call close on the underlying session.""" self._http.close() - def request( - self, *args: Any, timeout: float | None = None, **kwargs: Any - ) -> Response: + def request(self, *args: Any, timeout: float | None = None, **kwargs: Any) -> Response: """Issue the HTTP request capturing any errors that may occur.""" try: return self._http.request(*args, timeout=timeout or self.timeout, **kwargs) diff --git a/prawcore/sessions.py b/prawcore/sessions.py index dca273d..1692916 100644 --- a/prawcore/sessions.py +++ b/prawcore/sessions.py @@ -7,8 +7,9 @@ import time from abc import ABC, abstractmethod from copy import deepcopy +from dataclasses import dataclass from pprint import pformat -from typing import TYPE_CHECKING, Any, BinaryIO, TextIO +from typing import TYPE_CHECKING, BinaryIO, TextIO from urllib.parse import urljoin from requests.exceptions import ChunkedEncodingError, ConnectionError, ReadTimeout @@ -24,6 +25,7 @@ NotFound, Redirect, RequestException, + ResponseException, ServerError, SpecialError, TooLarge, @@ -36,6 +38,7 @@ if TYPE_CHECKING: from requests.models import Response + from typing_extensions import Self from .auth import Authorizer from .requestor import Requestor @@ -56,7 +59,15 @@ class RetryStrategy(ABC): def _sleep_seconds(self) -> float | None: pass - def sleep(self): + @abstractmethod + def consume_available_retry(self) -> RetryStrategy: + """Allow one fewer retry.""" + + @abstractmethod + def should_retry_on_failure(self) -> bool: + """Return True when a retry should occur.""" + + def sleep(self) -> None: """Sleep until we are ready to attempt the request.""" sleep_seconds = self._sleep_seconds() if sleep_seconds is not None: @@ -65,6 +76,29 @@ def sleep(self): time.sleep(sleep_seconds) +@dataclass(frozen=True) +class FiniteRetryStrategy(RetryStrategy): + """A ``RetryStrategy`` that retries requests a finite number of times.""" + + DEFAULT_RETRIES = 2 + + retries: int = DEFAULT_RETRIES + + def _sleep_seconds(self) -> float | None: + if self.retries < self.DEFAULT_RETRIES: + base = 0 if self.retries > 0 else 2 + return base + 2 * random.random() # noqa: S311 + return None + + def consume_available_retry(self) -> FiniteRetryStrategy: + """Allow one fewer retry.""" + return type(self)(retries=self.retries - 1) + + def should_retry_on_failure(self) -> bool: + """Return ``True`` if and only if the strategy will allow another retry.""" + return self.retries > 0 + + class Session: """The low-level connection interface to Reddit's API.""" @@ -104,11 +138,12 @@ class Session: @staticmethod def _log_request( - data: list[tuple[str, str]] | None, + *, + data: list[tuple[str, object]] | None, method: str, - params: dict[str, int], + params: dict[str, object], url: str, - ): + ) -> None: log.debug("Fetching: %s %s at %s", method, url, time.monotonic()) log.debug("Data: %s", pformat(data)) log.debug("Params: %s", pformat(params)) @@ -117,11 +152,11 @@ def _log_request( def _requestor(self) -> Requestor: return self._authorizer._authenticator._requestor - def __enter__(self) -> Session: # noqa: PYI034 + def __enter__(self) -> Self: """Allow this object to be used as a context manager.""" return self - def __exit__(self, *_args): + def __exit__(self, *_args) -> None: """Allow this object to be used as a context manager.""" self.close() @@ -129,7 +164,7 @@ def __init__( self, authorizer: BaseAuthorizer | None, window_size: int = WINDOW_SIZE, - ): + ) -> None: """Prepare the connection to Reddit's API. :param authorizer: An instance of :class:`.Authorizer`. @@ -145,127 +180,134 @@ def __init__( def _do_retry( self, - data: list[tuple[str, Any]], - files: dict[str, BinaryIO | TextIO], - json: dict[str, Any], + *, + data: list[tuple[str, object]] | None, + files: dict[str, BinaryIO | TextIO] | None, + json: dict[str, object] | None, method: str, - params: dict[str, int], - response: Response | None, + params: dict[str, object], retry_strategy_state: FiniteRetryStrategy, - saved_exception: Exception | None, + status: str, timeout: float, url: str, - ) -> dict[str, Any] | str | None: - status = repr(saved_exception) if saved_exception else response.status_code - log.warning("Retrying due to %s status: %s %s", status, method, url) + ) -> dict[str, object] | str | None: + log.warning("Retrying due to %s: %s %s", status, method, url) return self._request_with_retries( data=data, files=files, json=json, method=method, params=params, + retry_strategy_state=retry_strategy_state.consume_available_retry(), timeout=timeout, url=url, - retry_strategy_state=retry_strategy_state.consume_available_retry(), # noqa: E501 ) def _make_request( self, - data: list[tuple[str, Any]], - files: dict[str, BinaryIO | TextIO], - json: dict[str, Any], + data: list[tuple[str, object]] | None, + files: dict[str, BinaryIO | TextIO] | None, + json: dict[str, object] | None, method: str, - params: dict[str, Any], - retry_strategy_state: FiniteRetryStrategy, + params: dict[str, object], timeout: float, url: str, - ) -> tuple[Response, None] | tuple[None, Exception]: - try: - response = self._rate_limiter.call( - self._requestor.request, - self._set_header_callback, - method, - url, - allow_redirects=False, - data=data, - files=files, - json=json, - params=params, - timeout=timeout, - ) - log.debug( - "Response: %s (%s bytes) (rst-%s:rem-%s:used-%s ratelimit) at %s", - response.status_code, - response.headers.get("content-length"), - response.headers.get("x-ratelimit-reset"), - response.headers.get("x-ratelimit-remaining"), - response.headers.get("x-ratelimit-used"), - time.monotonic(), - ) - return response, None - except RequestException as exception: - if not retry_strategy_state.should_retry_on_failure() or not isinstance( # noqa: E501 - exception.original_exception, self.RETRY_EXCEPTIONS - ): - raise - return None, exception.original_exception + ) -> Response: + response = self._rate_limiter.call( + self._requestor.request, + self._set_header_callback, + method, + url, + allow_redirects=False, + data=data, + files=files, + json=json, + params=params, + timeout=timeout, + ) + log.debug( + "Response: %s (%s bytes) (rst-%s:rem-%s:used-%s ratelimit) at %s", + response.status_code, + response.headers.get("content-length"), + response.headers.get("x-ratelimit-reset"), + response.headers.get("x-ratelimit-remaining"), + response.headers.get("x-ratelimit-used"), + time.monotonic(), + ) + return response def _request_with_retries( self, - data: list[tuple[str, Any]], - files: dict[str, BinaryIO | TextIO], - json: dict[str, Any], + *, + data: list[tuple[str, object]] | None, + files: dict[str, BinaryIO | TextIO] | None, + json: dict[str, object] | None, method: str, - params: dict[str, Any], + params: dict[str, object], + retry_strategy_state: FiniteRetryStrategy | None = None, timeout: float, url: str, - retry_strategy_state: FiniteRetryStrategy | None = None, - ) -> dict[str, Any] | str | None: + ) -> dict[str, object] | str | None: if retry_strategy_state is None: retry_strategy_state = self._retry_strategy_class() retry_strategy_state.sleep() - self._log_request(data, method, params, url) - response, saved_exception = self._make_request( - data, - files, - json, - method, - params, - retry_strategy_state, - timeout, - url, - ) + self._log_request(data=data, method=method, params=params, url=url) - do_retry = False - if response is not None and response.status_code == codes["unauthorized"]: + try: + response = self._make_request( + data=data, + files=files, + json=json, + method=method, + params=params, + timeout=timeout, + url=url, + ) + except RequestException as exception: + if retry_strategy_state.should_retry_on_failure() and isinstance( # noqa: E501 + exception.original_exception, self.RETRY_EXCEPTIONS + ): + return self._do_retry( + data=data, + files=files, + json=json, + method=method, + params=params, + retry_strategy_state=retry_strategy_state, + status=repr(exception.original_exception), + timeout=timeout, + url=url, + ) + raise + + retry_status = None + if response.status_code == codes["unauthorized"]: self._authorizer._clear_access_token() if hasattr(self._authorizer, "refresh"): - do_retry = True + retry_status = f"{response.status_code} status" + elif response.status_code in self.RETRY_STATUSES: + retry_status = f"{response.status_code} status" - if retry_strategy_state.should_retry_on_failure() and ( - do_retry or response is None or response.status_code in self.RETRY_STATUSES - ): + if retry_status is not None and retry_strategy_state.should_retry_on_failure(): return self._do_retry( - data, - files, - json, - method, - params, - response, - retry_strategy_state, - saved_exception, - timeout, - url, + data=data, + files=files, + json=json, + method=method, + params=params, + retry_strategy_state=retry_strategy_state, + status=retry_status, + timeout=timeout, + url=url, ) - if response.status_code in self.STATUS_EXCEPTIONS: - raise self.STATUS_EXCEPTIONS[response.status_code](response) if response.status_code == codes["no_content"]: return None - assert response.status_code in self.SUCCESS_STATUSES, ( - f"Unexpected status code: {response.status_code}" - ) + if response.status_code in self.STATUS_EXCEPTIONS: + raise self.STATUS_EXCEPTIONS[response.status_code](response) + if response.status_code not in self.SUCCESS_STATUSES: + raise ResponseException(response) if response.headers.get("content-length") == "0": return "" try: @@ -274,11 +316,12 @@ def _request_with_retries( raise BadJSON(response) from None def _set_header_callback(self) -> dict[str, str]: - if not self._authorizer.is_valid() and hasattr(self._authorizer, "refresh"): - self._authorizer.refresh() + refresh_method = getattr(self._authorizer, "refresh", None) + if not self._authorizer.is_valid() and refresh_method is not None: + refresh_method() return {"Authorization": f"bearer {self._authorizer.access_token}"} - def close(self): + def close(self) -> None: """Close the session and perform any clean up.""" self._requestor.close() @@ -286,12 +329,12 @@ def request( self, method: str, path: str, - data: dict[str, Any] | None = None, + data: dict[str, object] | None = None, files: dict[str, BinaryIO | TextIO] | None = None, - json: dict[str, Any] | None = None, - params: dict[str, Any] | None = None, + json: dict[str, object] | None = None, + params: dict[str, object] | None = None, timeout: float = TIMEOUT, - ) -> dict[str, Any] | str | None: + ) -> dict[str, object] | str | None: """Return the json content from the resource at ``path``. :param method: The request verb. E.g., ``"GET"``, ``"POST"``, ``"PUT"``. @@ -316,13 +359,15 @@ def request( if isinstance(data, dict): data = deepcopy(data) data["api_type"] = "json" - data = sorted(data.items()) + data_list = sorted(data.items()) + else: + data_list = data if isinstance(json, dict): json = deepcopy(json) json["api_type"] = "json" url = urljoin(self._requestor.oauth_url, path) return self._request_with_retries( - data=data, + data=data_list, files=files, json=json, method=method, @@ -333,7 +378,7 @@ def request( def session( - authorizer: Authorizer = None, + authorizer: Authorizer | None = None, window_size: int = WINDOW_SIZE, ) -> Session: """Return a :class:`.Session` instance. @@ -343,29 +388,3 @@ def session( """ return Session(authorizer=authorizer, window_size=window_size) - - -class FiniteRetryStrategy(RetryStrategy): - """A ``RetryStrategy`` that retries requests a finite number of times.""" - - def __init__(self, retries: int = 3): - """Initialize the strategy. - - :param retries: Number of times to attempt a request (default: ``3``). - - """ - self._retries = retries - - def _sleep_seconds(self) -> float | None: - if self._retries < 3: - base = 0 if self._retries == 2 else 2 - return base + 2 * random.random() # noqa: S311 - return None - - def consume_available_retry(self) -> FiniteRetryStrategy: - """Allow one fewer retry.""" - return type(self)(self._retries - 1) - - def should_retry_on_failure(self) -> bool: - """Return ``True`` if and only if the strategy will allow another retry.""" - return self._retries > 1 diff --git a/prawcore/util.py b/prawcore/util.py index ddd7338..fa0f0cb 100644 --- a/prawcore/util.py +++ b/prawcore/util.py @@ -26,8 +26,5 @@ def authorization_error_class( """ message = response.headers.get("www-authenticate") error: int | str - if message: - error = message.replace('"', "").rsplit("=", 1)[1] - else: - error = response.status_code + error = message.replace('"', "").rsplit("=", 1)[1] if message else response.status_code return _auth_error_mapping[error](response) diff --git a/pyproject.toml b/pyproject.toml index 2cec1c9..b13a921 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,6 @@ requires-python = "~=3.9" [project.optional-dependencies] dev = [ - "packaging", "prawcore[lint]", "prawcore[test]" ] @@ -43,7 +42,9 @@ lint = [ ] test = [ "betamax >=0.8, <0.9", + "pyright", "pytest ==7.*", + "types-requests", "urllib3 ==1.*" ] @@ -52,23 +53,22 @@ test = [ "Source Code" = "https://github.com/praw-dev/prawcore" [tool.ruff] -target-version = "py39" include = [ - "prawcore/*.py" + "examples/*.py", + "prawcore/*.py", + "tests/*.py" ] +line-length = 120 [tool.ruff.lint] ignore = [ "A002", # shadowing built-in name "A004", # shadowing built-in "ANN202", # missing return type for private method - "ANN401", # typing.Any usage "D203", # 1 blank line required before class docstring "D213", # Multi-line docstring summary should start at the second line "E501", # line-length - "PLR0913", # too many arguments - "PLR2004", # Magic value used in comparison, - "S101" # use of assert + "PLR0913" # too many arguments ] select = [ "A", # flake8-builtins @@ -113,9 +113,10 @@ select = [ [tool.ruff.lint.flake8-annotations] allow-star-arg-any = true -mypy-init-return = true suppress-dummy-args = true -suppress-none-returning = true [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] +"pre_push.py" = ["T201"] +"examples/*.py" = ["ANN", "PLR2004", "T201"] +"tests/**.py" = ["ANN", "D", "PLR2004", "S101", "S105", "S106", "S301"] diff --git a/tests/conftest.py b/tests/conftest.py index 49286a9..31d6b92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import socket import time from base64 import b64encode +from pathlib import Path from sys import platform import pytest @@ -17,7 +18,6 @@ def patch_sleep(monkeypatch): def _sleep(*_, **__): """Dud sleep function.""" - pass monkeypatch.setattr(time, "sleep", value=_sleep) @@ -28,7 +28,7 @@ def image_path(): def _get_path(name): """Return path to image.""" - return os.path.join(os.path.dirname(__file__), "integration", "files", name) + return Path(__file__).parent / "integration" / "files" / name return _get_path @@ -65,20 +65,14 @@ def env_default(key): def pytest_configure(config): pytest.placeholders = Placeholders(placeholders) - config.addinivalue_line( - "markers", "add_placeholder: Define an additional placeholder for the cassette." - ) - config.addinivalue_line( - "markers", "cassette_name: Name of cassette to use for test." - ) - config.addinivalue_line( - "markers", "recorder_kwargs: Arguments to pass to the recorder." - ) + config.addinivalue_line("markers", "add_placeholder: Define an additional placeholder for the cassette.") + config.addinivalue_line("markers", "cassette_name: Name of cassette to use for test.") + config.addinivalue_line("markers", "recorder_kwargs: Arguments to pass to the recorder.") def two_factor_callback(): """Return an OTP code.""" - return None + return class Placeholders: @@ -94,18 +88,13 @@ def __init__(self, _dict): ).split() } -if ( - placeholders["client_id"] != "fake_client_id" - and placeholders["client_secret"] == "fake_client_secret" -): - placeholders["basic_auth"] = b64encode( - f"{placeholders['client_id']}:".encode("utf-8") - ).decode("utf-8") +if placeholders["client_id"] != "fake_client_id" and placeholders["client_secret"] == "fake_client_secret": + placeholders["basic_auth"] = b64encode(f"{placeholders['client_id']}:".encode()).decode("utf-8") else: placeholders["basic_auth"] = b64encode( - f"{placeholders['client_id']}:{placeholders['client_secret']}".encode("utf-8") + f"{placeholders['client_id']}:{placeholders['client_secret']}".encode() ).decode("utf-8") if platform == "darwin": # Work around issue with betamax on OS X - socket.gethostbyname = lambda x: "127.0.0.1" + socket.gethostbyname = lambda _: "127.0.0.1" diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index f876de9..b2c6447 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -24,20 +24,17 @@ class IntegrationTest: @pytest.fixture(autouse=True, scope="session") def cassette_tracker(self): """Track cassettes to ensure unused cassettes are not uploaded.""" - global existing_cassettes for cassette in os.listdir(CASSETTES_PATH): existing_cassettes.add(cassette[: cassette.rindex(".")]) yield unused_cassettes = existing_cassettes - used_cassettes if unused_cassettes and os.getenv("ENSURE_NO_UNUSED_CASSETTES", "0") == "1": - raise AssertionError( - f"The following cassettes are unused: {', '.join(unused_cassettes)}." - ) + msg = f"The following cassettes are unused: {', '.join(unused_cassettes)}." + raise AssertionError(msg) @pytest.fixture(autouse=True) def cassette(self, request, recorder, cassette_name): """Wrap a test in a Betamax cassette.""" - global used_cassettes kwargs = {} for marker in request.node.iter_markers("add_placeholder"): for key, value in marker.kwargs.items(): @@ -49,9 +46,9 @@ def cassette(self, request, recorder, cassette_name): # Don't overwrite existing values since function markers are provided # before class markers. kwargs.setdefault(key, value) - with recorder.use_cassette(cassette_name, **kwargs) as recorder: - cassette = recorder.current_cassette - yield recorder + with recorder.use_cassette(cassette_name, **kwargs) as recorder_context: + cassette = recorder_context.current_cassette + yield recorder_context ensure_integration_test(cassette) used_cassettes.add(cassette_name) @@ -66,7 +63,7 @@ def recorder(self, requestor): config.before_record(callback=filter_access_token) for key, value in pytest.placeholders.__dict__.items(): if key == "password": - value = quote_plus(value) + value = quote_plus(value) # noqa: PLW2901 config.define_cassette_placeholder(f"<{key.upper()}>", value) yield recorder # since placeholders persist between tests @@ -77,9 +74,5 @@ def cassette_name(self, request): """Return the name of the cassette to use.""" marker = request.node.get_closest_marker("cassette_name") if marker is None: - return ( - f"{request.cls.__name__}.{request.node.name}" - if request.cls - else request.node.name - ) + return f"{request.cls.__name__}.{request.node.name}" if request.cls else request.node.name return marker.args[0] diff --git a/tests/integration/cassettes/TestSession.test_request__unexpected_status_code.json b/tests/integration/cassettes/TestSession.test_request__unexpected_status_code.json new file mode 100644 index 0000000..6026d0b --- /dev/null +++ b/tests/integration/cassettes/TestSession.test_request__unexpected_status_code.json @@ -0,0 +1,102 @@ +{ + "http_interactions": [ + { + "recorded_at": "2016-07-09T23:02:18", + "request": { + "body": { + "encoding": "utf-8", + "string": "grant_type=password&password=&username=" + }, + "headers": { + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate", + "Authorization": "Basic ", + "Connection": "keep-alive", + "Content-Length": "57", + "Content-Type": "application/x-www-form-urlencoded", + "Cookie": "loid=S1UX7gEWPLHZFXDMKZ; __cfduid=dd96ea2981f3ad6ab5250a89b33e0ace81458613117", + "User-Agent": "prawcore:test (by /u/bboe) prawcore/0.0.9" + }, + "method": "POST", + "uri": "https://www.reddit.com/api/v1/access_token" + }, + "response": { + "body": { + "base64_string": "H4sIAAAAAAAAA6tWSkxOTi0uji/Jz07NU7JSUMopjPSKNDVLqkpPyfVwsihKNzbJ8DZzLTZJS1fSUVACq4svqSxIBSlOSk0sSi0CiadWFGQWpRbHZ4IMMTYzMNBRUCpOzoco01KqBQCnwDS+aQAAAA==", + "encoding": "UTF-8", + "string": "" + }, + "headers": { + "CF-RAY": "2bff671aee2f2969-DUB", + "Connection": "keep-alive", + "Content-Encoding": "gzip", + "Content-Type": "application/json; charset=UTF-8", + "Date": "Sat, 09 Jul 2016 23:02:18 GMT", + "Server": "cloudflare-nginx", + "Strict-Transport-Security": "max-age=15552000; includeSubDomains; preload", + "Transfer-Encoding": "chunked", + "X-Moose": "majestic", + "cache-control": "max-age=0, must-revalidate", + "x-content-type-options": "nosniff", + "x-frame-options": "SAMEORIGIN", + "x-xss-protection": "1; mode=block" + }, + "status": { + "code": 200, + "message": "OK" + }, + "url": "https://www.reddit.com/api/v1/access_token" + } + }, + { + "recorded_at": "2016-07-09T23:02:18", + "request": { + "body": { + "encoding": "utf-8", + "string": "" + }, + "headers": { + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate", + "Authorization": "bearer lqYJY56bzgdmHB8rg34hK6Es4fg", + "Connection": "keep-alive", + "Content-Length": "0", + "Cookie": "loid=S1UX7gEWPLHZFXDMKZ; __cfduid=dd96ea2981f3ad6ab5250a89b33e0ace81458613117", + "User-Agent": "prawcore:test (by /u/bboe) prawcore/0.0.9" + }, + "method": "DELETE", + "uri": "https://oauth.reddit.com/api/v1/me/friends/spez?raw_json=1" + }, + "response": { + "body": { + "encoding": "UTF-8", + "string": "" + }, + "headers": { + "CF-RAY": "2bff671f9154296f-DUB", + "Connection": "keep-alive", + "Content-Length": "0", + "Content-Type": "application/json; charset=UTF-8", + "Date": "Sat, 09 Jul 2016 23:02:18 GMT", + "Server": "cloudflare-nginx", + "Strict-Transport-Security": "max-age=15552000; includeSubDomains; preload", + "X-Moose": "majestic", + "cache-control": "private, s-maxage=0, max-age=0, must-revalidate", + "expires": "-1", + "x-content-type-options": "nosniff", + "x-frame-options": "SAMEORIGIN", + "x-ratelimit-remaining": "598.0", + "x-ratelimit-reset": "462", + "x-ratelimit-used": "2", + "x-xss-protection": "1; mode=block" + }, + "status": { + "code": 205, + "message": "Reset Content" + }, + "url": "https://oauth.reddit.com/api/v1/me/friends/spez?raw_json=1" + } + } + ], + "recorded_with": "betamax/0.7.1" +} diff --git a/tests/integration/test_authenticator.py b/tests/integration/test_authenticator.py index f71cbdd..253580d 100644 --- a/tests/integration/test_authenticator.py +++ b/tests/integration/test_authenticator.py @@ -35,7 +35,5 @@ def test_revoke_token__with_refresh_token_hint(self, requestor): class TestUntrustedAuthenticator(IntegrationTest): def test_revoke_token(self, requestor): - authenticator = prawcore.UntrustedAuthenticator( - requestor, pytest.placeholders.client_id - ) + authenticator = prawcore.UntrustedAuthenticator(requestor, pytest.placeholders.client_id) authenticator.revoke_token("dummy token") diff --git a/tests/integration/test_authorizer.py b/tests/integration/test_authorizer.py index 8bb5169..749df66 100644 --- a/tests/integration/test_authorizer.py +++ b/tests/integration/test_authorizer.py @@ -38,9 +38,7 @@ def test_authorize__with_temporary_grant(self, trusted_authenticator): assert authorizer.is_valid() def test_refresh(self, trusted_authenticator): - authorizer = prawcore.Authorizer( - trusted_authenticator, refresh_token=pytest.placeholders.refresh_token - ) + authorizer = prawcore.Authorizer(trusted_authenticator, refresh_token=pytest.placeholders.refresh_token) authorizer.refresh() assert authorizer.access_token is not None @@ -73,9 +71,7 @@ def callback(authorizer): assert authorizer.refresh_token is None authorizer.refresh_token = pytest.placeholders.refresh_token - authorizer = prawcore.Authorizer( - trusted_authenticator, pre_refresh_callback=callback - ) + authorizer = prawcore.Authorizer(trusted_authenticator, pre_refresh_callback=callback) authorizer.refresh() assert authorizer.access_token is not None @@ -84,17 +80,13 @@ def callback(authorizer): assert authorizer.is_valid() def test_refresh__with_invalid_token(self, trusted_authenticator): - authorizer = prawcore.Authorizer( - trusted_authenticator, refresh_token="INVALID_TOKEN" - ) + authorizer = prawcore.Authorizer(trusted_authenticator, refresh_token="INVALID_TOKEN") with pytest.raises(prawcore.ResponseException): authorizer.refresh() assert not authorizer.is_valid() def test_revoke__access_token_with_refresh_set(self, trusted_authenticator): - authorizer = prawcore.Authorizer( - trusted_authenticator, refresh_token=pytest.placeholders.refresh_token - ) + authorizer = prawcore.Authorizer(trusted_authenticator, refresh_token=pytest.placeholders.refresh_token) authorizer.refresh() authorizer.revoke(only_access=True) @@ -119,9 +111,7 @@ def test_revoke__access_token_without_refresh_set(self, trusted_authenticator): assert not authorizer.is_valid() def test_revoke__refresh_token_with_access_set(self, trusted_authenticator): - authorizer = prawcore.Authorizer( - trusted_authenticator, refresh_token=pytest.placeholders.refresh_token - ) + authorizer = prawcore.Authorizer(trusted_authenticator, refresh_token=pytest.placeholders.refresh_token) authorizer.refresh() authorizer.revoke() @@ -131,9 +121,7 @@ def test_revoke__refresh_token_with_access_set(self, trusted_authenticator): assert not authorizer.is_valid() def test_revoke__refresh_token_without_access_set(self, trusted_authenticator): - authorizer = prawcore.Authorizer( - trusted_authenticator, refresh_token=pytest.placeholders.refresh_token - ) + authorizer = prawcore.Authorizer(trusted_authenticator, refresh_token=pytest.placeholders.refresh_token) authorizer.revoke() assert authorizer.access_token is None @@ -151,9 +139,7 @@ def test_refresh(self, untrusted_authenticator): assert authorizer.scopes == {"*"} assert authorizer.is_valid() - def test_refresh__with_scopes_and_trusted_authenticator( - self, requestor, untrusted_authenticator - ): + def test_refresh__with_scopes_and_trusted_authenticator(self, requestor): scope_list = {"adsedit", "adsread", "creddits", "history"} authorizer = prawcore.DeviceIDAuthorizer( prawcore.TrustedAuthenticator( @@ -189,9 +175,7 @@ def test_refresh(self, trusted_authenticator): def test_refresh__with_scopes(self, trusted_authenticator): scope_list = {"adsedit", "adsread", "creddits", "history"} - authorizer = prawcore.ReadOnlyAuthorizer( - trusted_authenticator, scopes=scope_list - ) + authorizer = prawcore.ReadOnlyAuthorizer(trusted_authenticator, scopes=scope_list) assert authorizer.access_token is None assert authorizer.scopes is None assert not authorizer.is_valid() @@ -230,9 +214,7 @@ def test_refresh__with_invalid_otp(self, trusted_authenticator): assert not authorizer.is_valid() def test_refresh__with_invalid_username_or_password(self, trusted_authenticator): - authorizer = prawcore.ScriptAuthorizer( - trusted_authenticator, pytest.placeholders.username, "invalidpassword" - ) + authorizer = prawcore.ScriptAuthorizer(trusted_authenticator, pytest.placeholders.username, "invalidpassword") with pytest.raises(prawcore.OAuthException): authorizer.refresh() assert not authorizer.is_valid() diff --git a/tests/integration/test_sessions.py b/tests/integration/test_sessions.py index a3e84f4..110e3e3 100644 --- a/tests/integration/test_sessions.py +++ b/tests/integration/test_sessions.py @@ -2,6 +2,7 @@ import logging from json import dumps +from pathlib import Path import pytest @@ -33,11 +34,7 @@ def test_request__accepted(self, script_authorizer, caplog): session.request("POST", "api/read_all_messages") found_message = False for package, level, message in caplog.record_tuples: - if ( - package == "prawcore" - and level == logging.DEBUG - and "Response: 202 (2 bytes)" in message - ): + if package == "prawcore" and level == logging.DEBUG and "Response: 202 (2 bytes)" in message: found_message = True assert found_message, f"'Response: 202 (2 bytes)' in {caplog.record_tuples}" @@ -56,25 +53,19 @@ def test_request__bad_json(self, script_authorizer): def test_request__bad_request(self, script_authorizer): session = prawcore.Session(script_authorizer) with pytest.raises(prawcore.BadRequest) as exception_info: - session.request( - "PUT", "/api/v1/me/friends/spez", data='{"note": "prawcore"}' - ) + session.request("PUT", "/api/v1/me/friends/spez", data='{"note": "prawcore"}') assert "reason" in exception_info.value.response.json() def test_request__cloudflare_connection_timed_out(self, readonly_authorizer): session = prawcore.Session(readonly_authorizer) with pytest.raises(prawcore.ServerError) as exception_info: session.request("GET", "/") - session.request("GET", "/") - session.request("GET", "/") assert exception_info.value.response.status_code == 522 def test_request__cloudflare_unknown_error(self, readonly_authorizer): session = prawcore.Session(readonly_authorizer) with pytest.raises(prawcore.ServerError) as exception_info: session.request("GET", "/") - session.request("GET", "/") - session.request("GET", "/") assert exception_info.value.response.status_code == 520 def test_request__conflict(self, script_authorizer): @@ -164,7 +155,7 @@ def test_request__post(self, script_authorizer): def test_request__post__with_files(self, script_authorizer): session = prawcore.Session(script_authorizer) data = {"upload_type": "header"} - with open("tests/integration/files/white-square.png", "rb") as fp: + with Path("tests/integration/files/white-square.png").open("rb") as fp: files = {"file": fp} response = session.request( "POST", @@ -180,10 +171,7 @@ def test_request__raw_json(self, readonly_authorizer): "GET", "/r/reddit_api_test/comments/45xjdr/want_raw_json_test/", ) - assert ( - "WANT_RAW_JSON test: < > &" - == response[0]["data"]["children"][0]["data"]["title"] - ) + assert response[0]["data"]["children"][0]["data"]["title"] == "WANT_RAW_JSON test: < > &" def test_request__redirect(self, readonly_authorizer): session = prawcore.Session(readonly_authorizer) @@ -201,23 +189,17 @@ def test_request__service_unavailable(self, readonly_authorizer): session = prawcore.Session(readonly_authorizer) with pytest.raises(prawcore.ServerError) as exception_info: session.request("GET", "/") - session.request("GET", "/") - session.request("GET", "/") assert exception_info.value.response.status_code == 503 def test_request__too__many_requests__with_retry_headers(self, readonly_authorizer): session = prawcore.Session(readonly_authorizer) - session._requestor._http.headers.update( - {"User-Agent": "python-requests/2.25.1"} - ) + session._requestor._http.headers.update({"User-Agent": "python-requests/2.25.1"}) with pytest.raises(prawcore.TooManyRequests) as exception_info: session.request("GET", "/api/v1/me") assert exception_info.value.response.status_code == 429 assert exception_info.value.response.headers.get("retry-after") assert exception_info.value.response.reason == "Too Many Requests" - assert str(exception_info.value).startswith( - "received 429 HTTP response. Please wait at least" - ) + assert str(exception_info.value).startswith("received 429 HTTP response. Please wait at least") assert exception_info.value.message.startswith("\n") def test_request__too__many_requests__without_retry_headers(self, requestor): @@ -243,7 +225,7 @@ def test_request__too__many_requests__without_retry_headers(self, requestor): def test_request__too_large(self, script_authorizer): session = prawcore.Session(script_authorizer) data = {"upload_type": "header"} - with open("tests/integration/files/too_large.jpg", "rb") as fp: + with Path("tests/integration/files/too_large.jpg").open("rb") as fp: files = {"file": fp} with pytest.raises(prawcore.TooLarge) as exception_info: session.request( @@ -256,35 +238,37 @@ def test_request__too_large(self, script_authorizer): def test_request__unavailable_for_legal_reasons(self, readonly_authorizer): session = prawcore.Session(readonly_authorizer) - exception_class = prawcore.UnavailableForLegalReasons - with pytest.raises(exception_class) as exception_info: + with pytest.raises(prawcore.UnavailableForLegalReasons) as exception_info: session.request("GET", "/") assert exception_info.value.response.status_code == 451 + def test_request__unexpected_status_code(self, script_authorizer): + session = prawcore.Session(script_authorizer) + with pytest.raises(prawcore.ResponseException) as exception_info: + session.request("DELETE", "/api/v1/me/friends/spez") + assert exception_info.value.response.status_code == 205 + def test_request__unsupported_media_type(self, script_authorizer): session = prawcore.Session(script_authorizer) - exception_class = prawcore.SpecialError data = { "content": "type: submission\naction: upvote", "page": "config/automoderator", } - with pytest.raises(exception_class) as exception_info: + with pytest.raises(prawcore.SpecialError) as exception_info: session.request("POST", "r/ttft/api/wiki/edit/", data=data) assert exception_info.value.response.status_code == 415 def test_request__uri_too_long(self, readonly_authorizer): session = prawcore.Session(readonly_authorizer) path_start = "/api/morechildren?link_id=t3_n7r3uz&children=" - with open("tests/integration/files/comment_ids.txt") as fp: + with Path("tests/integration/files/comment_ids.txt").open() as fp: ids = fp.read() with pytest.raises(prawcore.URITooLong) as exception_info: session.request("GET", (path_start + ids)[:9996]) assert exception_info.value.response.status_code == 414 def test_request__with_insufficient_scope(self, trusted_authenticator): - authorizer = prawcore.Authorizer( - trusted_authenticator, refresh_token=pytest.placeholders.refresh_token - ) + authorizer = prawcore.Authorizer(trusted_authenticator, refresh_token=pytest.placeholders.refresh_token) authorizer.refresh() session = prawcore.Session(authorizer) with pytest.raises(prawcore.InsufficientScope): diff --git a/tests/unit/test_authenticator.py b/tests/unit/test_authenticator.py index fcf2e06..6d93380 100644 --- a/tests/unit/test_authenticator.py +++ b/tests/unit/test_authenticator.py @@ -14,9 +14,7 @@ def trusted_authenticator(self, trusted_authenticator): return trusted_authenticator def test_authorize_url(self, trusted_authenticator): - url = trusted_authenticator.authorize_url( - "permanent", ["identity", "read"], "a_state" - ) + url = trusted_authenticator.authorize_url("permanent", ["identity", "read"], "a_state") assert f"client_id={pytest.placeholders.client_id}" in url assert "duration=permanent" in url assert "response_type=code" in url @@ -25,9 +23,7 @@ def test_authorize_url(self, trusted_authenticator): def test_authorize_url__fail_with_implicit(self, trusted_authenticator): with pytest.raises(prawcore.InvalidInvocation): - trusted_authenticator.authorize_url( - "temporary", ["identity", "read"], "a_state", implicit=True - ) + trusted_authenticator.authorize_url("temporary", ["identity", "read"], "a_state", implicit=True) def test_authorize_url__fail_without_redirect_uri(self, trusted_authenticator): trusted_authenticator.redirect_uri = None @@ -46,18 +42,14 @@ def untrusted_authenticator(self, untrusted_authenticator): return untrusted_authenticator def test_authorize_url__code(self, untrusted_authenticator): - url = untrusted_authenticator.authorize_url( - "permanent", ["identity", "read"], "a_state" - ) + url = untrusted_authenticator.authorize_url("permanent", ["identity", "read"], "a_state") assert f"client_id={pytest.placeholders.client_id}" in url assert "duration=permanent" in url assert "response_type=code" in url assert "scope=identity+read" in url assert "state=a_state" in url - def test_authorize_url__fail_with_token_and_permanent( - self, untrusted_authenticator - ): + def test_authorize_url__fail_with_token_and_permanent(self, untrusted_authenticator): with pytest.raises(prawcore.InvalidInvocation): untrusted_authenticator.authorize_url( "permanent", @@ -76,9 +68,7 @@ def test_authorize_url__fail_without_redirect_uri(self, untrusted_authenticator) ) def test_authorize_url__token(self, untrusted_authenticator): - url = untrusted_authenticator.authorize_url( - "temporary", ["identity", "read"], "a_state", implicit=True - ) + url = untrusted_authenticator.authorize_url("temporary", ["identity", "read"], "a_state", implicit=True) assert f"client_id={pytest.placeholders.client_id}" in url assert "duration=temporary" in url assert "response_type=token" in url diff --git a/tests/unit/test_authorizer.py b/tests/unit/test_authorizer.py index 02e1004..a8858a0 100644 --- a/tests/unit/test_authorizer.py +++ b/tests/unit/test_authorizer.py @@ -27,9 +27,7 @@ def test_initialize(self, trusted_authenticator): assert not authorizer.is_valid() def test_initialize__with_refresh_token(self, trusted_authenticator): - authorizer = prawcore.Authorizer( - trusted_authenticator, refresh_token=pytest.placeholders.refresh_token - ) + authorizer = prawcore.Authorizer(trusted_authenticator, refresh_token=pytest.placeholders.refresh_token) assert authorizer.access_token is None assert authorizer.scopes is None assert pytest.placeholders.refresh_token == authorizer.refresh_token @@ -50,9 +48,7 @@ def test_refresh__without_refresh_token(self, trusted_authenticator): assert not authorizer.is_valid() def test_revoke__without_access_token(self, trusted_authenticator): - authorizer = prawcore.Authorizer( - trusted_authenticator, refresh_token=pytest.placeholders.refresh_token - ) + authorizer = prawcore.Authorizer(trusted_authenticator, refresh_token=pytest.placeholders.refresh_token) with pytest.raises(prawcore.InvalidInvocation): authorizer.revoke(only_access=True) @@ -77,9 +73,7 @@ def test_initialize__with_invalid_authenticator(self): class TestImplicitAuthorizer(UnitTest): def test_initialize(self, untrusted_authenticator): - authorizer = prawcore.ImplicitAuthorizer( - untrusted_authenticator, "fake token", 1, "modposts read" - ) + authorizer = prawcore.ImplicitAuthorizer(untrusted_authenticator, "fake token", 1, "modposts read") assert authorizer.access_token == "fake token" assert authorizer.scopes == {"modposts", "read"} assert authorizer.is_valid() diff --git a/tests/unit/test_rate_limit.py b/tests/unit/test_rate_limit.py index 4bf93e1..c9ebf1c 100644 --- a/tests/unit/test_rate_limit.py +++ b/tests/unit/test_rate_limit.py @@ -36,9 +36,7 @@ def test_delay(self, mock_sleep, mock_monotonic_ns, rate_limiter): @patch("time.monotonic_ns") @patch("time.sleep") - def test_delay__no_sleep_when_time_in_past( - self, mock_sleep, mock_monotonic_ns, rate_limiter - ): + def test_delay__no_sleep_when_time_in_past(self, mock_sleep, mock_monotonic_ns, rate_limiter): mock_monotonic_ns.return_value = 101 * NANOSECONDS rate_limiter.delay() assert mock_monotonic_ns.called @@ -52,18 +50,14 @@ def test_delay__no_sleep_when_time_is_not_set(self, mock_sleep, rate_limiter): @patch("time.monotonic_ns") @patch("time.sleep") - def test_delay__no_sleep_when_times_match( - self, mock_sleep, mock_monotonic_ns, rate_limiter - ): + def test_delay__no_sleep_when_times_match(self, mock_sleep, mock_monotonic_ns, rate_limiter): mock_monotonic_ns.return_value = 100 * NANOSECONDS rate_limiter.delay() assert mock_monotonic_ns.called assert not mock_sleep.called @patch("time.monotonic_ns") - def test_update__compute_delay_with_no_previous_info( - self, mock_monotonic_ns, rate_limiter - ): + def test_update__compute_delay_with_no_previous_info(self, mock_monotonic_ns, rate_limiter): mock_monotonic_ns.return_value = 100 * NANOSECONDS rate_limiter.update(self._headers(60, 100, 60)) assert rate_limiter.remaining == 60 @@ -71,9 +65,7 @@ def test_update__compute_delay_with_no_previous_info( assert rate_limiter.next_request_timestamp_ns == 100 * NANOSECONDS @patch("time.monotonic_ns") - def test_update__compute_delay_with_single_client( - self, mock_monotonic_ns, rate_limiter - ): + def test_update__compute_delay_with_single_client(self, mock_monotonic_ns, rate_limiter): rate_limiter.window_size = 150 mock_monotonic_ns.return_value = 100 * NANOSECONDS rate_limiter.update(self._headers(50, 100, 60)) @@ -82,9 +74,7 @@ def test_update__compute_delay_with_single_client( assert rate_limiter.next_request_timestamp_ns == 110 * NANOSECONDS @patch("time.monotonic_ns") - def test_update__compute_delay_with_six_clients( - self, mock_monotonic_ns, rate_limiter - ): + def test_update__compute_delay_with_six_clients(self, mock_monotonic_ns, rate_limiter): rate_limiter.remaining = 66 rate_limiter.window_size = 180 mock_monotonic_ns.return_value = 100 * NANOSECONDS @@ -94,9 +84,7 @@ def test_update__compute_delay_with_six_clients( assert rate_limiter.next_request_timestamp_ns == 104.5 * NANOSECONDS @patch("time.monotonic_ns") - def test_update__delay_full_time_with_negative_remaining( - self, mock_monotonic_ns, rate_limiter - ): + def test_update__delay_full_time_with_negative_remaining(self, mock_monotonic_ns, rate_limiter): mock_monotonic_ns.return_value = 37 * NANOSECONDS rate_limiter.update(self._headers(0, 100, 13)) assert rate_limiter.remaining == 0 @@ -104,9 +92,7 @@ def test_update__delay_full_time_with_negative_remaining( assert rate_limiter.next_request_timestamp_ns == 50 * NANOSECONDS @patch("time.monotonic_ns") - def test_update__delay_full_time_with_zero_remaining( - self, mock_monotonic_ns, rate_limiter - ): + def test_update__delay_full_time_with_zero_remaining(self, mock_monotonic_ns, rate_limiter): mock_monotonic_ns.return_value = 37 * NANOSECONDS rate_limiter.update(self._headers(0, 100, 13)) assert rate_limiter.remaining == 0 @@ -114,9 +100,7 @@ def test_update__delay_full_time_with_zero_remaining( assert rate_limiter.next_request_timestamp_ns == 50 * NANOSECONDS @patch("time.monotonic_ns") - def test_update__delay_full_time_with_zero_remaining_and_no_sleep_time( - self, mock_monotonic_ns, rate_limiter - ): + def test_update__delay_full_time_with_zero_remaining_and_no_sleep_time(self, mock_monotonic_ns, rate_limiter): mock_monotonic_ns.return_value = 37 * NANOSECONDS rate_limiter.update(self._headers(0, 100, 0)) assert rate_limiter.remaining == 0 diff --git a/tests/unit/test_requestor.py b/tests/unit/test_requestor.py index 56628ef..51a1c22 100644 --- a/tests/unit/test_requestor.py +++ b/tests/unit/test_requestor.py @@ -14,10 +14,7 @@ class TestRequestor(UnitTest): def test_initialize(self, requestor): - assert ( - requestor._http.headers["User-Agent"] - == f"prawcore:test (by /u/bboe) prawcore/{prawcore.__version__}" - ) + assert requestor._http.headers["User-Agent"] == f"prawcore:test (by /u/bboe) prawcore/{prawcore.__version__}" def test_initialize__failures(self): for agent in [None, "shorty"]: @@ -41,10 +38,7 @@ def test_request__use_custom_session(self): requestor = prawcore.Requestor("prawcore:test (by /u/bboe)", session=session) - assert ( - requestor._http.headers["User-Agent"] - == f"prawcore:test (by /u/bboe) prawcore/{prawcore.__version__}" - ) + assert requestor._http.headers["User-Agent"] == f"prawcore:test (by /u/bboe) prawcore/{prawcore.__version__}" assert requestor._http.headers["session_header"] == custom_header assert requestor.request("https://reddit.com") == override diff --git a/tests/unit/test_sessions.py b/tests/unit/test_sessions.py index 2a4cb8b..cfb96d7 100644 --- a/tests/unit/test_sessions.py +++ b/tests/unit/test_sessions.py @@ -8,13 +8,14 @@ import prawcore from prawcore.exceptions import RequestException +from prawcore.sessions import FiniteRetryStrategy from . import UnitTest class InvalidAuthorizer(prawcore.Authorizer): def __init__(self, requestor): - super(InvalidAuthorizer, self).__init__( + super().__init__( prawcore.TrustedAuthenticator( requestor, pytest.placeholders.client_id, @@ -61,9 +62,7 @@ def test_request__retry(self, mock_session, exception, caplog): session_instance = mock_session.return_value # Handle Auth response_dict = {"access_token": "", "expires_in": 99, "scope": ""} - session_instance.request.return_value = Mock( - headers={}, json=lambda: response_dict, status_code=200 - ) + session_instance.request.return_value = Mock(headers={}, json=lambda: response_dict, status_code=200) requestor = prawcore.Requestor("prawcore:test (by /u/bboe)") authenticator = prawcore.TrustedAuthenticator( requestor, @@ -81,8 +80,7 @@ def test_request__retry(self, mock_session, exception, caplog): assert ( "prawcore", logging.WARNING, - f"Retrying due to {exception.__class__.__name__}() status: GET " - "https://oauth.reddit.com/", + f"Retrying due to {exception.__class__.__name__}(): GET https://oauth.reddit.com/", ) in caplog.record_tuples assert isinstance(exception_info.value, RequestException) assert exception is exception_info.value.original_exception @@ -96,6 +94,26 @@ def test_request__with_invalid_authorizer(self, requestor): class TestSessionFunction(UnitTest): def test_session(self, requestor): - assert isinstance( - prawcore.session(InvalidAuthorizer(requestor)), prawcore.Session - ) + assert isinstance(prawcore.session(InvalidAuthorizer(requestor)), prawcore.Session) + + +class TestFiniteRetryStrategy(UnitTest): + @patch("time.sleep") + def test_strategy(self, mock_sleep): + strategy = FiniteRetryStrategy() + assert strategy.should_retry_on_failure() + strategy.sleep() + mock_sleep.assert_not_called() + + strategy = strategy.consume_available_retry() + assert strategy.should_retry_on_failure() + strategy.sleep() + assert len(calls := mock_sleep.mock_calls) == 1 + assert 0 < calls[0].args[0] < 2 + mock_sleep.reset_mock() + + strategy = strategy.consume_available_retry() + assert not strategy.should_retry_on_failure() + strategy.sleep() + assert len(calls := mock_sleep.mock_calls) == 1 + assert 2 < calls[0].args[0] < 4 diff --git a/tests/utils.py b/tests/utils.py index b7c2f71..2c393b2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,7 +3,6 @@ import json import betamax -import pytest from betamax.serializers import JSONSerializer @@ -12,9 +11,7 @@ def ensure_integration_test(cassette): is_integration_test = bool(cassette.interactions) action = "record" else: - is_integration_test = any( - interaction.used for interaction in cassette.interactions - ) + is_integration_test = any(interaction.used for interaction in cassette.interactions) action = "play back" message = f"Cassette did not {action} any requests. This test can be a unit test." assert is_integration_test, message @@ -33,9 +30,7 @@ def filter_access_token(interaction, current_cassette): except (KeyError, TypeError, ValueError): continue current_cassette.placeholders.append( - betamax.cassette.cassette.Placeholder( - placeholder=f"<{token_key.upper()}_TOKEN>", replace=token - ) + betamax.cassette.cassette.Placeholder(placeholder=f"<{token_key.upper()}_TOKEN>", replace=token) ) diff --git a/tools/__init__.py b/tools/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tools/bump_version.py b/tools/bump_version.py deleted file mode 100755 index 5f4c3ee..0000000 --- a/tools/bump_version.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python3 -import sys - -COMMIT_PREFIX = "Bump to v" - - -def main(): - line = sys.stdin.readline() - if not line.startswith(COMMIT_PREFIX): - sys.stderr.write( - f"Commit message does not begin with `{COMMIT_PREFIX}`.\nMessage:\n\n{line}" - ) - return 1 - print(line[len(COMMIT_PREFIX) : -1]) - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tools/extract_log_entry.py b/tools/extract_log_entry.py deleted file mode 100755 index 2a95e35..0000000 --- a/tools/extract_log_entry.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 -import sys - -import docutils.nodes -import docutils.parsers.rst -import docutils.utils - - -def get_entry_slice(doc): - current_version = sys.stdin.readline().strip() - start_line = None - end_line = None - for section in doc.children[0].children: - if start_line: - end_line = section.children[0].line - 2 - break - header = section.children[0] - if current_version in getattr(header, "rawsource", ""): - start_line = header.line - 2 - return slice(start_line, end_line) - - -def parse_rst(text: str) -> docutils.nodes.document: - parser = docutils.parsers.rst.Parser() - components = (docutils.parsers.rst.Parser,) - settings = docutils.frontend.OptionParser( - components=components - ).get_default_values() - settings.report_level = 4 - document = docutils.utils.new_document("", settings=settings) - parser.parse(text, document) - return document - - -with open("CHANGES.rst") as f: - source = f.read() - document = parse_rst(source) - -sys.stdout.write("\n".join(source.splitlines()[get_entry_slice(document)])) diff --git a/tools/set_version.py b/tools/set_version.py deleted file mode 100755 index 5780f36..0000000 --- a/tools/set_version.py +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/env python3 -import re -import sys -from datetime import date - -import packaging.version - -CHANGELOG_HEADER = ( - "Change Log\n==========\n\n" - "prawcore follows `semantic versioning `_.\n\n" -) -UNRELEASED_HEADER = "Unreleased\n----------\n\n" - - -def add_unreleased_to_changelog(): - with open("CHANGES.rst") as fp: - content = fp.read() - - if not content.startswith(CHANGELOG_HEADER): - sys.stderr.write("Unexpected CHANGES.rst header\n") - return False - new_header = f"{CHANGELOG_HEADER}{UNRELEASED_HEADER}" - if content.startswith(new_header): - sys.stderr.write("CHANGES.rst already contains Unreleased header\n") - return False - - with open("CHANGES.rst", "w") as fp: - fp.write(f"{new_header}{content[len(CHANGELOG_HEADER) :]}") - return True - - -def handle_unreleased(): - return add_unreleased_to_changelog() and increment_development_version() - - -def handle_version(version): - version = valid_version(version) - if not version: - return False - return update_changelog(version) and update_package(version) - - -def increment_development_version(): - with open("prawcore/__init__.py") as fp: - version = re.search('__version__ = "([^"]+)"', fp.read()).group(1) - - parsed_version = valid_version(version) - if not parsed_version: - return False - - if parsed_version.is_devrelease: - pre = "".join(str(x) for x in parsed_version.pre) if parsed_version.pre else "" - new_version = f"{parsed_version.base_version}{pre}.dev{parsed_version.dev + 1}" - elif parsed_version.is_prerelease: - new_version = f"{parsed_version}.dev0" - else: - assert parsed_version.base_version == version - new_version = f"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro + 1}.dev0" - - assert valid_version(new_version) - return update_package(new_version) - - -def main(): - if len(sys.argv) != 2: - sys.stderr.write(f"Usage: {sys.argv[0]} VERSION\n") - return 1 - if sys.argv[1] == "Unreleased": - return not handle_unreleased() - return not handle_version(sys.argv[1]) - - -def update_changelog(version): - with open("CHANGES.rst") as fp: - content = fp.read() - - expected_header = f"{CHANGELOG_HEADER}{UNRELEASED_HEADER}" - if not content.startswith(expected_header): - sys.stderr.write("CHANGES.rst does not contain Unreleased header.\n") - return False - - date_string = date.today().strftime("%Y/%m/%d") - version_line = f"{version} ({date_string})\n" - version_header = f"{version_line}{'-' * len(version_line[:-1])}\n\n" - - with open("CHANGES.rst", "w") as fp: - fp.write(f"{CHANGELOG_HEADER}{version_header}{content[len(expected_header) :]}") - return True - - -def update_package(version): - with open("prawcore/__init__.py") as fp: - content = fp.read() - - updated = re.sub('__version__ = "([^"]+)"', f'__version__ = "{version}"', content) - if content == updated: - sys.stderr.write("Package version string not changed\n") - return False - - with open("prawcore/__init__.py", "w") as fp: - fp.write(updated) - - print(version) - return True - - -def valid_version(version): - parsed_version = packaging.version.parse(version) - if parsed_version.local or parsed_version.is_postrelease or parsed_version.epoch: - sys.stderr.write("epoch, local postrelease version parts are not supported") - return False - return parsed_version - - -if __name__ == "__main__": - sys.exit(main())