From 053815476767c196c83c0652e79f7d5f9716f573 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Wed, 16 Nov 2022 13:00:35 +0100 Subject: [PATCH] Add type hints to requirements script (#82075) --- script/gen_requirements_all.py | 70 +++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/script/gen_requirements_all.py b/script/gen_requirements_all.py index bbc970f91785e8..264d9ff9f8a172 100755 --- a/script/gen_requirements_all.py +++ b/script/gen_requirements_all.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 -"""Generate an updated requirements_all.txt.""" +"""Generate updated constraint and requirements files.""" +from __future__ import annotations + import difflib import importlib import os @@ -7,6 +9,7 @@ import pkgutil import re import sys +from typing import Any from homeassistant.util.yaml.loader import load_yaml from script.hassfest.model import Integration @@ -157,7 +160,7 @@ PACKAGE_REGEX = re.compile(r"^(?:--.+\s)?([-_\.\w\d]+).*==.+$") -def has_tests(module: str): +def has_tests(module: str) -> bool: """Test if a module has tests. Module format: homeassistant.components.hue @@ -169,11 +172,11 @@ def has_tests(module: str): return path.exists() -def explore_module(package, explore_children): +def explore_module(package: str, explore_children: bool) -> list[str]: """Explore the modules.""" module = importlib.import_module(package) - found = [] + found: list[str] = [] if not hasattr(module, "__path__"): return found @@ -187,14 +190,17 @@ def explore_module(package, explore_children): return found -def core_requirements(): +def core_requirements() -> list[str]: """Gather core requirements out of pyproject.toml.""" with open("pyproject.toml", "rb") as fp: data = tomllib.load(fp) - return data["project"]["dependencies"] + dependencies: list[str] = data["project"]["dependencies"] + return dependencies -def gather_recursive_requirements(domain, seen=None): +def gather_recursive_requirements( + domain: str, seen: set[str] | None = None +) -> set[str]: """Recursively gather requirements from a module.""" if seen is None: seen = set() @@ -221,18 +227,18 @@ def normalize_package_name(requirement: str) -> str: return package -def comment_requirement(req): +def comment_requirement(req: str) -> bool: """Comment out requirement. Some don't install on all systems.""" return any( normalize_package_name(req) == ign for ign in COMMENT_REQUIREMENTS_NORMALIZED ) -def gather_modules(): +def gather_modules() -> dict[str, list[str]] | None: """Collect the information.""" - reqs = {} + reqs: dict[str, list[str]] = {} - errors = [] + errors: list[str] = [] gather_requirements_from_manifests(errors, reqs) gather_requirements_from_modules(errors, reqs) @@ -248,7 +254,9 @@ def gather_modules(): return reqs -def gather_requirements_from_manifests(errors, reqs): +def gather_requirements_from_manifests( + errors: list[str], reqs: dict[str, list[str]] +) -> None: """Gather all of the requirements from manifests.""" integrations = Integration.load_dir(Path("homeassistant/components")) for domain in sorted(integrations): @@ -266,7 +274,9 @@ def gather_requirements_from_manifests(errors, reqs): ) -def gather_requirements_from_modules(errors, reqs): +def gather_requirements_from_modules( + errors: list[str], reqs: dict[str, list[str]] +) -> None: """Collect the requirements from the modules directly.""" for package in sorted( explore_module("homeassistant.scripts", True) @@ -283,7 +293,12 @@ def gather_requirements_from_modules(errors, reqs): process_requirements(errors, module.REQUIREMENTS, package, reqs) -def process_requirements(errors, module_requirements, package, reqs): +def process_requirements( + errors: list[str], + module_requirements: list[str], + package: str, + reqs: dict[str, list[str]], +) -> None: """Process all of the requirements.""" for req in module_requirements: if "://" in req: @@ -293,7 +308,7 @@ def process_requirements(errors, module_requirements, package, reqs): reqs.setdefault(req, []).append(package) -def generate_requirements_list(reqs): +def generate_requirements_list(reqs: dict[str, list[str]]) -> str: """Generate a pip file based on requirements.""" output = [] for pkg, requirements in sorted(reqs.items(), key=lambda item: item[0]): @@ -307,7 +322,7 @@ def generate_requirements_list(reqs): return "".join(output) -def requirements_output(reqs): +def requirements_output() -> str: """Generate output for requirements.""" output = [ "-c homeassistant/package_constraints.txt\n", @@ -320,7 +335,7 @@ def requirements_output(reqs): return "".join(output) -def requirements_all_output(reqs): +def requirements_all_output(reqs: dict[str, list[str]]) -> str: """Generate output for requirements_all.""" output = [ "# Home Assistant Core, full dependency set\n", @@ -331,7 +346,7 @@ def requirements_all_output(reqs): return "".join(output) -def requirements_test_all_output(reqs): +def requirements_test_all_output(reqs: dict[str, list[str]]) -> str: """Generate output for test_requirements.""" output = [ "# Home Assistant tests, full dependency set\n", @@ -356,15 +371,18 @@ def requirements_test_all_output(reqs): return "".join(output) -def requirements_pre_commit_output(): +def requirements_pre_commit_output() -> str: """Generate output for pre-commit dependencies.""" source = ".pre-commit-config.yaml" - pre_commit_conf = load_yaml(source) - reqs = [] + pre_commit_conf: dict[str, list[dict[str, Any]]] + pre_commit_conf = load_yaml(source) # type: ignore[assignment] + reqs: list[str] = [] + hook: dict[str, Any] for repo in (x for x in pre_commit_conf["repos"] if x.get("rev")): + rev: str = repo["rev"] for hook in repo["hooks"]: if hook["id"] not in IGNORE_PRE_COMMIT_HOOK_ID: - reqs.append(f"{hook['id']}=={repo['rev'].lstrip('v')}") + reqs.append(f"{hook['id']}=={rev.lstrip('v')}") reqs.extend(x for x in hook.get("additional_dependencies", ())) output = [ f"# Automatically generated " @@ -375,7 +393,7 @@ def requirements_pre_commit_output(): return "\n".join(output) + "\n" -def gather_constraints(): +def gather_constraints() -> str: """Construct output for constraint file.""" return ( "\n".join( @@ -392,7 +410,7 @@ def gather_constraints(): ) -def diff_file(filename, content): +def diff_file(filename: str, content: str) -> list[str]: """Diff a file.""" return list( difflib.context_diff( @@ -404,7 +422,7 @@ def diff_file(filename, content): ) -def main(validate): +def main(validate: bool) -> int: """Run the script.""" if not os.path.isfile("requirements_all.txt"): print("Run this from HA root dir") @@ -415,7 +433,7 @@ def main(validate): if data is None: return 1 - reqs_file = requirements_output(data) + reqs_file = requirements_output() reqs_all_file = requirements_all_output(data) reqs_test_all_file = requirements_test_all_output(data) reqs_pre_commit_file = requirements_pre_commit_output()