Skip to content

Commit

Permalink
Merge pull request #145 from mietzen/dev
Browse files Browse the repository at this point in the history
Easier way to omit ssl verification
  • Loading branch information
GrandMoff100 authored Mar 11, 2023
2 parents b8042bf + 724ff56 commit 272b42f
Show file tree
Hide file tree
Showing 19 changed files with 129 additions and 64 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
environment: "Python Package Deployment"
steps:
- uses: actions/checkout@v3
- name: Set up Python
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
code_functionality:
name: "Code Functionality"
runs-on: ubuntu-latest
environment: "Test Suite"
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The full version, including alpha/beta/rc tags
with open("../pyproject.toml") as f:
pyproject = f.read()
release = version = re.search("version = \"(.+?)\"", pyproject).group(1)
release = version = re.search('version = "(.+?)"', pyproject).group(1)

# -- General configuration ---------------------------------------------------

Expand Down
12 changes: 9 additions & 3 deletions homeassistant_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@ class Client(RawClient, RawAsyncClient):
:param async_cache_session: A :py:class:`aiohttp_client_cache.CachedSession` object to use for caching requests. Optional.
""" # pylint: disable=line-too-long

def __init__(self, *args: Any, use_async: bool = False, **kwargs: Any) -> None:
def __init__(
self,
*args: Any,
use_async: bool = False,
verify_ssl: bool = True,
**kwargs: Any
) -> None:
if use_async:
RawAsyncClient.__init__(self, *args, **kwargs)
RawAsyncClient.__init__(self, *args, verify_ssl=verify_ssl, **kwargs)
else:
RawClient.__init__(self, *args, **kwargs)
RawClient.__init__(self, *args, verify_ssl=verify_ssl, **kwargs)
4 changes: 2 additions & 2 deletions homeassistant_api/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def __init__(self, status_code: int, content: Union[str, bytes]) -> None:
class UnauthorizedError(HomeassistantAPIError):
"""Error raised when an invalid token in used to authenticate with homeassistant."""

def __init__(self):
def __init__(self) -> None:
super().__init__("Invalid authentication token")


class EndpointNotFoundError(HomeassistantAPIError):
"""Error raised when a request is made to a non existing endpoint."""

def __init__(self, path: str):
def __init__(self, path: str) -> None:
super().__init__(f"Cannot make request to the endpoint {path!r}")


Expand Down
4 changes: 2 additions & 2 deletions homeassistant_api/models/entity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module for Entity and entity Group data models"""

from datetime import datetime
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from pydantic import Field

Expand Down Expand Up @@ -43,7 +43,7 @@ def get_entity(self, slug: str) -> Optional["Entity"]:
"""Returns Entity with the given name if it exists. Otherwise returns None"""
return self.entities.get(slug)

def __getattr__(self, key: str):
def __getattr__(self, key: str) -> Any:
if key in self.entities:
return self.get_entity(key)
return super().__getattribute__(key)
Expand Down
2 changes: 1 addition & 1 deletion homeassistant_api/models/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class History(BaseModel):
..., description="A tuple of previous states of an entity."
)

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
assert self.entity_id is not None

Expand Down
14 changes: 7 additions & 7 deletions homeassistant_api/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def process_content(self, *, async_: bool = False) -> Any:
mimetype = self._response.headers.get( # type: ignore [arg-type]
"content-type",
"text/plain",
) # type: ignore[arg-type]
)
for processor in self._processors.get(mimetype, ()):
if not async_ ^ inspect.iscoroutinefunction(processor):
logger.debug("Using processor %r on %r", processor, self._response)
Expand Down Expand Up @@ -105,36 +105,36 @@ def process(self) -> Any:

# List of default processors
@Processing.processor("application/json") # type: ignore[arg-type]
def process_json(response: ResponseType) -> Dict[str, Any]:
def process_json(response: ResponseType) -> dict[str, Any]:
"""Returns the json dict content of the response."""
try:
return response.json()
return cast(dict[str, Any], response.json())
except (json.JSONDecodeError, simplejson.JSONDecodeError) as err:
raise MalformedDataError(
f"Home Assistant responded with non-json response: {repr(response.text)}"
) from err


@Processing.processor("text/plain") # type: ignore[arg-type]
@Processing.processor("application/octet-stream") # type: ignore[arg-type]
@Processing.processor("application/octet-stream")
def process_text(response: ResponseType) -> str:
"""Returns the plaintext of the reponse."""
return response.text


