Skip to content

Commit

Permalink
Made #495 service repsonse backward compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
craigbarratt committed Jul 29, 2023
1 parent 9522331 commit bd45610
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 30 deletions.
15 changes: 15 additions & 0 deletions custom_components/pyscript/const.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
"""Define pyscript-wide constants."""

#
# 2023.7 supports service response; handle older versions by defaulting enum
# Should eventually deprecate this and just use SupportsResponse import
#
try:
from homeassistant.core import SupportsResponse

SERVICE_RESPONSE_NONE = SupportsResponse.NONE
SERVICE_RESPONSE_OPTIONAL = SupportsResponse.OPTIONAL
SERVICE_RESPONSE_ONLY = SupportsResponse.ONLY
except ImportError:
SERVICE_RESPONSE_NONE = None
SERVICE_RESPONSE_OPTIONAL = None
SERVICE_RESPONSE_ONLY = None

DOMAIN = "pyscript"

CONFIG_ENTRY = "config_entry"
Expand Down
8 changes: 6 additions & 2 deletions custom_components/pyscript/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import yaml

from homeassistant.core import SupportsResponse
from homeassistant.const import SERVICE_RELOAD
from homeassistant.helpers.service import async_set_service_schema

Expand All @@ -26,6 +25,7 @@
DOMAIN,
LOGGER_PATH,
SERVICE_JUPYTER_KERNEL_START,
SERVICE_RESPONSE_NONE,
)
from .function import Function
from .state import State
Expand Down Expand Up @@ -505,7 +505,11 @@ async def do_service_call(func, ast_ctx, data):
if name in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START):
raise SyntaxError(f"{exc_mesg}: @service conflicts with builtin service")
Function.service_register(
trig_ctx_name, domain, name, pyscript_service_factory(func_name, self), dec_kwargs.get("supports_response", SupportsResponse.NONE)
trig_ctx_name,
domain,
name,
pyscript_service_factory(func_name, self),
dec_kwargs.get("supports_response", SERVICE_RESPONSE_NONE),
)
async_set_service_schema(Function.hass, domain, name, service_desc)
self.trigger_service.add(srv_name)
Expand Down
56 changes: 36 additions & 20 deletions custom_components/pyscript/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import logging
import traceback

from homeassistant.core import Context, SupportsResponse
from homeassistant.core import Context

from .const import LOGGER_PATH
from .const import LOGGER_PATH, SERVICE_RESPONSE_NONE, SERVICE_RESPONSE_ONLY

_LOGGER = logging.getLogger(LOGGER_PATH + ".function")

Expand Down Expand Up @@ -332,14 +332,7 @@ async def service_call(cls, domain, name, **kwargs):
elif default:
hass_args[keyword] = default

if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, name) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

return await cls.hass.services.async_call(domain, name, kwargs, **hass_args)
return await cls.hass_services_async_call(domain, name, kwargs, **hass_args)

@classmethod
async def service_completions(cls, root):
Expand Down Expand Up @@ -413,19 +406,35 @@ async def service_call(*args, **kwargs):
if len(args) != 0:
raise TypeError(f"service {domain}.{service} takes only keyword arguments")

if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
return await cls.hass_services_async_call(domain, service, kwargs, **hass_args)

return service_call

return service_call_factory(domain, service)

@classmethod
async def hass_services_async_call(cls, domain, service, kwargs, **hass_args):
"""Call a hass async service."""
if SERVICE_RESPONSE_ONLY is None:
# backwards compatibility < 2023.7
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
else:
# allow service responses >= 2023.7
if (
"return_response" in hass_args
and hass_args["return_response"]
and "blocking" not in hass_args
):
hass_args["blocking"] = True
elif (
"return_response" not in hass_args
and cls.hass.services.supports_response(domain, service) == SERVICE_RESPONSE_ONLY
):
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)

@classmethod
async def run_coro(cls, coro, ast_ctx=None):
"""Run coroutine task and update unique task on start and exit."""
Expand Down Expand Up @@ -466,7 +475,9 @@ def create_task(cls, coro, ast_ctx=None):
return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx))

@classmethod
def service_register(cls, global_ctx_name, domain, service, callback, supports_response = SupportsResponse.NONE):
def service_register(
cls, global_ctx_name, domain, service, callback, supports_response=SERVICE_RESPONSE_NONE
):
"""Register a new service callback."""
key = f"{domain}.{service}"
if key not in cls.service_cnt:
Expand All @@ -478,7 +489,12 @@ def service_register(cls, global_ctx_name, domain, service, callback, supports_r
f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}"
)
cls.service_cnt[key] += 1
cls.hass.services.async_register(domain, service, callback, supports_response = supports_response)
if SERVICE_RESPONSE_ONLY is None:
# backwards compatibility < 2023.7
cls.hass.services.async_register(domain, service, callback)
else:
# allow service responses >= 2023.7
cls.hass.services.async_register(domain, service, callback, supports_response=supports_response)

@classmethod
def service_remove(cls, global_ctx_name, domain, service):
Expand Down
10 changes: 2 additions & 8 deletions custom_components/pyscript/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging

from homeassistant.core import Context, SupportsResponse
from homeassistant.core import Context
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
from homeassistant.helpers.service import async_get_all_descriptions

Expand Down Expand Up @@ -308,13 +308,7 @@ async def service_call(*args, **kwargs):
elif len(args) != 0:
raise TypeError(f"service {domain}.{service} takes no positional arguments")

if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

# return await Function.hass_services_async_call(domain, service, kwargs, **hass_args)
return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)

return service_call
Expand Down
6 changes: 6 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from custom_components.pyscript.function import Function
from custom_components.pyscript.state import State
from homeassistant.core import Context
from homeassistant.helpers.state import State as HassState
Expand All @@ -23,6 +24,7 @@ async def test_service_call(hass):
hass.services, "async_call"
) as call:
State.init(hass)
Function.init(hass)
await State.get_service_params()

func = State.get("test.entity.test")
Expand All @@ -45,3 +47,7 @@ async def test_service_call(hass):
{"other_service_data": "test", "entity_id": "test.entity"},
)
assert call.call_args[1] == {"context": Context(id="test"), "blocking": False}

# Stop all tasks to avoid conflicts with other tests
await Function.waiter_stop()
await Function.reaper_stop()

0 comments on commit bd45610

Please sign in to comment.