Skip to content

Commit

Permalink
Make hassfest strictly typed (home-assistant#82091)
Browse files Browse the repository at this point in the history
  • Loading branch information
akx authored Nov 23, 2022
1 parent 0b5357d commit 97b40b5
Show file tree
Hide file tree
Showing 20 changed files with 132 additions and 97 deletions.
13 changes: 10 additions & 3 deletions script/hassfest/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Validate manifests."""
from __future__ import annotations

import argparse
import pathlib
import sys
Expand Down Expand Up @@ -55,7 +57,7 @@
]


def valid_integration_path(integration_path):
def valid_integration_path(integration_path: pathlib.Path | str) -> pathlib.Path:
"""Test if it's a valid integration."""
path = pathlib.Path(integration_path)
if not path.is_dir():
Expand Down Expand Up @@ -124,7 +126,7 @@ def get_config() -> Config:
)


def main():
def main() -> int:
"""Validate manifests."""
try:
config = get_config()
Expand Down Expand Up @@ -218,7 +220,12 @@ def main():
return 1


def print_integrations_status(config, integrations, *, show_fixable_errors=True):
def print_integrations_status(
config: Config,
integrations: list[Integration],
*,
show_fixable_errors: bool = True,
) -> None:
"""Print integration status."""
for integration in sorted(integrations, key=lambda itg: itg.domain):
extra = f" - {integration.path}" if config.specific_integrations else ""
Expand Down
2 changes: 1 addition & 1 deletion script/hassfest/application_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate(integrations: dict[str, Integration], config: Config) -> None:
)


def generate(integrations: dict[str, Integration], config: Config):
def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate application_credentials data."""
application_credentials_path = (
config.root / "homeassistant/generated/application_credentials.py"
Expand Down
6 changes: 3 additions & 3 deletions script/hassfest/bluetooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .serializer import format_python_namespace


def generate_and_validate(integrations: list[dict[str, str]]):
def generate_and_validate(integrations: dict[str, Integration]) -> str:
"""Validate and generate bluetooth data."""
match_list = []

Expand All @@ -29,7 +29,7 @@ def generate_and_validate(integrations: list[dict[str, str]]):
)


def validate(integrations: dict[str, Integration], config: Config):
def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate bluetooth file."""
bluetooth_path = config.root / "homeassistant/generated/bluetooth.py"
config.cache["bluetooth"] = content = generate_and_validate(integrations)
Expand All @@ -48,7 +48,7 @@ def validate(integrations: dict[str, Integration], config: Config):
return


def generate(integrations: dict[str, Integration], config: Config):
def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate bluetooth file."""
bluetooth_path = config.root / "homeassistant/generated/bluetooth.py"
with open(str(bluetooth_path), "w") as fp:
Expand Down
2 changes: 1 addition & 1 deletion script/hassfest/brand.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _validate_brand(
):
config.add_error(
"brand",
f"{brand.path.name}: Brand '{brand.brand['domain']}' "
f"{brand.path.name}: Brand '{brand.domain}' "
f"is an integration but is missing in the brand's 'integrations' list'",
)

Expand Down
6 changes: 3 additions & 3 deletions script/hassfest/codeowners.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"""


def generate_and_validate(integrations: dict[str, Integration], config: Config):
def generate_and_validate(integrations: dict[str, Integration], config: Config) -> str:
"""Generate CODEOWNERS."""
parts = [BASE]

Expand Down Expand Up @@ -77,7 +77,7 @@ def generate_and_validate(integrations: dict[str, Integration], config: Config):
return "\n".join(parts)


def validate(integrations: dict[str, Integration], config: Config):
def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate CODEOWNERS."""
codeowners_path = config.root / "CODEOWNERS"
config.cache["codeowners"] = content = generate_and_validate(integrations, config)
Expand All @@ -95,7 +95,7 @@ def validate(integrations: dict[str, Integration], config: Config):
return


def generate(integrations: dict[str, Integration], config: Config):
def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate CODEOWNERS."""
codeowners_path = config.root / "CODEOWNERS"
with open(str(codeowners_path), "w") as fp:
Expand Down
29 changes: 16 additions & 13 deletions script/hassfest/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import pathlib
from typing import Any

