Skip to content

Commit

Permalink
Add type hints to requirements script (home-assistant#82075)
Browse files Browse the repository at this point in the history
  • Loading branch information
epenet authored Nov 16, 2022
1 parent 1582d88 commit 0538154
Showing 1 changed file with 44 additions and 26 deletions.
70 changes: 44 additions & 26 deletions script/gen_requirements_all.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#!/usr/bin/env python3
"""Generate an updated requirements_all.txt."""
"""Generate updated constraint and requirements files."""
from __future__ import annotations

import difflib
import importlib
import os
from pathlib import Path
import pkgutil
import re
import sys
from typing import Any

from homeassistant.util.yaml.loader import load_yaml
from script.hassfest.model import Integration
Expand Down Expand Up @@ -157,7 +160,7 @@
PACKAGE_REGEX = re.compile(r"^(?:--.+\s)?([-_\.\w\d]+).*==.+$")


def has_tests(module: str):
def has_tests(module: str) -> bool:
"""Test if a module has tests.
Module format: homeassistant.components.hue
Expand All @@ -169,11 +172,11 @@ def has_tests(module: str):
return path.exists()


def explore_module(package, explore_children):
def explore_module(package: str, explore_children: bool) -> list[str]:
"""Explore the modules."""
module = importlib.import_module(package)

found = []
found: list[str] = []

if not hasattr(module, "__path__"):
return found
Expand All @@ -187,14 +190,17 @@ def explore_module(package, explore_children):
return found


def core_requirements():
def core_requirements() -> list[str]:
"""Gather core requirements out of pyproject.toml."""
with open("pyproject.toml", "rb") as fp:
data = tomllib.load(fp)
return data["project"]["dependencies"]
dependencies: list[str] = data["project"]["dependencies"]
return dependencies


def gather_recursive_requirements(domain, seen=None):
def gather_recursive_requirements(
domain: str, seen: set[str] | None = None
) -> set[str]:
"""Recursively gather requirements from a module."""
if seen is None:
seen = set()
Expand All @@ -221,18 +227,18 @@ def normalize_package_name(requirement: str) -> str:
return package


def comment_requirement(req):
def comment_requirement(req: str) -> bool:
"""Comment out requirement. Some don't install on all systems."""
return any(
normalize_package_name(req) == ign for ign in COMMENT_REQUIREMENTS_NORMALIZED
)


def gather_modules():
def gather_modules() -> dict[str, list[str]] | None:
"""Collect the information."""
reqs = {}
reqs: dict[str, list[str]] = {}

errors = []
errors: list[str] = []

gather_requirements_from_manifests(errors, reqs)
gather_requirements_from_modules(errors, reqs)
Expand All @@ -248,7 +254,9 @@ def gather_modules():
return reqs


def gather_requirements_from_manifests(errors, reqs):
def gather_requirements_from_manifests(
errors: list[str], reqs: dict[str, list[str]]
) -> None:
"""Gather all of the requirements from manifests."""
integrations = Integration.load_dir(Path("homeassistant/components"))
for domain in sorted(integrations):
Expand All @@ -266,7 +274,9 @@ def gather_requirements_from_manifests(errors, reqs):
)


def gather_requirements_from_modules(errors, reqs):
def gather_requirements_from_modules(
errors: list[str], reqs: dict[str, list[str]]
) -> None:
"""Collect the requirements from the modules directly."""
for package in sorted(
explore_module("homeassistant.scripts", True)
Expand All @@ -283,7 +293,12 @@ def gather_requirements_from_modules(errors, reqs):
process_requirements(errors, module.REQUIREMENTS, package, reqs)


def process_requirements(errors, module_requirements, package, reqs):
def process_requirements(
errors: list[str],
module_requirements: list[str],
package: str,
reqs: dict[str, list[str]],
) -> None:
"""Process all of the requirements."""
for req in module_requirements:
if "://" in req:
Expand All @@ -293,7 +308,7 @@ def process_requirements(errors, module_requirements, package, reqs):
reqs.setdefault(req, []).append(package)


def generate_requirements_list(reqs):
def generate_requirements_list(reqs: dict[str, list[str]]) -> str:
"""Generate a pip file based on requirements."""
output = []
for pkg, requirements in sorted(reqs.items(), key=lambda item: item[0]):
Expand All @@ -307,7 +322,7 @@ def generate_requirements_list(reqs):
return "".join(output)


def requirements_output(reqs):
def requirements_output() -> str:
"""Generate output for requirements."""
output = [
"-c homeassistant/package_constraints.txt\n",
Expand All @@ -320,7 +335,7 @@ def requirements_output(reqs):
return "".join(output)


def requirements_all_output(reqs):
def requirements_all_output(reqs: dict[str, list[str]]) -> str:
"""Generate output for requirements_all."""
output = [
"# Home Assistant Core, full dependency set\n",
Expand All @@ -331,7 +346,7 @@ def requirements_all_output(reqs):
return "".join(output)


def requirements_test_all_output(reqs):
def requirements_test_all_output(reqs: dict[str, list[str]]) -> str:
"""Generate output for test_requirements."""
output = [
"# Home Assistant tests, full dependency set\n",
Expand All @@ -356,15 +371,18 @@ def requirements_test_all_output(reqs):
return "".join(output)


def requirements_pre_commit_output():
def requirements_pre_commit_output() -> str:
"""Generate output for pre-commit dependencies."""
source = ".pre-commit-config.yaml"
pre_commit_conf = load_yaml(source)
reqs = []
pre_commit_conf: dict[str, list[dict[str, Any]]]
pre_commit_conf = load_yaml(source) # type: ignore[assignment]
reqs: list[str] = []
hook: dict[str, Any]
for repo in (x for x in pre_commit_conf["repos"] if x.get("rev")):
rev: str = repo["rev"]
for hook in repo["hooks"]:
if hook["id"] not in IGNORE_PRE_COMMIT_HOOK_ID:
reqs.append(f"{hook['id']}=={repo['rev'].lstrip('v')}")
reqs.append(f"{hook['id']}=={rev.lstrip('v')}")
reqs.extend(x for x in hook.get("additional_dependencies", ()))
output = [
f"# Automatically generated "
Expand All @@ -375,7 +393,7 @@ def requirements_pre_commit_output():
return "\n".join(output) + "\n"


def gather_constraints():
def gather_constraints() -> str:
"""Construct output for constraint file."""
return (
"\n".join(
Expand All @@ -392,7 +410,7 @@ def gather_constraints():
)


def diff_file(filename, content):
def diff_file(filename: str, content: str) -> list[str]:
"""Diff a file."""
return list(
difflib.context_diff(
Expand All @@ -404,7 +422,7 @@ def diff_file(filename, content):
)


def main(validate):
def main(validate: bool) -> int:
"""Run the script."""
if not os.path.isfile("requirements_all.txt"):
print("Run this from HA root dir")
Expand All @@ -415,7 +433,7 @@ def main(validate):
if data is None:
return 1

reqs_file = requirements_output(data)
reqs_file = requirements_output()
reqs_all_file = requirements_all_output(data)
reqs_test_all_file = requirements_test_all_output(data)
reqs_pre_commit_file = requirements_pre_commit_output()
Expand Down

0 comments on commit 0538154

Please sign in to comment.