Skip to content

Commit

Permalink
break: postpone the validation of hashes (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostming authored Jul 18, 2023
1 parent e5b9e4d commit f0aea16
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 115 deletions.
18 changes: 0 additions & 18 deletions .copier-answers.yml

This file was deleted.

1 change: 1 addition & 0 deletions changelogithub.config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"types": {
"break": { "title": "💥 Breaking Changes" },
"feat": { "title": "🚀 Features" },
"fix": { "title": "🐞 Bug Fixes" },
"doc": { "title": "📝 Documentation" },
Expand Down
85 changes: 36 additions & 49 deletions src/unearth/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
import sys
from typing import Any
from urllib.parse import urlencode

import packaging.requirements
from packaging.specifiers import InvalidSpecifier, SpecifierSet
Expand Down Expand Up @@ -129,29 +128,26 @@ class Evaluator:
Args:
package_name (str): The links must match the package name
target_python (TargetPython): The links must match the target Python
hashes (dict[str, list[str]): The links must have the correct hashes
ignore_compatibility (bool): Whether to ignore the compatibility check
allow_yanked (bool): Whether to allow yanked candidates
format_control (bool): Format control flags
"""

package_name: str
session: Session
target_python: TargetPython = dc.field(default_factory=TargetPython)
hashes: dict[str, list[str]] = dc.field(default_factory=dict)
ignore_compatibility: bool = False
allow_yanked: bool = False
format_control: FormatControl = dc.field(default_factory=FormatControl)

def __post_init__(self) -> None:
self._canonical_name = canonicalize_name(self.package_name)

def _check_yanked(self, link: Link) -> None:
def check_yanked(self, link: Link) -> None:
if link.yank_reason is not None and not self.allow_yanked:
yank_reason = f"due to {link.yank_reason}" if link.yank_reason else ""
raise LinkMismatchError(f"Yanked {yank_reason}")

def _check_requires_python(self, link: Link) -> None:
def check_requires_python(self, link: Link) -> None:
if not self.ignore_compatibility and link.requires_python:
py_ver = self.target_python.py_ver or sys.version_info[:2]
py_version = ".".join(str(v) for v in py_ver)
Expand All @@ -171,54 +167,14 @@ def _check_requires_python(self, link: Link) -> None:
),
)

def _check_hashes(self, link: Link) -> None:
def hash_mismatch(
hash_name: str, given_hash: str, allowed_hashes: list[str]
) -> None:
raise LinkMismatchError(
f"Hash mismatch, expected: {allowed_hashes}\n"
f"got: {hash_name}:{given_hash}"
)

if not self.hashes:
return
link_hashes = link.hash_option
if link_hashes:
for hash_name, allowed_hashes in self.hashes.items():
if hash_name in link_hashes:
given_hash = link_hashes[hash_name][0]
if given_hash not in allowed_hashes:
hash_mismatch(hash_name, given_hash, allowed_hashes)
return

hash_name, allowed_hashes = next(iter(self.hashes.items()))
given_hash = self._get_hash(link, hash_name)
if given_hash not in allowed_hashes:
hash_mismatch(hash_name, given_hash, allowed_hashes)

def _get_hash(self, link: Link, hash_name: str) -> str:
resp = self.session.get(link.normalized, stream=True)
hasher = hashlib.new(hash_name)
for chunk in resp.iter_content(chunk_size=1024 * 8):
hasher.update(chunk)
digest = hasher.hexdigest()
# Store the hash on the link for future use
fragment_dict = link._fragment_dict
fragment_dict.pop(link.hash_name, None) # type: ignore
fragment_dict[hash_name] = digest
link.__dict__["parsed"] = link.parsed._replace(
fragment=urlencode(fragment_dict)
)
return digest

def evaluate_link(self, link: Link) -> Package | None:
"""
Evaluate the link and return the package if it matches or None if it doesn't.
"""
try:
self.format_control.check_format(link, self.package_name)
self._check_yanked(link)
self._check_requires_python(link)
self.check_yanked(link)
self.check_requires_python(link)
version: str | None = None
if link.is_wheel:
try:
Expand Down Expand Up @@ -260,7 +216,6 @@ def evaluate_link(self, link: Link) -> Package | None:
raise LinkMismatchError(
f"Invalid version in the filename {egg_info}: {version}"
)
self._check_hashes(link)
except LinkMismatchError as e:
logger.debug("Skipping link %s: %s", link, e)
return None
Expand Down Expand Up @@ -299,3 +254,35 @@ def evaluate_package(
)
return False
return True


def _get_hash(link: Link, hash_name: str, session: Session) -> str:
resp = session.get(link.normalized, stream=True)
hasher = hashlib.new(hash_name)
for chunk in resp.iter_content(chunk_size=1024 * 8):
hasher.update(chunk)
digest = hasher.hexdigest()
if not link.hashes:
link.hashes = {}
link.hashes[hash_name] = digest
return digest


def validate_hashes(
package: Package, hashes: dict[str, list[str]], session: Session
) -> bool:
if not hashes:
return True
link = package.link
link_hashes = link.hash_option
if link_hashes:
for hash_name, allowed_hashes in hashes.items():
if hash_name in link_hashes:
given_hash = link_hashes[hash_name][0]
if given_hash not in allowed_hashes:
return False
return True

hash_name, allowed_hashes = next(iter(hashes.items()))
given_hash = _get_hash(link, hash_name, session)
return given_hash in allowed_hashes
55 changes: 36 additions & 19 deletions src/unearth/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import itertools
import os
import pathlib
import warnings
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Iterable, NamedTuple, Sequence
from urllib.parse import urljoin
Expand All @@ -22,6 +23,7 @@
TargetPython,
evaluate_package,
is_equality_specifier,
validate_hashes,
)
from unearth.link import Link
from unearth.preparer import unpack_link
Expand Down Expand Up @@ -145,13 +147,17 @@ def build_evaluator(
Args:
package_name (str): The desired package name
allow_yanked (bool): Whether to allow yanked candidates.
hashes (dict[str, list[str]]|None): The hashes to filter on.
Returns:
Evaluator: The evaluator for the given package name
"""
if hashes:
hashes = {name: sorted(values) for name, values in hashes.items()}
if hashes is not None:
warnings.warn(
"The evaluator no longer validates hashes, "
"please remove the hashes argument",
FutureWarning,
stacklevel=2,
)
canonical_name = canonicalize_name(package_name)
format_control = FormatControl(
no_binary=canonical_name in self.no_binary or ":all:" in self.no_binary,
Expand All @@ -160,11 +166,9 @@ def build_evaluator(
)
return Evaluator(
package_name=package_name,
session=self.session,
target_python=self.target_python,
ignore_compatibility=self.ignore_compatibility,
allow_yanked=allow_yanked,
hashes=hashes or {},
format_control=format_control,
)

Expand Down Expand Up @@ -198,6 +202,14 @@ def _evaluate_packages(
)
return filter(evaluator, packages)

def _evaluate_hashes(
self, packages: Iterable[Package], hashes: dict[str, list[str]]
) -> Iterable[Package]:
evaluator = functools.partial(
validate_hashes, hashes=hashes, session=self.session
)
return filter(evaluator, packages)

def _sort_key(self, package: Package) -> tuple:
"""The key for sort, package with the largest value is the most preferred."""
link = package.link
Expand Down Expand Up @@ -225,19 +237,17 @@ def _find_packages(
self,
package_name: str,
allow_yanked: bool = False,
hashes: dict[str, list[str]] | None = None,
) -> Iterable[Package]:
"""Find all packages with the given name.
Args:
package_name (str): The desired package name
allow_yanked (bool): Whether to allow yanked candidates.
hashes (dict[str, list[str]]|None): The hashes to filter on.
Returns:
Iterable[Package]: The packages with the given name, sorted by best match.
"""
evaluator = self.build_evaluator(package_name, allow_yanked, hashes)
evaluator = self.build_evaluator(package_name, allow_yanked)

def find_one_source(source: Source) -> Iterable[Package]:
if source["type"] == "index":
Expand Down Expand Up @@ -278,20 +288,23 @@ def find_all_packages(
Returns:
Sequence[Package]: The packages list sorted by best match
"""
return LazySequence(self._find_packages(package_name, allow_yanked, hashes))
return LazySequence(
self._evaluate_hashes(
self._find_packages(package_name, allow_yanked), hashes=hashes or {}
)
)

def _find_packages_from_requirement(
self,
requirement: packaging.requirements.Requirement,
allow_yanked: bool | None = None,
hashes: dict[str, list[str]] | None = None,
) -> Iterable[Package]:
if allow_yanked is None:
allow_yanked = is_equality_specifier(requirement.specifier)
if requirement.url:
yield Package(requirement.name, None, link=Link(requirement.url))
else:
yield from self._find_packages(requirement.name, allow_yanked, hashes)
yield from self._find_packages(requirement.name, allow_yanked)

def find_matches(
self,
Expand All @@ -317,10 +330,13 @@ def find_matches(
if isinstance(requirement, str):
requirement = packaging.requirements.Requirement(requirement)
return LazySequence(
self._evaluate_packages(
self._find_packages_from_requirement(requirement, allow_yanked, hashes),
requirement,
allow_prereleases,
self._evaluate_hashes(
self._evaluate_packages(
self._find_packages_from_requirement(requirement, allow_yanked),
requirement,
allow_prereleases,
),
hashes=hashes or {},
)
)

Expand All @@ -347,12 +363,13 @@ def find_best_match(
"""
if isinstance(requirement, str):
requirement = packaging.requirements.Requirement(requirement)
packages = self._find_packages_from_requirement(
requirement, allow_yanked, hashes
)
packages = self._find_packages_from_requirement(requirement, allow_yanked)
candidates = LazySequence(packages)
applicable_candidates = LazySequence(
self._evaluate_packages(packages, requirement, allow_prereleases)
self._evaluate_hashes(
self._evaluate_packages(packages, requirement, allow_prereleases),
hashes=hashes or {},
)
)
best_match = next(iter(applicable_candidates), None)
return BestMatch(best_match, applicable_candidates, candidates)
Expand Down
Loading

0 comments on commit f0aea16

Please sign in to comment.