@Processing.processor("application/json") # type: ignore[arg-type]
async def async_process_json(response: AsyncResponseType) -> Dict[str, Any]:
async def async_process_json(response: AsyncResponseType) -> dict[str, Any]:
"""Returns the json dict content of the response."""
try:
return await response.json()
return cast(dict[str, Any], await response.json())
except (json.JSONDecodeError, simplejson.JSONDecodeError) as err:
raise MalformedDataError(
f"Home Assistant responded with non-json response: {repr(await response.text())}"
) from err


@Processing.processor("text/plain") # type: ignore[arg-type]
@Processing.processor("application/octet-stream") # type: ignore[arg-type]
@Processing.processor("application/octet-stream")
async def async_process_text(response: AsyncResponseType) -> str:
"""Returns the plaintext of the reponse."""
return await response.text()
27 changes: 16 additions & 11 deletions homeassistant_api/rawasyncclient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Module for interacting with Home Assistant asyncronously."""
from __future__ import annotations

import asyncio
import json
import logging
Expand Down Expand Up @@ -54,17 +56,20 @@ def __init__(
Literal[False],
Literal[None],
] = None, # Explicitly disable cache with async_cache_session=False
verify_ssl: bool = True,
**kwargs,
):
RawBaseClient.__init__(self, *args, **kwargs)
connector = aiohttp.TCPConnector(verify_ssl=False) if not verify_ssl else None
if async_cache_session is False:
self.async_cache_session = aiohttp.ClientSession()
self.async_cache_session = aiohttp.ClientSession(connector=connector)
elif async_cache_session is None:
self.async_cache_session = aiohttp_client_cache.CachedSession(
cache=aiohttp_client_cache.CacheBackend(
self.async_cache_session = aiohttp_client_cache.CachedSession( # type: ignore[attr-defined]
cache=aiohttp_client_cache.CacheBackend( # type: ignore[attr-defined]
cache_name="default_async_cache",
expire_after=300,
),
connector=connector,
)
else:
self.async_cache_session = async_cache_session
Expand Down Expand Up @@ -155,14 +160,14 @@ async def async_get_entity_histories(
for states in data:
yield History.parse_obj({"states": states})

async def async_get_rendered_template(self, template: str):
async def async_get_rendered_template(self, template: str) -> str:
"""Renders a given Jinja2 template string with Home Assistant context data."""
try:
return await self.async_request(
return cast(str, await self.async_request(
"template",
json=dict(template=template),
method="POST",
)
))
except RequestError as err:
raise BadTemplateError(
"Your template is invalid. "
Expand Down Expand Up @@ -208,9 +213,9 @@ async def async_get_entities(self) -> Tuple[Group, ...]:

async def async_get_entity(
self,
group_id: str = None,
slug: str = None,
entity_id: str = None,
group_id: str | None = None,
slug: str | None = None,
entity_id: str | None = None,
) -> Optional[Entity]:
"""Returns a Entity model for an :code:`entity_id`"""
if group_id is not None and slug is not None:
Expand Down Expand Up @@ -311,14 +316,14 @@ async def async_get_event(self, name: str) -> Optional[Event]:
return event
return None

async def async_fire_event(self, event_type: str, **event_data) -> str:
async def async_fire_event(self, event_type: str, **event_data: Any) -> str:
"""Fires a given event_type within homeassistant. Must be an existing event_type."""
data = await self.async_request(
join("events", event_type),
method="POST",
json=event_data,
)
return data.get("message", "No message provided")
return cast(str, data.get("message", "No message provided"))

async def async_get_components(self) -> Tuple[str, ...]:
"""Returns a tuple of all registered components."""
Expand Down
4 changes: 2 additions & 2 deletions homeassistant_api/rawbaseclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from datetime import datetime
from posixpath import join
from typing import Dict, Iterable, Optional, Tuple, Union
from typing import Dict, Iterable, Optional, Tuple, Union, Any

from .models import Entity

Expand All @@ -13,7 +13,7 @@ class RawBaseClient:

api_url: str
token: str
global_request_kwargs: Dict[str, str]
global_request_kwargs: Dict[str, Any]

def __init__(
self,
Expand Down
29 changes: 16 additions & 13 deletions homeassistant_api/rawclient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for all interaction with homeassistant."""
from __future__ import annotations

