diff --git a/docs/api/fetchers.md b/docs/api/fetchers.md new file mode 100644 index 0000000..d071d45 --- /dev/null +++ b/docs/api/fetchers.md @@ -0,0 +1,12 @@ +```{caution} +This API is not finalized, and may change in a patch version. +``` + +# `unearth.fetchers` + +```{eval-rst} +.. automodule:: unearth.fetchers + +.. autoclass:: unearth.fetchers.PyPIClient + :members: +``` diff --git a/docs/api/session.md b/docs/api/session.md deleted file mode 100644 index 17c1a33..0000000 --- a/docs/api/session.md +++ /dev/null @@ -1,14 +0,0 @@ -```{caution} -This API is not finalized, and may change in a patch version. -``` - -# `unearth.session` - -```{eval-rst} -.. automodule:: unearth.session - -.. autoclass:: unearth.session.PyPISession - :members: - -.. autoclass:: unearth.session.InsecureHTTPAdapter -``` diff --git a/noxfile.py b/noxfile.py index 7d568c6..96c5ca8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -13,7 +13,7 @@ def test(session): @nox.session(python="3.11") def docs(session): - session.install("-r", "docs/requirements.txt") + session.run("pdm", "install", "-Gdoc", external=True) # Generate documentation into `build/docs` session.run("sphinx-build", "-n", "-W", "-b", "html", "docs/", "build/docs") @@ -21,8 +21,7 @@ def docs(session): @nox.session(name="docs-live", python="3.11") def docs_live(session): - session.install("-r", "docs/requirements.txt") - session.install("-e", ".") + session.run("pdm", "install", "-Gdoc", external=True) session.install("sphinx-autobuild") session.run( diff --git a/pdm.lock b/pdm.lock index 9e2c0c2..bc2753f 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "keyring", "test", "doc"] +groups = ["default", "keyring", "test", "doc", "legacy"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:e58b37951cde629cefc10b0fe8158a638bedd2d7a090dbe832e4aa243761bc70" +content_hash = "sha256:18a70748f97dab1d43d0a3f8555f961ea5d27ac88297ec280fbadcb16df3fcab" [[package]] name = "alabaster" @@ -18,6 +18,23 @@ files = [ {file = "alabaster-0.7.13.tar.gz", hash = "sha256:a27a4a084d5e690e16e01e03ad2b2e552c61a65469419b907243193de1a84ae2"}, ] +[[package]] +name = "anyio" +version = "4.3.0" +requires_python = ">=3.8" +summary = "High level compatibility layer for multiple asynchronous event loop implementations" +groups = ["default"] +dependencies = [ + "exceptiongroup>=1.0.2; python_version < \"3.11\"", + "idna>=2.8", + "sniffio>=1.1", + "typing-extensions>=4.1; python_version < \"3.11\"", +] +files = [ + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, +] + [[package]] name = "babel" version = "2.14.0" @@ -62,7 +79,7 @@ name = "certifi" version = "2023.11.17" requires_python = ">=3.6" summary = "Python package for providing Mozilla's CA Bundle." -groups = ["default", "doc", "test"] +groups = ["default", "doc", "legacy", "test"] files = [ {file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"}, {file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"}, @@ -137,7 +154,7 @@ name = "charset-normalizer" version = "3.3.2" requires_python = ">=3.7.0" summary = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -groups = ["default", "doc", "test"] +groups = ["doc", "legacy", "test"] files = [ {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, @@ -295,7 +312,7 @@ name = "exceptiongroup" version = "1.2.0" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["test"] +groups = ["default", "test"] marker = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, @@ -338,12 +355,56 @@ files = [ {file = "furo-2024.1.29.tar.gz", hash = "sha256:4d6b2fe3f10a6e36eb9cc24c1e7beb38d7a23fc7b3c382867503b7fcac8a1e02"}, ] +[[package]] +name = "h11" +version = "0.14.0" +requires_python = ">=3.7" +summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +groups = ["default"] +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.4" +requires_python = ">=3.8" +summary = "A minimal low-level HTTP client." +groups = ["default"] +dependencies = [ + "certifi", + "h11<0.15,>=0.13", +] +files = [ + {file = "httpcore-1.0.4-py3-none-any.whl", hash = "sha256:ac418c1db41bade2ad53ae2f3834a3a0f5ae76b56cf5aa497d2d033384fc7d73"}, + {file = "httpcore-1.0.4.tar.gz", hash = "sha256:cb2839ccfcba0d2d3c1131d3c3e26dfc327326fbe7a5dc0dbfe9f6c9151bb022"}, +] + +[[package]] +name = "httpx" +version = "0.27.0" +requires_python = ">=3.8" +summary = "The next generation HTTP client." +groups = ["default"] +dependencies = [ + "anyio", + "certifi", + "httpcore==1.*", + "idna", + "sniffio", +] +files = [ + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, +] + [[package]] name = "idna" version = "3.6" requires_python = ">=3.5" summary = "Internationalized Domain Names in Applications (IDNA)" -groups = ["default", "doc", "test"] +groups = ["default", "doc", "legacy", "test"] files = [ {file = "idna-3.6-py3-none-any.whl", hash = "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f"}, {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, @@ -677,6 +738,20 @@ files = [ {file = "pytest_httpserver-1.0.10.tar.gz", hash = "sha256:77b9fbc2eb0a129cfbbacc8fe57e8cafe071d506489f31fe31e62f1b332d9905"}, ] +[[package]] +name = "pytest-mock" +version = "3.12.0" +requires_python = ">=3.8" +summary = "Thin-wrapper around the mock package for easier use with pytest" +groups = ["test"] +dependencies = [ + "pytest>=5.0", +] +files = [ + {file = "pytest-mock-3.12.0.tar.gz", hash = "sha256:31a40f038c22cad32287bb43932054451ff5583ff094bca6f675df2f8bc1a6e9"}, + {file = "pytest_mock-3.12.0-py3-none-any.whl", hash = "sha256:0972719a7263072da3a21c7f4773069bcc7486027d7e8e1f81d98a47e701bc4f"}, +] + [[package]] name = "pytz" version = "2023.3.post1" @@ -752,7 +827,7 @@ name = "requests" version = "2.31.0" requires_python = ">=3.7" summary = "Python HTTP for Humans." -groups = ["default", "doc", "test"] +groups = ["doc", "legacy", "test"] dependencies = [ "certifi>=2017.4.17", "charset-normalizer<4,>=2", @@ -792,6 +867,17 @@ files = [ {file = "SecretStorage-3.3.3.tar.gz", hash = "sha256:2403533ef369eca6d2ba81718576c5e0f564d5cca1b58f73a8b23e7d4eeebd77"}, ] +[[package]] +name = "sniffio" +version = "1.3.1" +requires_python = ">=3.7" +summary = "Sniff out which async library your code is running under" +groups = ["default"] +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "snowballstemmer" version = "2.2.0" @@ -978,12 +1064,24 @@ files = [ {file = "trustme-1.1.0.tar.gz", hash = "sha256:5375ad7fb427074bec956592e0d4ee2a4cf4da68934e1ba4bcf4217126bc45e6"}, ] +[[package]] +name = "typing-extensions" +version = "4.10.0" +requires_python = ">=3.8" +summary = "Backported and Experimental Type Hints for Python 3.8+" +groups = ["default"] +marker = "python_version < \"3.11\"" +files = [ + {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, + {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, +] + [[package]] name = "urllib3" version = "2.1.0" requires_python = ">=3.8" summary = "HTTP library with thread-safe connection pooling, file post, and more." -groups = ["default", "doc", "test"] +groups = ["doc", "legacy", "test"] files = [ {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, diff --git a/pyproject.toml b/pyproject.toml index d091d7c..724bfc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ readme = "README.md" requires-python = ">=3.8" dependencies = [ "packaging>=20", - "requests>=2.25", + "httpx>=0.27.0,<1", ] dynamic = ["version"] @@ -38,6 +38,9 @@ Changelog = "https://github.com/frostming/unearth/releases" keyring = [ "keyring", ] +legacy = [ + "requests>=2.25", +] [project.scripts] unearth = "unearth.__main__:cli" @@ -55,6 +58,7 @@ test = [ "flask>=2.1.2", "requests-wsgi-adapter>=0.4.1", "trustme>=0.9.0", + "pytest-mock>=3.12.0", ] doc = [ "furo", diff --git a/src/unearth/auth.py b/src/unearth/auth.py index fa29175..8fc97ef 100644 --- a/src/unearth/auth.py +++ b/src/unearth/auth.py @@ -6,15 +6,19 @@ import os import shutil import subprocess -from typing import Any, Callable, Iterable, Optional, Tuple, cast +from typing import TYPE_CHECKING, Optional, Tuple, cast from urllib.parse import SplitResult, urlparse, urlsplit -from requests import Response -from requests.auth import AuthBase, HTTPBasicAuth -from requests.models import PreparedRequest -from requests.utils import get_netrc_auth +from httpx import URL, Auth, BasicAuth -from unearth.utils import commonprefix, split_auth_from_url +from unearth.utils import commonprefix, get_netrc_auth, split_auth_from_url + +if TYPE_CHECKING: + from typing import Any, Callable, Generator, Iterable + + from httpx import Request, Response + from requests import Response as RequestsResponse + from requests.models import PreparedRequest KEYRING_DISABLED = False @@ -155,7 +159,9 @@ def get_keyring_auth(url: str | None, username: str | None) -> AuthInfo | None: return None -class MultiDomainBasicAuth(AuthBase): +class MultiDomainBasicAuth(Auth): + """A multi-domain HTTP basic authentication handler supporting both requests and httpx""" + def __init__(self, prompting: bool = True, index_urls: Iterable[str] = ()) -> None: self.prompting = prompting self.index_urls = list(index_urls) @@ -286,6 +292,8 @@ def _get_url_and_credentials( def __call__(self, req: PreparedRequest) -> PreparedRequest: # Get credentials for this request + from requests.auth import HTTPBasicAuth + url, username, password = self._get_url_and_credentials(cast(str, req.url)) req.url = url @@ -297,6 +305,56 @@ def __call__(self, req: PreparedRequest) -> PreparedRequest: return req + def auth_flow(self, request: Request) -> Generator[Request, Response, None]: + url, username, password = self._get_url_and_credentials(str(request.url)) + request.url = URL(url) + + if username is not None and password is not None: + basic_auth = BasicAuth(username, password) + request = next(basic_auth.auth_flow(request)) + + response = yield request + + if response.status_code != 401: + return + + # Query the keyring for credentials: + username, password = self._get_new_credentials( + url, allow_netrc=False, allow_keyring=True + ) + + # Prompt the user for a new username and password + save = False + netloc = response.url.netloc.decode() + if not username and not password: + # We are not able to prompt the user so simply return the response + if not self.prompting: + return + + username, password, save = self._prompt_for_password(netloc) + + # Store the new username and password to use for future requests + self._credentials_to_save = None + if username is not None and password is not None: + self._cached_passwords[netloc] = (username, password) + + # Prompt to save the password to keyring + if save and self._should_save_password_to_keyring(): + self._credentials_to_save = (netloc, username, password) + + # Add our new username and password to the request + basic_auth = BasicAuth(username or "", password or "") + request = next(basic_auth.auth_flow(request)) + + response = yield request + self.warn_on_401(response) + + # On successful request, save the credentials that were used to + # keyring. (Note that if the user responded "no" above, this member + # is not set and nothing will be saved.) + if self._credentials_to_save: + self.save_credentials(response) + # Factored out to allow for easy patching in tests def _prompt_for_password(self, netloc: str) -> tuple[str | None, str | None, bool]: username = input(f"User for {netloc}: ") @@ -314,9 +372,11 @@ def _should_save_password_to_keyring(self) -> bool: return False return input("Save credentials to keyring [y/N]: ") == "y" - def handle_401(self, resp: Response, **kwargs: Any) -> Response: + def handle_401(self, resp: RequestsResponse, **kwargs: Any) -> Response: # We only care about 401 response, anything else we want to just # pass through the actual response + from requests.auth import HTTPBasicAuth + if resp.status_code != 401: return resp @@ -370,7 +430,7 @@ def handle_401(self, resp: Response, **kwargs: Any) -> Response: return new_resp - def warn_on_401(self, resp: Response, **kwargs: Any) -> None: + def warn_on_401(self, resp: Response | RequestsResponse, **kwargs: Any) -> None: """Response callback to warn about incorrect credentials.""" if resp.status_code == 401: logger.warning( @@ -379,7 +439,9 @@ def warn_on_401(self, resp: Response, **kwargs: Any) -> None: resp.request.url, ) - def save_credentials(self, resp: Response, **kwargs: Any) -> None: + def save_credentials( + self, resp: Response | RequestsResponse, **kwargs: Any + ) -> None: """Response callback to save credentials on success.""" keyring = get_keyring_provider() assert keyring is not None, "should never reach here without keyring" diff --git a/src/unearth/collector.py b/src/unearth/collector.py index 1b564be..73acacb 100644 --- a/src/unearth/collector.py +++ b/src/unearth/collector.py @@ -3,6 +3,7 @@ from __future__ import annotations import functools +import ipaddress import json import logging import mimetypes @@ -11,10 +12,8 @@ from typing import Iterable, NamedTuple from urllib import parse -from requests.models import Response - +from unearth.fetchers import Fetcher, Response from unearth.link import Link -from unearth.session import PyPISession from unearth.utils import is_archive_file, path_to_url SUPPORTED_CONTENT_TYPES = ( @@ -51,6 +50,45 @@ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None self.anchors.append(dict(attrs)) +def _compare_origin_part(allowed: str, actual: str) -> bool: + return allowed == "*" or allowed == actual + + +def is_secure_origin(fetcher: Fetcher, location: Link) -> bool: + """ + Determine if the origin is a trusted host. + + Args: + location (Link): The location to check. + """ + _, _, scheme = location.parsed.scheme.rpartition("+") + host, port = location.parsed.hostname or "", location.parsed.port + for secure_scheme, secure_host, secure_port in fetcher.iter_secure_origins(): + if not _compare_origin_part(secure_scheme, scheme): + continue + try: + addr = ipaddress.ip_address(host) + network = ipaddress.ip_network(secure_host) + except ValueError: + # Either addr or network is invalid + if not _compare_origin_part(secure_host, host): + continue + else: + if addr not in network: + continue + + if not _compare_origin_part(secure_port, "*" if port is None else str(port)): + continue + # We've got here, so all the parts match + return True + + logger.warning( + "Skipping %s for not being trusted, please add it to `trusted_hosts` list", + location.redacted, + ) + return False + + def parse_html_page(page: IndexPage) -> Iterable[Link]: """PEP 503 simple index API""" parser = IndexHTMLParser() @@ -111,7 +149,7 @@ def parse_json_response(page: IndexPage) -> Iterable[Link]: def collect_links_from_location( - session: PyPISession, location: Link, expand: bool = False + session: Fetcher, location: Link, expand: bool = False ) -> Iterable[Link]: """Collect package links from a remote URL or local path. @@ -141,7 +179,7 @@ def collect_links_from_location( @functools.lru_cache(maxsize=None) -def fetch_page(session: PyPISession, location: Link) -> IndexPage: +def fetch_page(session: Fetcher, location: Link) -> IndexPage: if location.is_vcs: raise LinkCollectError("It is a VCS link.") resp = _get_html_response(session, location) @@ -149,12 +187,12 @@ def fetch_page(session: PyPISession, location: Link) -> IndexPage: cache_text = " (from cache)" if from_cache else "" logger.debug("Fetching HTML page %s%s", location.redacted, cache_text) return IndexPage( - Link(resp.url), resp.content, resp.encoding, resp.headers["Content-Type"] + Link(str(resp.url)), resp.content, resp.encoding, resp.headers["Content-Type"] ) -def _collect_links_from_index(session: PyPISession, location: Link) -> Iterable[Link]: - if not session.is_secure_origin(location): +def _collect_links_from_index(session: Fetcher, location: Link) -> Iterable[Link]: + if not is_secure_origin(session, location): return [] try: page = fetch_page(session, location) @@ -173,7 +211,7 @@ def _is_html_file(file_url: str) -> bool: return mimetypes.guess_type(file_url, strict=False)[0] == "text/html" -def _get_html_response(session: PyPISession, location: Link) -> Response: +def _get_html_response(session: Fetcher, location: Link) -> Response: if is_archive_file(location.filename): # If the URL looks like a file, send a HEAD request to ensure # the link is an HTML page to avoid downloading a large file. @@ -199,7 +237,7 @@ def _get_html_response(session: PyPISession, location: Link) -> Response: return resp -def _ensure_index_response(session: PyPISession, location: Link) -> None: +def _ensure_index_response(session: Fetcher, location: Link) -> None: if location.parsed.scheme not in {"http", "https"}: raise LinkCollectError( "NotHTTP: the file looks like an archive but its content-type " @@ -212,7 +250,10 @@ def _ensure_index_response(session: PyPISession, location: Link) -> None: def _check_for_status(resp: Response) -> None: - reason = resp.reason + if hasattr(resp, "reason"): + reason = resp.reason + else: + reason = resp.reason_phrase if isinstance(reason, bytes): try: diff --git a/src/unearth/evaluator.py b/src/unearth/evaluator.py index e4d05bb..4930e65 100644 --- a/src/unearth/evaluator.py +++ b/src/unearth/evaluator.py @@ -20,8 +20,8 @@ parse_wheel_filename, ) from packaging.version import InvalidVersion, Version -from requests import Session +from unearth.fetchers import Fetcher from unearth.link import Link from unearth.pep425tags import get_supported from unearth.utils import ( @@ -308,10 +308,10 @@ def evaluate_package( return True -def _get_hash(link: Link, hash_name: str, session: Session) -> str: +def _get_hash(link: Link, hash_name: str, session: Fetcher) -> str: hasher = hashlib.new(hash_name) - with session.get(link.normalized, stream=True) as resp: - for chunk in resp.iter_content(chunk_size=1024 * 8): + with session.get_stream(link.normalized) as resp: + for chunk in resp.iter_bytes(chunk_size=1024 * 8): hasher.update(chunk) digest = hasher.hexdigest() if not link.hashes: @@ -321,7 +321,7 @@ def _get_hash(link: Link, hash_name: str, session: Session) -> str: def validate_hashes( - package: Package, hashes: dict[str, list[str]], session: Session + package: Package, hashes: dict[str, list[str]], session: Fetcher ) -> bool: if not hashes: return True diff --git a/src/unearth/fetchers/__init__.py b/src/unearth/fetchers/__init__.py new file mode 100644 index 0000000..67ad679 --- /dev/null +++ b/src/unearth/fetchers/__init__.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import ContextManager, Iterable, Iterator, Mapping, Protocol + +from unearth.fetchers.sync import PyPIClient as PyPIClient + +DEFAULT_MAX_RETRIES = 5 +DEFAULT_SECURE_ORIGINS = [ + ("https", "*", "*"), + ("wss", "*", "*"), + ("*", "localhost", "*"), + ("*", "127.0.0.0/8", "*"), + ("*", "::1/128", "*"), + ("file", "*", "*"), +] + + +class Response(Protocol): + status_code: int + headers: Mapping[str, str] + encoding: str | None + url: str | None + + @property + def content(self) -> bytes: ... + + def json(self) -> dict: ... + + def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: ... + + @property + def reason_phrase(self) -> str: ... + + def raise_for_status(self) -> None: ... + + +class Fetcher(Protocol): + def get( + self, url: str, *, headers: Mapping[str, str] | None = None + ) -> Response: ... + + def head( + self, url: str, *, headers: Mapping[str, str] | None = None + ) -> Response: ... + + def get_stream( + self, url: str, *, headers: Mapping[str, str] | None = None + ) -> ContextManager[Response]: ... + + def __hash__(self) -> int: ... + + def iter_secure_origins(self) -> Iterable[tuple[str, str, str]]: ... diff --git a/src/unearth/fetchers/legacy.py b/src/unearth/fetchers/legacy.py new file mode 100644 index 0000000..5caffeb --- /dev/null +++ b/src/unearth/fetchers/legacy.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import contextlib +import email.utils +import io +import logging +import mimetypes +import os +import warnings +from pathlib import Path +from typing import Any, Iterable, Iterator, cast + +import urllib3 + +try: + from requests import Session, adapters +except ModuleNotFoundError: + raise ModuleNotFoundError( + "requests is required to use PyPISession, please install `unearth[requests]`" + ) from None +from requests.models import PreparedRequest, Response + +from unearth.fetchers import DEFAULT_MAX_RETRIES, DEFAULT_SECURE_ORIGINS +from unearth.link import Link +from unearth.utils import build_url_from_netloc, parse_netloc + +logger = logging.getLogger(__name__) + + +class InsecureMixin: + def cert_verify(self, conn, url, verify, cert): + return super().cert_verify(conn, url, verify=False, cert=cert) + + def send(self, request, *args, **kwargs): + with warnings.catch_warnings(): + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + return super().send(request, *args, **kwargs) + + +class InsecureHTTPAdapter(InsecureMixin, adapters.HTTPAdapter): + pass + + +class LocalFSAdapter(adapters.BaseAdapter): + def send(self, request: PreparedRequest, *args: Any, **kwargs: Any) -> Response: + link = Link(cast(str, request.url)) + path = link.file_path + resp = Response() + resp.status_code = 200 + resp.url = cast(str, request.url) + resp.request = request + + try: + stats = os.stat(path) + except OSError as exc: + # format the exception raised as a io.BytesIO object, + # to return a better error message: + resp.status_code = 404 + resp.reason = type(exc).__name__ + resp.raw = io.BytesIO(f"{resp.reason}: {exc}".encode("utf8")) + else: + modified = email.utils.formatdate(stats.st_mtime, usegmt=True) + content_type = mimetypes.guess_type(path)[0] or "text/plain" + resp.headers.update( + { + "Content-Type": content_type, + "Content-Length": str(stats.st_size), + "Last-Modified": modified, + } + ) + + resp.raw = open(path, "rb") + resp.close = resp.raw.close # type: ignore[method-assign] + + return resp + + def close(self) -> None: + pass + + +class PyPISession(Session): + """ + A session with caching enabled and specific hosts trusted. + + Args: + retries: The number of retries to attempt. + trusted_hosts: The hosts to trust. + ca_certificates: The path to a file where the certificates for + CAs reside. These are used when verifying the host + certificates of the index servers. When left unset, the + default certificates of the requests library will be used. + """ + + #: The adapter class to use for secure connections. + secure_adapter_cls = adapters.HTTPAdapter + #: The adapter class to use for insecure connections. + insecure_adapter_cls = InsecureHTTPAdapter + + def __init__( + self, + *, + index_urls: Iterable[str] = (), + retries: int = DEFAULT_MAX_RETRIES, + trusted_hosts: Iterable[str] = (), + ca_certificates: Path | None = None, + timeout: float | tuple[float, float] | urllib3.Timeout = 10, + ) -> None: + super().__init__() + + retry = urllib3.Retry( + total=retries, + # A 500 may indicate transient error in Amazon S3 + # A 520 or 527 - may indicate transient error in CloudFlare + status_forcelist=[500, 503, 520, 527], + backoff_factor=0.25, + ) + self._insecure_adapter = self.insecure_adapter_cls(max_retries=retry) + secure_adapter = self.secure_adapter_cls(max_retries=retry) + + self.mount("https://", secure_adapter) + self.mount("http://", self._insecure_adapter) + self.mount("file://", LocalFSAdapter()) + + self.timeout = timeout + self._trusted_host_ports: set[tuple[str, int | None]] = set() + + for host in trusted_hosts: + self._add_trusted_host(host) + + if ca_certificates is not None: + self.set_ca_certificates(ca_certificates) + + def send(self, request: PreparedRequest, **kwargs: Any) -> Response: + if kwargs.get("timeout") is None: + kwargs["timeout"] = self.timeout + return super().send(request, **kwargs) + + def set_ca_certificates(self, cert_file: Path): + """ + Set one or multiple certificate authorities which sign the + server's certs. + """ + self.verify = str(cert_file) + + def _add_trusted_host(self, host: str) -> None: + """Trust the given host by not verifying the SSL certificate.""" + hostname, port = parse_netloc(host) + self._trusted_host_ports.add((hostname, port)) + for scheme in ("https", "http"): + url = build_url_from_netloc(host, scheme=scheme) + self.mount(url + "/", self._insecure_adapter) + if port is None: + # Allow all ports for this host + self.mount(url + ":", self._insecure_adapter) + + def iter_secure_origins(self) -> Iterable[tuple[str, str, str]]: + yield from DEFAULT_SECURE_ORIGINS + for host, port in self._trusted_host_ports: + yield ("*", host, "*" if port is None else str(port)) + + @contextlib.contextmanager + def get_stream( + self, url: str, *, headers: dict[str, str] | None = None + ) -> Iterator[Response]: + """Stream the response from the given URL.""" + with self.get(url, headers=headers, stream=True) as resp: + resp.iter_bytes = resp.iter_content # type: ignore[attr-defined] + yield resp diff --git a/src/unearth/fetchers/sync.py b/src/unearth/fetchers/sync.py new file mode 100644 index 0000000..ba7f691 --- /dev/null +++ b/src/unearth/fetchers/sync.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import email +import mimetypes +import os +from typing import TYPE_CHECKING + +import httpx +from httpx._config import DEFAULT_LIMITS +from httpx._content import IteratorByteStream + +from unearth.link import Link +from unearth.utils import parse_netloc + +if TYPE_CHECKING: + from typing import Any, ContextManager, Iterable, Mapping + + from httpx._types import CertTypes, TimeoutTypes, VerifyTypes + + +def is_absolute_url(self) -> bool: + return self._uri_reference.scheme or self._uri_reference.host + + +# Patch the is_absolute_url method of httpx.URL to allow file:// URLs +httpx.URL.is_absolute_url = property(is_absolute_url) + + +class LocalFSTransport(httpx.BaseTransport): + def handle_request(self, request: httpx.Request) -> httpx.Response: + link = Link(str(request.url)) + path = link.file_path + if request.method != "GET": + return httpx.Response(status_code=405) + + try: + stats = os.stat(path) + except OSError as exc: + # format the exception raised as a io.BytesIO object, + # to return a better error message: + return httpx.Response(status_code=404, text=f"{type(exc).__name__}: {exc}") + else: + modified = email.utils.formatdate(stats.st_mtime, usegmt=True) + content_type = mimetypes.guess_type(path)[0] or "text/plain" + headers = { + "Content-Type": content_type, + "Content-Length": str(stats.st_size), + "Last-Modified": modified, + } + return httpx.Response( + status_code=200, + headers=headers, + stream=IteratorByteStream(path.open("rb")), + ) + + +class PyPIClient(httpx.Client): + """ + A :class:`httpx.Client` subclass that supports file:// URLs and trusted hosts configuration. + + Args: + trusted_hosts: A list of trusted hosts. If a host is trusted, the client will not verify the SSL certificate. + \\**kwargs: Additional keyword arguments to pass to the :class:`httpx.Client` constructor. + """ + + def __init__( + self, + *, + trusted_hosts: Iterable[str] = (), + verify: VerifyTypes = True, + cert: CertTypes | None = None, + http1: bool = True, + http2: bool = False, + limits: httpx.Limits = DEFAULT_LIMITS, + trust_env: bool = True, + timeout: TimeoutTypes = 10.0, + **kwargs: Any, + ) -> None: + self._trusted_host_ports: set[tuple[str, int | None]] = set() + # Due to lack of ability of retry behavior in httpx, we don't support it for simplicity + insecure_transport = httpx.HTTPTransport( + verify=False, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + ) + + mounts: dict[str, httpx.BaseTransport] = {"file://": LocalFSTransport()} + for host in trusted_hosts: + hostname, port = parse_netloc(host) + self._trusted_host_ports.add((hostname, port)) + mounts[f"all://{host}"] = insecure_transport + + mounts.update(kwargs.pop("mounts", {})) + + super().__init__( + verify=verify, + cert=cert, + http1=http1, + http2=http2, + limits=limits, + trust_env=trust_env, + timeout=timeout, + mounts=mounts, + **kwargs, + ) + + def get_stream( + self, url: str, *, headers: Mapping[str, str] | None = None + ) -> ContextManager[httpx.Response]: + return self.stream("GET", url, headers=headers) + + def iter_secure_origins(self) -> Iterable[tuple[str, str, str]]: + from unearth.fetchers import DEFAULT_SECURE_ORIGINS + + yield from DEFAULT_SECURE_ORIGINS + for host, port in self._trusted_host_ports: + yield ("*", host, "*" if port is None else str(port)) diff --git a/src/unearth/finder.py b/src/unearth/finder.py index c4ebf94..a703e23 100644 --- a/src/unearth/finder.py +++ b/src/unearth/finder.py @@ -7,15 +7,17 @@ import itertools import os import pathlib +import warnings from datetime import datetime from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Iterable, NamedTuple, Sequence +from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Sequence from urllib.parse import urljoin import packaging.requirements from packaging.utils import BuildTag, canonicalize_name, parse_wheel_filename from packaging.version import parse as parse_version +from unearth.auth import MultiDomainBasicAuth from unearth.collector import collect_links_from_location from unearth.evaluator import ( Evaluator, @@ -26,9 +28,10 @@ is_equality_specifier, validate_hashes, ) +from unearth.fetchers import Fetcher +from unearth.fetchers.sync import PyPIClient from unearth.link import Link from unearth.preparer import noop_download_reporter, noop_unpack_reporter, unpack_link -from unearth.session import PyPISession from unearth.utils import LazySequence if TYPE_CHECKING: @@ -44,6 +47,21 @@ class Source(TypedDict): Source = dict +def _check_legacy_session(session: Any) -> None: + try: + from requests import Session + except ModuleNotFoundError: + return + + if isinstance(session, Session): + warnings.warn( + "The legacy requests.Session is used, which is deprecated and will be removed in the next release. " + "Please use `httpx.Client` instead. ", + DeprecationWarning, + stacklevel=2, + ) + + class BestMatch(NamedTuple): """The best match for a package.""" @@ -80,7 +98,7 @@ class PackageFinder: def __init__( self, - session: PyPISession | None = None, + session: Fetcher | None = None, *, index_urls: Iterable[str] = (), find_links: Iterable[str] = (), @@ -108,6 +126,7 @@ def __init__( self.only_binary = {canonicalize_name(name) for name in only_binary} self.prefer_binary = {canonicalize_name(name) for name in prefer_binary} self.trusted_hosts = trusted_hosts + _check_legacy_session(session) self._session = session self.respect_source_order = respect_source_order self.verbosity = verbosity @@ -118,14 +137,13 @@ def __init__( } @property - def session(self) -> PyPISession: + def session(self) -> Fetcher: if self._session is None: index_urls = [ source["url"] for source in self.sources if source["type"] == "index" ] - session = PyPISession( - index_urls=index_urls, trusted_hosts=self.trusted_hosts - ) + session = PyPIClient(trusted_hosts=self.trusted_hosts) + session.auth = MultiDomainBasicAuth(index_urls=index_urls) atexit.register(session.close) self._session = session return self._session diff --git a/src/unearth/preparer.py b/src/unearth/preparer.py index 18714c2..e50b7f6 100644 --- a/src/unearth/preparer.py +++ b/src/unearth/preparer.py @@ -14,9 +14,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Iterable, cast -from requests import HTTPError, Session +import httpx from unearth.errors import HashMismatchError, UnpackError +from unearth.fetchers import Fetcher from unearth.link import Link from unearth.utils import ( BZ2_EXTENSIONS, @@ -29,6 +30,14 @@ ) from unearth.vcs import vcs_support +HTTPErrors: tuple[type[Exception], ...] = (httpx.HTTPError,) +try: + from requests import HTTPError + + HTTPErrors += (HTTPError,) +except ModuleNotFoundError: + pass + if TYPE_CHECKING: from typing import Protocol @@ -282,7 +291,7 @@ def _untar_archive(filename: Path, location: Path, reporter: UnpackReporter) -> def unpack_link( - session: Session, + session: Fetcher, link: Link, download_dir: Path, location: Path, @@ -296,7 +305,7 @@ def unpack_link( The link can be a VCS link or a file link. Args: - session (Session): the requests session + session (Fetcher): the requests session link (Link): the link to unpack download_dir (Path): the directory to download the file to location (Path): the destination directory @@ -326,10 +335,10 @@ def unpack_link( # A remote artfiact link, check the download dir first artifact = download_dir / link.filename if not _check_downloaded(artifact, hashes): - with session.get(link.normalized, stream=True) as resp: + with session.get_stream(link.normalized) as resp: try: resp.raise_for_status() - except HTTPError as e: + except HTTPErrors as e: raise UnpackError(f"Download failed: {e}") from None try: total = int(resp.headers["Content-Length"]) @@ -343,7 +352,7 @@ def unpack_link( with artifact.open("wb") as f: callback = functools.partial(download_reporter, link, total=total) for chunk in iter_with_callback( - resp.iter_content(chunk_size=READ_CHUNK_SIZE), + resp.iter_bytes(chunk_size=READ_CHUNK_SIZE), callback, stepper=len, ): diff --git a/src/unearth/session.py b/src/unearth/session.py index 04fb91c..779a950 100644 --- a/src/unearth/session.py +++ b/src/unearth/session.py @@ -1,206 +1,14 @@ -from __future__ import annotations - -import email.utils -import io -import ipaddress -import logging -import mimetypes -import os import warnings -from pathlib import Path -from typing import Any, Iterable, cast - -import requests.adapters -import urllib3 -from requests import Session -from requests.models import PreparedRequest, Response - -from unearth.auth import MultiDomainBasicAuth -from unearth.link import Link -from unearth.utils import build_url_from_netloc, parse_netloc - -logger = logging.getLogger(__name__) - -DEFAULT_MAX_RETRIES = 5 -DEFAULT_SECURE_ORIGINS = [ - ("https", "*", "*"), - ("wss", "*", "*"), - ("*", "localhost", "*"), - ("*", "127.0.0.0/8", "*"), - ("*", "::1/128", "*"), - ("file", "*", "*"), -] - - -def _compare_origin_part(allowed: str, actual: str) -> bool: - return allowed == "*" or allowed == actual - - -class InsecureMixin: - def cert_verify(self, conn, url, verify, cert): - return super().cert_verify(conn, url, verify=False, cert=cert) - - def send(self, request, *args, **kwargs): - with warnings.catch_warnings(): - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - return super().send(request, *args, **kwargs) - - -class InsecureHTTPAdapter(InsecureMixin, requests.adapters.HTTPAdapter): - pass - - -class LocalFSAdapter(requests.adapters.BaseAdapter): - def send(self, request: PreparedRequest, *args: Any, **kwargs: Any) -> Response: - link = Link(cast(str, request.url)) - path = link.file_path - resp = Response() - resp.status_code = 200 - resp.url = cast(str, request.url) - resp.request = request - - try: - stats = os.stat(path) - except OSError as exc: - # format the exception raised as a io.BytesIO object, - # to return a better error message: - resp.status_code = 404 - resp.reason = type(exc).__name__ - resp.raw = io.BytesIO(f"{resp.reason}: {exc}".encode("utf8")) - else: - modified = email.utils.formatdate(stats.st_mtime, usegmt=True) - content_type = mimetypes.guess_type(path)[0] or "text/plain" - resp.headers.update( - { - "Content-Type": content_type, - "Content-Length": str(stats.st_size), - "Last-Modified": modified, - } - ) - - resp.raw = open(path, "rb") - resp.close = resp.raw.close # type: ignore[method-assign] - - return resp - - def close(self) -> None: - pass - - -class PyPISession(Session): - """ - A session with caching enabled and specific hosts trusted. - - Args: - index_urls: The PyPI index URLs to use. - retries: The number of retries to attempt. - trusted_hosts: The hosts to trust. - ca_certificates: The path to a file where the certificates for - CAs reside. These are used when verifying the host - certificates of the index servers. When left unset, the - default certificates of the requests library will be used. - """ - - #: The adapter class to use for secure connections. - secure_adapter_cls = requests.adapters.HTTPAdapter - #: The adapter class to use for insecure connections. - insecure_adapter_cls = InsecureHTTPAdapter - - def __init__( - self, - *, - index_urls: Iterable[str] = (), - retries: int = DEFAULT_MAX_RETRIES, - trusted_hosts: Iterable[str] = (), - ca_certificates: Path | None = None, - timeout: float | tuple[float, float] | urllib3.Timeout = 10, - ) -> None: - super().__init__() - - retry = urllib3.Retry( - total=retries, - # A 500 may indicate transient error in Amazon S3 - # A 520 or 527 - may indicate transient error in CloudFlare - status_forcelist=[500, 503, 520, 527], - backoff_factor=0.25, - ) - self._insecure_adapter = self.insecure_adapter_cls(max_retries=retry) - secure_adapter = self.secure_adapter_cls(max_retries=retry) - - self.mount("https://", secure_adapter) - self.mount("http://", self._insecure_adapter) - self.mount("file://", LocalFSAdapter()) - - self.timeout = timeout - self._trusted_host_ports: set[tuple[str, int | None]] = set() - - for host in trusted_hosts: - self.add_trusted_host(host) - self.auth = MultiDomainBasicAuth(index_urls=index_urls) - - if ca_certificates is not None: - self.set_ca_certificates(ca_certificates) - - def send(self, request: PreparedRequest, **kwargs: Any) -> Response: - if kwargs.get("timeout") is None: - kwargs["timeout"] = self.timeout - return super().send(request, **kwargs) - - def set_ca_certificates(self, cert_file: Path): - """ - Set one or multiple certificate authorities which sign the - server's certs. - """ - self.verify = str(cert_file) - - def add_trusted_host(self, host: str) -> None: - """Trust the given host by not verifying the SSL certificate.""" - hostname, port = parse_netloc(host) - self._trusted_host_ports.add((hostname, port)) - for scheme in ("https", "http"): - url = build_url_from_netloc(host, scheme=scheme) - self.mount(url + "/", self._insecure_adapter) - if port is None: - # Allow all ports for this host - self.mount(url + ":", self._insecure_adapter) - - def iter_secure_origins(self) -> Iterable[tuple[str, str, str]]: - yield from DEFAULT_SECURE_ORIGINS - for host, port in self._trusted_host_ports: - yield ("*", host, "*" if port is None else str(port)) - - def is_secure_origin(self, location: Link) -> bool: - """ - Determine if the origin is a trusted host. - - Args: - location (Link): The location to check. - """ - _, _, scheme = location.parsed.scheme.rpartition("+") - host, port = location.parsed.hostname or "", location.parsed.port - for secure_scheme, secure_host, secure_port in self.iter_secure_origins(): - if not _compare_origin_part(secure_scheme, scheme): - continue - try: - addr = ipaddress.ip_address(host) - network = ipaddress.ip_network(secure_host) - except ValueError: - # Either addr or network is invalid - if not _compare_origin_part(secure_host, host): - continue - else: - if addr not in network: - continue - - if not _compare_origin_part( - secure_port, "*" if port is None else str(port) - ): - continue - # We've got here, so all the parts match - return True - logger.warning( - "Skipping %s for not being trusted, please add it to `trusted_hosts` list", - location.redacted, - ) - return False +from .fetchers.legacy import ( # noqa: F401 + InsecureHTTPAdapter, + InsecureMixin, + PyPISession, +) + +warnings.warn( + "unearth.session has been deprecated and will be removed " + "in the next minor release. Please import from unearth.fetchers instead.", + DeprecationWarning, + stacklevel=1, +) diff --git a/src/unearth/utils.py b/src/unearth/utils.py index 84065ac..2b765e2 100644 --- a/src/unearth/utils.py +++ b/src/unearth/utils.py @@ -283,3 +283,25 @@ def commonprefix(*m: str) -> str: if c != s2[i]: return s1[:i] return s1 + + +def get_netrc_auth(url: str) -> tuple[str, str] | None: + """Get the auth for the given url from the netrc file.""" + try: + from netrc import netrc + except ImportError: + return None + from httpx import URL + + hostname = URL(url).host + + try: + authenticator = netrc(os.getenv("NETRC")) + except (FileNotFoundError, TypeError): + return None + info = authenticator.authenticators(hostname) + + if info is None: + return None + + return info[0], info[2] diff --git a/tests/conftest.py b/tests/conftest.py index 43c965b..70c2bc3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,23 @@ """Configuration for the pytest test suite.""" +from __future__ import annotations + import os from ssl import SSLContext -from unittest import mock +from typing import TYPE_CHECKING import flask import pytest import trustme +from httpx import WSGITransport from wsgiadapter import WSGIAdapter as _WSGIAdapter from tests.fixtures.app import BASE_DIR, create_app -from unearth.session import InsecureMixin, PyPISession +from unearth.fetchers import PyPIClient +from unearth.fetchers.legacy import InsecureMixin, PyPISession + +if TYPE_CHECKING: + from typing import Literal class WSGIAdapter(_WSGIAdapter): @@ -58,16 +65,7 @@ def fixtures_dir(): @pytest.fixture() def pypi(): - wsgi_app = create_app() - with mock.patch.object( - PyPISession, "insecure_adapter_cls", return_value=InsecureWSGIAdapter(wsgi_app) - ): - with mock.patch.object( - PyPISession, - "secure_adapter_cls", - return_value=WSGIAdapter(wsgi_app), - ): - yield wsgi_app + return create_app() @pytest.fixture() @@ -92,13 +90,39 @@ def require_basic_auth(): return pypi +@pytest.fixture(params=["sync", "legacy"]) +def fetcher_type(request) -> Literal["sync", "legacy"]: + return request.param + + +@pytest.fixture() +def session(fetcher_type): + if fetcher_type == "sync": + client = PyPIClient() + else: + client = PyPISession() + try: + yield client + finally: + client.close() + + @pytest.fixture() -def session(): - s = PyPISession() +def pypi_session(pypi, fetcher_type, mocker): + if fetcher_type == "sync": + client = PyPIClient(transport=WSGITransport(pypi)) + else: + mocker.patch.object( + PyPISession, "insecure_adapter_cls", return_value=InsecureWSGIAdapter(pypi) + ) + mocker.patch.object( + PyPISession, "secure_adapter_cls", return_value=WSGIAdapter(pypi) + ) + client = PyPISession() try: - yield s + yield client finally: - s.close() + client.close() @pytest.fixture(params=["html", "json"]) diff --git a/tests/fixtures/app.py b/tests/fixtures/app.py index ea5edb0..65fabd7 100644 --- a/tests/fixtures/app.py +++ b/tests/fixtures/app.py @@ -27,7 +27,7 @@ @bp.route("/files/") def package_file(path): - return flask.send_from_directory(BASE_DIR / "files", path) + return flask.send_from_directory(BASE_DIR / "files", path, as_attachment=True) @bp.route("/simple/") diff --git a/tests/test_collector.py b/tests/test_collector.py index 66d1605..d08bf7c 100644 --- a/tests/test_collector.py +++ b/tests/test_collector.py @@ -4,37 +4,40 @@ from unearth.link import Link -def test_collector_skip_insecure_hosts(pypi, session, caplog): +def test_collector_skip_insecure_hosts(pypi_session, caplog): collected = list( - collect_links_from_location(session, Link("http://insecure.com/simple/click")) + collect_links_from_location( + pypi_session, Link("http://insecure.com/simple/click") + ) ) assert not collected assert "not being trusted" in caplog.records[0].message -def test_collector_skip_vcs_link(pypi, session, caplog): +def test_collector_skip_vcs_link(pypi_session, caplog): collected = list( collect_links_from_location( - session, Link("git+https://github.com/pallets/click.git") + pypi_session, Link("git+https://github.com/pallets/click.git") ) ) assert not collected assert "It is a VCS link" in caplog.records[0].message -def test_collect_links_from_404_page(pypi, session): +def test_collect_links_from_404_page(pypi_session): collected = list( collect_links_from_location( - session, Link("https://test.pypi.org/simple/not-found") + pypi_session, Link("https://test.pypi.org/simple/not-found") ) ) assert not collected -def test_skip_non_html_archive(pypi, session, caplog): +def test_skip_non_html_archive(pypi_session, caplog): collected = list( collect_links_from_location( - session, Link("https://test.pypi.org/files/click-8.1.3-py3-none-any.whl") + pypi_session, + Link("https://test.pypi.org/files/click-8.1.3-py3-none-any.whl"), ) ) assert not collected @@ -42,10 +45,10 @@ def test_skip_non_html_archive(pypi, session, caplog): @pytest.mark.usefixtures("content_type") -def test_collect_links_from_index_page(pypi, session): +def test_collect_links_from_index_page(pypi_session): collected = sorted( collect_links_from_location( - session, Link("https://test.pypi.org/simple/click") + pypi_session, Link("https://test.pypi.org/simple/click") ), key=lambda link: link.filename, ) @@ -54,10 +57,10 @@ def test_collect_links_from_index_page(pypi, session): @pytest.mark.parametrize("filename", ["findlinks", "findlinks/index.html"]) -def test_collect_links_from_local_file(pypi, session, fixtures_dir, filename): +def test_collect_links_from_local_file(pypi_session, fixtures_dir, filename): link = Link.from_path(fixtures_dir / filename) collected = sorted( - collect_links_from_location(session, link), + collect_links_from_location(pypi_session, link), key=lambda link: link.filename, ) assert [link.filename for link in collected] == [ @@ -68,10 +71,10 @@ def test_collect_links_from_local_file(pypi, session, fixtures_dir, filename): ] -def test_collect_links_from_local_dir_expand(pypi, session, fixtures_dir): +def test_collect_links_from_local_dir_expand(pypi_session, fixtures_dir): link = Link.from_path(fixtures_dir / "findlinks") collected = sorted( - collect_links_from_location(session, link, expand=True), + collect_links_from_location(pypi_session, link, expand=True), key=lambda link: link.filename, ) assert [link.filename for link in collected] == [ diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py index f5804b8..0231687 100644 --- a/tests/test_evaluator.py +++ b/tests/test_evaluator.py @@ -24,17 +24,6 @@ ] -@pytest.fixture() -def session(): - from unearth.session import PyPISession - - sess = PyPISession() - try: - yield sess - finally: - sess.close() - - @pytest.mark.parametrize("link", BINARY_LINKS) def test_only_binary_is_allowed(link): format_control = FormatControl(only_binary={"foo"}) @@ -175,7 +164,9 @@ def test_evaluate_against_missing_version(link): def test_evaluate_against_allowed_hashes(url, match, session): package = Package("click", "8.1.3", link=Link(url)) result = validate_hashes( - package, {"sha256": ["1234567890abcdef", "fedcba0987654321"]}, session=session + package, + {"sha256": ["1234567890abcdef", "fedcba0987654321"]}, + session=session, ) assert result is match @@ -200,7 +191,7 @@ def test_evaluate_allow_all_hashes(url, session): "https://test.pypi.org/files/click-8.1.3-py3-none-any.whl#md5=1111222", ], ) -def test_retrieve_hash_from_internet(pypi, session, url): +def test_retrieve_hash_from_internet(pypi_session, url): link = Link(url) package = Package("click", "8.1.3", link=link) assert validate_hashes( @@ -210,7 +201,7 @@ def test_retrieve_hash_from_internet(pypi, session, url): "bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48" ] }, - session=session, + session=pypi_session, ) hash_name, hash = next(iter(link.hashes.items())) assert hash_name == "sha256" diff --git a/tests/test_finder.py b/tests/test_finder.py index a2c1b55..0e8ea52 100644 --- a/tests/test_finder.py +++ b/tests/test_finder.py @@ -6,7 +6,7 @@ from unearth.evaluator import TargetPython from unearth.finder import PackageFinder -pytestmark = pytest.mark.usefixtures("pypi", "content_type") +pytestmark = pytest.mark.usefixtures("content_type") DEFAULT_INDEX_URL = "https://pypi.org/simple/" @@ -36,16 +36,18 @@ ), ], ) -def test_find_most_matching_wheel(session, target_python, filename): +def test_find_most_matching_wheel(pypi_session, target_python, filename): finder = PackageFinder( - session=session, index_urls=[DEFAULT_INDEX_URL], target_python=target_python + session=pypi_session, + index_urls=[DEFAULT_INDEX_URL], + target_python=target_python, ) assert finder.find_best_match("black").best.link.filename == filename -def test_find_package_with_format_control(session): +def test_find_package_with_format_control(pypi_session): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], target_python=TargetPython( (3, 9), abis=["cp39"], impl="cp", platforms=["win_amd64"] @@ -60,9 +62,9 @@ def test_find_package_with_format_control(session): ) -def test_find_package_no_binary_for_all(session): +def test_find_package_no_binary_for_all(pypi_session): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], target_python=TargetPython( (3, 9), abis=["cp39"], impl="cp", platforms=["win_amd64"] @@ -73,9 +75,9 @@ def test_find_package_no_binary_for_all(session): assert finder.find_best_match("first").best.link.filename == "first-2.0.2.tar.gz" -def test_find_package_prefer_binary(session): +def test_find_package_prefer_binary(pypi_session): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], target_python=TargetPython( (3, 9), abis=["cp39"], impl="cp", platforms=["win_amd64"] @@ -88,9 +90,9 @@ def test_find_package_prefer_binary(session): ) -def test_find_package_with_hash_allowance(session): +def test_find_package_with_hash_allowance(pypi_session): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], target_python=TargetPython( (3, 9), abis=["cp39"], impl="cp", platforms=["win_amd64"] @@ -110,9 +112,9 @@ def test_find_package_with_hash_allowance(session): @pytest.mark.parametrize("ignore_compat", [True, False]) -def test_find_package_ignoring_compatibility(session, ignore_compat): +def test_find_package_ignoring_compatibility(pypi_session, ignore_compat): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], target_python=TargetPython( (3, 9), abis=["cp39"], impl="cp", platforms=["win_amd64"] @@ -123,9 +125,9 @@ def test_find_package_ignoring_compatibility(session, ignore_compat): assert len(all_available) == (6 if ignore_compat else 3) -def test_find_package_with_version_specifier(session): +def test_find_package_with_version_specifier(pypi_session): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], ignore_compatibility=True, ) @@ -136,9 +138,9 @@ def test_find_package_with_version_specifier(session): assert len(matches) == 0 -def test_find_package_allowing_prereleases(session): +def test_find_package_allowing_prereleases(pypi_session): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], ignore_compatibility=True, ) @@ -154,9 +156,9 @@ def test_find_package_allowing_prereleases(session): assert len(matches) == 0 -def test_find_requirement_with_link(session): +def test_find_requirement_with_link(pypi_session): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], ignore_compatibility=True, ) @@ -167,10 +169,10 @@ def test_find_requirement_with_link(session): assert matches[0].link.normalized == "https://pypi.org/files/first-2.0.2.tar.gz" -def test_find_requirement_preference(session, fixtures_dir): +def test_find_requirement_preference(pypi_session, fixtures_dir): find_link = Link.from_path(fixtures_dir / "findlinks/index.html") finder = PackageFinder( - session=session, index_urls=[DEFAULT_INDEX_URL], ignore_compatibility=True + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], ignore_compatibility=True ) finder.add_find_links(find_link.normalized) best = finder.find_best_match("first").best @@ -178,10 +180,10 @@ def test_find_requirement_preference(session, fixtures_dir): assert best.link.comes_from == find_link.normalized -def test_find_requirement_preference_respect_source_order(session, fixtures_dir): +def test_find_requirement_preference_respect_source_order(pypi_session, fixtures_dir): find_link = Link.from_path(fixtures_dir / "findlinks/index.html") finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], ignore_compatibility=True, respect_source_order=True, @@ -192,9 +194,9 @@ def test_find_requirement_preference_respect_source_order(session, fixtures_dir) assert best.link.comes_from == "https://pypi.org/simple/first/" -def test_download_package_file(session, tmp_path): +def test_download_package_file(pypi_session, tmp_path): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], ignore_compatibility=True, ) @@ -231,9 +233,9 @@ def unpack_reporter(filename, completed, total): assert filename == downloaded -def test_exclude_newer_than(session, content_type): +def test_exclude_newer_than(pypi_session, content_type): finder = PackageFinder( - session=session, + session=pypi_session, index_urls=[DEFAULT_INDEX_URL], ignore_compatibility=True, exclude_newer_than=datetime.datetime( diff --git a/tests/test_session.py b/tests/test_session.py index b321606..06999e1 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,12 +1,26 @@ import logging -from unittest import mock import pytest from unearth.auth import MultiDomainBasicAuth +from unearth.collector import is_secure_origin +from unearth.fetchers.legacy import PyPISession +from unearth.fetchers.sync import PyPIClient from unearth.link import Link +@pytest.fixture +def private_session(fetcher_type): + if fetcher_type == "sync": + session = PyPIClient(trusted_hosts=["example.org", "192.168.0.1:8080"]) + else: + session = PyPISession(trusted_hosts=["example.org", "192.168.0.1:8080"]) + try: + yield session + finally: + session.close() + + @pytest.mark.parametrize( "url, is_secure", [ @@ -25,73 +39,80 @@ ("http://192.168.0.1:8080/simple", True), ], ) -def test_session_is_secure_origin(session, url, is_secure): - for host in ["example.org", "192.168.0.1:8080"]: - session.add_trusted_host(host) - assert session.is_secure_origin(Link(url)) == is_secure +def test_session_is_secure_origin(private_session, url, is_secure): + assert is_secure_origin(private_session, Link(url)) == is_secure def test_session_with_selfsigned_ca( - httpserver, custom_certificate_authority, session, tmp_path + httpserver, custom_certificate_authority, fetcher_type, tmp_path ): ca_cert = tmp_path / "ca.crt" custom_certificate_authority.cert_pem.write_to_path(ca_cert) - session.set_ca_certificates(ca_cert) + if fetcher_type == "sync": + session = PyPIClient(verify=str(ca_cert)) + else: + session = PyPISession(ca_certificates=ca_cert) httpserver.expect_request("/").respond_with_json({}) - assert session.get(httpserver.url_for("/")).json() == {} + with session: + assert session.get(httpserver.url_for("/")).json() == {} -def test_session_auth_401_if_no_prompting(pypi_auth, session): - session.auth = MultiDomainBasicAuth(prompting=False) - resp = session.get("https://pypi.org/simple") +@pytest.mark.usefixtures("pypi_auth") +def test_session_auth_401_if_no_prompting(pypi_session): + pypi_session.auth = MultiDomainBasicAuth(prompting=False) + resp = pypi_session.get("https://pypi.org/simple") assert resp.status_code == 401 -def test_session_auth_from_source_urls(pypi_auth, session): - session.auth = MultiDomainBasicAuth( +@pytest.mark.usefixtures("pypi_auth") +def test_session_auth_from_source_urls(pypi_session): + pypi_session.auth = MultiDomainBasicAuth( prompting=False, index_urls=["https://test:password@pypi.org/simple"] ) - resp = session.get("https://pypi.org/simple/click") + resp = pypi_session.get("https://pypi.org/simple/click") assert resp.status_code == 200 assert not any(r.status_code == 401 for r in resp.history) -def test_session_auth_with_empty_password(pypi_auth, session, monkeypatch): +@pytest.mark.usefixtures("pypi_auth") +def test_session_auth_with_empty_password(pypi_session, monkeypatch): monkeypatch.setenv("PYPI_PASSWORD", "") - session.auth = MultiDomainBasicAuth( + pypi_session.auth = MultiDomainBasicAuth( prompting=False, index_urls=["https://test:@pypi.org/simple"] ) - resp = session.get("https://pypi.org/simple/click") + resp = pypi_session.get("https://pypi.org/simple/click") assert resp.status_code == 200 assert not any(r.status_code == 401 for r in resp.history) -def test_session_auth_from_prompting(pypi_auth, session): - with mock.patch.object( +@pytest.mark.usefixtures("pypi_auth") +def test_session_auth_from_prompting(pypi_session, mocker): + pypi_session.auth = MultiDomainBasicAuth(prompting=True) + mocker.patch.object( MultiDomainBasicAuth, "_prompt_for_password", return_value=("test", "password", False), - ): - session.auth = MultiDomainBasicAuth(prompting=True) - resp = session.get("https://pypi.org/simple/click") + ) + resp = pypi_session.get("https://pypi.org/simple/click") assert resp.status_code == 200 assert any(r.status_code == 401 for r in resp.history) - resp = session.get("https://pypi.org/simple/click") + resp = pypi_session.get("https://pypi.org/simple/click") assert resp.status_code == 200 assert not any(r.status_code == 401 for r in resp.history) -def test_session_auth_warn_agains_wrong_credentials(pypi_auth, session, caplog): +@pytest.mark.usefixtures("pypi_auth") +def test_session_auth_warn_agains_wrong_credentials(pypi_session, caplog, mocker): caplog.set_level(logging.WARNING) - with mock.patch.object( + mocker.patch.object( MultiDomainBasicAuth, "_prompt_for_password", return_value=("test", "incorrect", False), - ): - session.auth = MultiDomainBasicAuth(prompting=True) - resp = session.get("https://pypi.org/simple/click") + ) + pypi_session.auth = MultiDomainBasicAuth(prompting=True) + resp = pypi_session.get("https://pypi.org/simple/click") assert resp.status_code == 401 record = caplog.records[-1] assert record.levelname == "WARNING"