Skip to content

Commit

Permalink
Added support for service responses when calling or creating services.
Browse files Browse the repository at this point in the history
  • Loading branch information
matzman666 committed Jul 22, 2023
1 parent 716ffd7 commit 212cc50
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
11 changes: 8 additions & 3 deletions custom_components/pyscript/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import yaml

from homeassistant.core import SupportsResponse
from homeassistant.const import SERVICE_RELOAD
from homeassistant.helpers.service import async_set_service_schema

Expand Down Expand Up @@ -377,6 +378,7 @@ async def trigger_init(self, trig_ctx, func_name):
"time_trigger": {"kwargs": {dict}},
"task_unique": {"kill_me": {bool, int}},
"time_active": {"hold_off": {int, float}},
"service": {"supports_response": {str}},
"state_trigger": {
"kwargs": {dict},
"state_hold": {int, float},
Expand Down Expand Up @@ -485,11 +487,14 @@ async def pyscript_service_handler(call):
func_args.update(call.data)

async def do_service_call(func, ast_ctx, data):
await func.call(ast_ctx, **data)
retval = await func.call(ast_ctx, **data)
if ast_ctx.get_exception_obj():
ast_ctx.get_logger().error(ast_ctx.get_exception_long())
return retval

Function.create_task(do_service_call(func, ast_ctx, func_args))
task = Function.create_task(do_service_call(func, ast_ctx, func_args))
await task
return task.result()

return pyscript_service_handler

Expand All @@ -500,7 +505,7 @@ async def do_service_call(func, ast_ctx, data):
if name in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START):
raise SyntaxError(f"{exc_mesg}: @service conflicts with builtin service")
Function.service_register(
trig_ctx_name, domain, name, pyscript_service_factory(func_name, self)
trig_ctx_name, domain, name, pyscript_service_factory(func_name, self), dec_kwargs.get("supports_response", SupportsResponse.NONE)
)
async_set_service_schema(Function.hass, domain, name, service_desc)
self.trigger_service.add(srv_name)
Expand Down
26 changes: 21 additions & 5 deletions custom_components/pyscript/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import traceback

from homeassistant.core import Context
from homeassistant.core import Context, SupportsResponse

from .const import LOGGER_PATH

Expand Down Expand Up @@ -324,14 +324,22 @@ async def service_call(cls, domain, name, **kwargs):
for keyword, typ, default in [
("context", [Context], cls.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
("limit", [float, int], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
hass_args[keyword] = kwargs.pop(keyword)
elif default:
hass_args[keyword] = default

await cls.hass.services.async_call(domain, name, kwargs, **hass_args)
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, name) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

return await cls.hass.services.async_call(domain, name, kwargs, **hass_args)

@classmethod
async def service_completions(cls, root):
Expand Down Expand Up @@ -394,6 +402,7 @@ async def service_call(*args, **kwargs):
for keyword, typ, default in [
("context", [Context], cls.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
("limit", [float, int], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
Expand All @@ -404,7 +413,14 @@ async def service_call(*args, **kwargs):
if len(args) != 0:
raise TypeError(f"service {domain}.{service} takes only keyword arguments")

await cls.hass.services.async_call(domain, service, kwargs, **hass_args)
if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)

return service_call

Expand Down Expand Up @@ -450,7 +466,7 @@ def create_task(cls, coro, ast_ctx=None):
return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx))

@classmethod
def service_register(cls, global_ctx_name, domain, service, callback):
def service_register(cls, global_ctx_name, domain, service, callback, supports_response = SupportsResponse.NONE):
"""Register a new service callback."""
key = f"{domain}.{service}"
if key not in cls.service_cnt:
Expand All @@ -462,7 +478,7 @@ def service_register(cls, global_ctx_name, domain, service, callback):
f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}"
)
cls.service_cnt[key] += 1
cls.hass.services.async_register(domain, service, callback)
cls.hass.services.async_register(domain, service, callback, supports_response = supports_response)

@classmethod
def service_remove(cls, global_ctx_name, domain, service):
Expand Down
13 changes: 11 additions & 2 deletions custom_components/pyscript/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging

from homeassistant.core import Context
from homeassistant.core import Context, SupportsResponse
from homeassistant.helpers.restore_state import DATA_RESTORE_STATE
from homeassistant.helpers.service import async_get_all_descriptions

Expand Down Expand Up @@ -290,6 +290,7 @@ async def service_call(*args, **kwargs):
for keyword, typ, default in [
("context", [Context], Function.task2context.get(curr_task, None)),
("blocking", [bool], None),
("return_response", [bool], None),
("limit", [float, int], None),
]:
if keyword in kwargs and type(kwargs[keyword]) in typ:
Expand All @@ -306,7 +307,15 @@ async def service_call(*args, **kwargs):
kwargs[param_name] = args[0]
elif len(args) != 0:
raise TypeError(f"service {domain}.{service} takes no positional arguments")
await cls.hass.services.async_call(domain, service, kwargs, **hass_args)

if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args:
hass_args["blocking"] = True
elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY:
hass_args["return_response"] = True
if "blocking" not in hass_args:
hass_args["blocking"] = True

return await cls.hass.services.async_call(domain, service, kwargs, **hass_args)

return service_call

Expand Down

0 comments on commit 212cc50

Please sign in to comment.