import json
import logging
Expand Down Expand Up @@ -53,36 +54,38 @@ def __init__(
Literal[False],
Literal[None],
] = None, # Explicitly disable cache with cache_session=False
verify_ssl: bool = True,
**kwargs,
):
RawBaseClient.__init__(self, *args, **kwargs)
self.global_request_kwargs["verify"] = verify_ssl
if cache_session is False:
self.cache_session = requests.Session()
elif cache_session is None:
self.cache_session = requests_cache.CachedSession(
self.cache_session = requests_cache.CachedSession( # type: ignore[attr-defined]
cache_name="default_cache",
backend="memory",
expire_after=300,
)
else:
self.cache_session = cache_session

def __enter__(self):
def __enter__(self) -> "RawClient":
logger.debug("Entering cached requests session %r.", self.cache_session)
self.cache_session.__enter__()
self.check_api_running()
self.check_api_config()
return self

def __exit__(self, _, __, ___):
def __exit__(self, _, __, ___) -> None:
logger.debug("Exiting requests session %r", self.cache_session)
self.cache_session.close()

def request(
self,
path,
path: str,
method="GET",
headers: Dict[str, str] = None,
headers: Dict[str, str] | None = None,
decode_bytes: bool = True,
**kwargs,
) -> Any:
Expand Down Expand Up @@ -202,17 +205,17 @@ def get_entities(self) -> Dict[str, Group]:
group_id, entity_slug = state.entity_id.split(".")
if group_id not in entities:
entities[group_id] = Group(
group_id=cast(str, group_id),
group_id=group_id,
_client=self, # type: ignore[arg-type]
)
entities[group_id]._add_entity(entity_slug, state)
return entities

def get_entity(
self,
group_id: str = None,
slug: str = None,
entity_id: str = None,
group_id: str | None = None,
slug: str | None = None,
entity_id: str | None = None,
) -> Optional[Entity]:
"""Returns an :py:class:`Entity` model for an :code:`entity_id`"""
if group_id is not None and slug is not None:
Expand All @@ -229,11 +232,11 @@ def get_entity(
)
split_group_id, split_slug = state.entity_id.split(".")
group = Group(
group_id=cast(str, split_group_id),
group_id=split_group_id,
_client=self, # type: ignore[arg-type]
)
group._add_entity(cast(str, split_slug), state)
return group.get_entity(cast(str, split_slug))
group._add_entity(split_slug, state)
return group.get_entity(split_slug)

# Services and domain methods
def get_domains(self) -> Dict[str, Domain]:
Expand Down Expand Up @@ -326,7 +329,7 @@ def fire_event(self, event_type: str, **event_data) -> Optional[str]:
method="POST",
json=event_data,
)
return cast(dict, data).get("message")
return cast(dict[str, Any], data).get("message")

def get_components(self) -> Tuple[str, ...]:
"""Returns a tuple of all registered components."""
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,12 @@ select = [
[tool.ruff.per-file-ignores]
"__init__.py" = ['F401']
"conf.py" = ['E402']

[tool.isort]
profile = "black"

[tool.mypy]
disable_error_code = [
"no-untyped-def",
"name-defined",
]
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
import os
from typing import AsyncGenerator, Generator

Expand All @@ -7,8 +8,6 @@

from homeassistant_api import Client

import logging

TIMEOUT = 300


Expand All @@ -23,6 +22,7 @@ def wait_for_server_fixture() -> None:
client.request(method="get", path="", timeout=TIMEOUT)
logging.info("Server is ready.")


@pytest.fixture(name="cached_client", scope="session")
def setup_cached_client(wait_for_server) -> Generator[Client, None, None]:
"""Initializes the Client and enters a cached session."""
Expand Down
5 changes: 4 additions & 1 deletion tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def test_get_config(cached_client: Client) -> None:

async def test_async_get_config(async_cached_client: Client) -> None:
"""Tests the `GET /api/config` endpoint."""
assert (await async_cached_client.async_get_config()).get("state") in {"RUNNING", "NOT_RUNNING"}
assert (await async_cached_client.async_get_config()).get("state") in {
"RUNNING",
"NOT_RUNNING",
}


def test_get_logbook_entries(cached_client: Client) -> None:
Expand Down
Loading

0 comments on commit 272b42f

Please sign in to comment.