from .brand import validate as validate_brands
from .model import Brand, Config, Integration
Expand All @@ -11,12 +12,12 @@
UNIQUE_ID_IGNORE = {"huawei_lte", "mqtt", "adguard"}


def _validate_integration(config: Config, integration: Integration):
def _validate_integration(config: Config, integration: Integration) -> None:
"""Validate config flow of an integration."""
config_flow_file = integration.path / "config_flow.py"

if not config_flow_file.is_file():
if integration.manifest.get("config_flow"):
if (integration.manifest or {}).get("config_flow"):
integration.add_error(
"config_flow",
"Config flows need to be defined in the file config_flow.py",
Expand Down Expand Up @@ -60,9 +61,9 @@ def _validate_integration(config: Config, integration: Integration):
)


def _generate_and_validate(integrations: dict[str, Integration], config: Config):
def _generate_and_validate(integrations: dict[str, Integration], config: Config) -> str:
"""Validate and generate config flow data."""
domains = {
domains: dict[str, list[str]] = {
"integration": [],
"helper": [],
}
Expand All @@ -84,9 +85,9 @@ def _generate_and_validate(integrations: dict[str, Integration], config: Config)


def _populate_brand_integrations(
integration_data: dict,
integration_data: dict[str, Any],
integrations: dict[str, Integration],
brand_metadata: dict,
brand_metadata: dict[str, Any],
sub_integrations: list[str],
) -> None:
"""Add referenced integrations to a brand's metadata."""
Expand All @@ -99,7 +100,7 @@ def _populate_brand_integrations(
"system",
):
continue
metadata = {
metadata: dict[str, Any] = {
"integration_type": integration.integration_type,
}
# Always set the config_flow key to avoid breaking the frontend
Expand All @@ -119,11 +120,13 @@ def _populate_brand_integrations(


def _generate_integrations(
brands: dict[str, Brand], integrations: dict[str, Integration], config: Config
):
brands: dict[str, Brand],
integrations: dict[str, Integration],
config: Config,
) -> str:
"""Generate integrations data."""

result = {
result: dict[str, Any] = {
"integration": {},
"helper": {},
"translated_name": set(),
Expand Down Expand Up @@ -154,7 +157,7 @@ def _generate_integrations(

# Generate the config flow index
for domain in sorted(primary_domains):
metadata = {}
metadata: dict[str, Any] = {}

if brand := brands.get(domain):
metadata["name"] = brand.name
Expand Down Expand Up @@ -199,7 +202,7 @@ def _generate_integrations(
)


def validate(integrations: dict[str, Integration], config: Config):
def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate config flow file."""
config_flow_path = config.root / "homeassistant/generated/config_flows.py"
integrations_path = config.root / "homeassistant/generated/integrations.json"
Expand Down Expand Up @@ -233,7 +236,7 @@ def validate(integrations: dict[str, Integration], config: Config):
)


def generate(integrations: dict[str, Integration], config: Config):
def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate config flow file."""
config_flow_path = config.root / "homeassistant/generated/config_flows.py"
integrations_path = config.root / "homeassistant/generated/integrations.json"
Expand Down
2 changes: 1 addition & 1 deletion script/hassfest/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
}


def validate(integrations: dict[str, Integration], config: Config):
def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate coverage."""
coverage_path = config.root / ".coveragerc"

Expand Down
31 changes: 17 additions & 14 deletions script/hassfest/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
from homeassistant.const import Platform
from homeassistant.requirements import DISCOVERY_INTEGRATIONS

from .model import Integration
from .model import Config, Integration


class ImportCollector(ast.NodeVisitor):
"""Collect all integrations referenced."""

def __init__(self, integration: Integration):
def __init__(self, integration: Integration) -> None:
"""Initialize the import collector."""
self.integration = integration
self.referenced: dict[Path, set[str]] = {}

# Current file or dir we're inspecting
self._cur_fil_dir = None
self._cur_fil_dir: Path | None = None

def collect(self) -> None:
"""Collect imports from a source file."""
Expand All @@ -32,11 +32,12 @@ def collect(self) -> None:
self.visit(ast.parse(fil.read_text()))
self._cur_fil_dir = None

def _add_reference(self, reference_domain: str):
def _add_reference(self, reference_domain: str) -> None:
"""Add a reference."""
assert self._cur_fil_dir
self.referenced[self._cur_fil_dir].add(reference_domain)

def visit_ImportFrom(self, node):
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Visit ImportFrom node."""
if node.module is None:
return
Expand All @@ -59,14 +60,14 @@ def visit_ImportFrom(self, node):
for name_node in node.names:
self._add_reference(name_node.name)

def visit_Import(self, node):
def visit_Import(self, node: ast.Import) -> None:
"""Visit Import node."""
# import homeassistant.components.hue as hue
for name_node in node.names:
if name_node.name.startswith("homeassistant.components."):
self._add_reference(name_node.name.split(".")[2])

def visit_Attribute(self, node):
def visit_Attribute(self, node: ast.Attribute) -> None:
"""Visit Attribute node."""
# hass.components.hue.async_create()
# Name(id=hass)
Expand Down Expand Up @@ -156,15 +157,16 @@ def visit_Attribute(self, node):

def calc_allowed_references(integration: Integration) -> set[str]:
"""Return a set of allowed references."""
manifest = integration.manifest
allowed_references = (
ALLOWED_USED_COMPONENTS
| set(integration.manifest.get("dependencies", []))
| set(integration.manifest.get("after_dependencies", []))
| set(manifest.get("dependencies", []))
| set(manifest.get("after_dependencies", []))
)

# Discovery requirements are ok if referenced in manifest
for check_domain, to_check in DISCOVERY_INTEGRATIONS.items():
if any(check in integration.manifest for check in to_check):
if any(check in manifest for check in to_check):
allowed_references.add(check_domain)

return allowed_references
Expand All @@ -174,7 +176,7 @@ def find_non_referenced_integrations(
integrations: dict[str, Integration],
integration: Integration,
references: dict[Path, set[str]],
):
) -> set[str]:
"""Find intergrations that are not allowed to be referenced."""
allowed_references = calc_allowed_references(integration)
referenced = set()
Expand Down Expand Up @@ -219,8 +221,9 @@ def find_non_referenced_integrations(


def validate_dependencies(
integrations: dict[str, Integration], integration: Integration
):
integrations: dict[str, Integration],
integration: Integration,
) -> None:
"""Validate all dependencies."""
# Some integrations are allowed to have violations.
if integration.domain in IGNORE_VIOLATIONS:
Expand All @@ -242,7 +245,7 @@ def validate_dependencies(
)


def validate(integrations: dict[str, Integration], config):
def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Handle dependencies for integrations."""
# check for non-existing dependencies
for integration in integrations.values():
Expand Down
6 changes: 3 additions & 3 deletions script/hassfest/dhcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .serializer import format_python_namespace


def generate_and_validate(integrations: list[dict[str, str]]):
def generate_and_validate(integrations: dict[str, Integration]) -> str:
"""Validate and generate dhcp data."""
match_list = []

Expand All @@ -29,7 +29,7 @@ def generate_and_validate(integrations: list[dict[str, str]]):
)


def validate(integrations: dict[str, Integration], config: Config):
def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Validate dhcp file."""
dhcp_path = config.root / "homeassistant/generated/dhcp.py"
config.cache["dhcp"] = content = generate_and_validate(integrations)
Expand All @@ -48,7 +48,7 @@ def validate(integrations: dict[str, Integration], config: Config):
return


def generate(integrations: dict[str, Integration], config: Config):
def generate(integrations: dict[str, Integration], config: Config) -> None:
"""Generate dhcp file."""
dhcp_path = config.root / "homeassistant/generated/dhcp.py"
with open(str(dhcp_path), "w") as fp:
Expand Down
8 changes: 3 additions & 5 deletions script/hassfest/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import json

from .model import Integration
from .model import Config, Integration


def validate_json_files(integration: Integration):
def validate_json_files(integration: Integration) -> None:
"""Validate JSON files for integration."""
for json_file in integration.path.glob("**/*.json"):
if not json_file.is_file():
Expand All @@ -18,10 +18,8 @@ def validate_json_files(integration: Integration):
relative_path = json_file.relative_to(integration.path)
integration.add_error("json", f"Invalid JSON file {relative_path}")

return


def validate(integrations: dict[str, Integration], config):
def validate(integrations: dict[str, Integration], config: Config) -> None:
"""Handle JSON files inside integrations."""
if not config.specific_integrations:
return
Expand Down
Loading

0 comments on commit 97b40b5

Please sign in to comment.