diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 301ef3a..fd1f15f 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -15,6 +15,7 @@ import yaml +from homeassistant.core import SupportsResponse from homeassistant.const import SERVICE_RELOAD from homeassistant.helpers.service import async_set_service_schema @@ -377,6 +378,7 @@ async def trigger_init(self, trig_ctx, func_name): "time_trigger": {"kwargs": {dict}}, "task_unique": {"kill_me": {bool, int}}, "time_active": {"hold_off": {int, float}}, + "service": {"supports_response": {str}}, "state_trigger": { "kwargs": {dict}, "state_hold": {int, float}, @@ -485,11 +487,14 @@ async def pyscript_service_handler(call): func_args.update(call.data) async def do_service_call(func, ast_ctx, data): - await func.call(ast_ctx, **data) + retval = await func.call(ast_ctx, **data) if ast_ctx.get_exception_obj(): ast_ctx.get_logger().error(ast_ctx.get_exception_long()) + return retval - Function.create_task(do_service_call(func, ast_ctx, func_args)) + task = Function.create_task(do_service_call(func, ast_ctx, func_args)) + await task + return task.result() return pyscript_service_handler @@ -500,7 +505,7 @@ async def do_service_call(func, ast_ctx, data): if name in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START): raise SyntaxError(f"{exc_mesg}: @service conflicts with builtin service") Function.service_register( - trig_ctx_name, domain, name, pyscript_service_factory(func_name, self) + trig_ctx_name, domain, name, pyscript_service_factory(func_name, self), dec_kwargs.get("supports_response", SupportsResponse.NONE) ) async_set_service_schema(Function.hass, domain, name, service_desc) self.trigger_service.add(srv_name) diff --git a/custom_components/pyscript/function.py b/custom_components/pyscript/function.py index 6bc0a53..54f49d5 100644 --- a/custom_components/pyscript/function.py +++ b/custom_components/pyscript/function.py @@ -4,7 +4,7 @@ import logging import traceback -from homeassistant.core import Context +from homeassistant.core import Context, SupportsResponse from .const import LOGGER_PATH @@ -324,6 +324,7 @@ async def service_call(cls, domain, name, **kwargs): for keyword, typ, default in [ ("context", [Context], cls.task2context.get(curr_task, None)), ("blocking", [bool], None), + ("return_response", [bool], None), ("limit", [float, int], None), ]: if keyword in kwargs and type(kwargs[keyword]) in typ: @@ -331,7 +332,14 @@ async def service_call(cls, domain, name, **kwargs): elif default: hass_args[keyword] = default - await cls.hass.services.async_call(domain, name, kwargs, **hass_args) + if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args: + hass_args["blocking"] = True + elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, name) == SupportsResponse.ONLY: + hass_args["return_response"] = True + if "blocking" not in hass_args: + hass_args["blocking"] = True + + return await cls.hass.services.async_call(domain, name, kwargs, **hass_args) @classmethod async def service_completions(cls, root): @@ -394,6 +402,7 @@ async def service_call(*args, **kwargs): for keyword, typ, default in [ ("context", [Context], cls.task2context.get(curr_task, None)), ("blocking", [bool], None), + ("return_response", [bool], None), ("limit", [float, int], None), ]: if keyword in kwargs and type(kwargs[keyword]) in typ: @@ -404,7 +413,14 @@ async def service_call(*args, **kwargs): if len(args) != 0: raise TypeError(f"service {domain}.{service} takes only keyword arguments") - await cls.hass.services.async_call(domain, service, kwargs, **hass_args) + if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args: + hass_args["blocking"] = True + elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY: + hass_args["return_response"] = True + if "blocking" not in hass_args: + hass_args["blocking"] = True + + return await cls.hass.services.async_call(domain, service, kwargs, **hass_args) return service_call @@ -450,7 +466,7 @@ def create_task(cls, coro, ast_ctx=None): return cls.hass.loop.create_task(cls.run_coro(coro, ast_ctx=ast_ctx)) @classmethod - def service_register(cls, global_ctx_name, domain, service, callback): + def service_register(cls, global_ctx_name, domain, service, callback, supports_response = SupportsResponse.NONE): """Register a new service callback.""" key = f"{domain}.{service}" if key not in cls.service_cnt: @@ -462,7 +478,7 @@ def service_register(cls, global_ctx_name, domain, service, callback): f"{global_ctx_name}: can't register service {key}; already defined in {cls.service2global_ctx[key]}" ) cls.service_cnt[key] += 1 - cls.hass.services.async_register(domain, service, callback) + cls.hass.services.async_register(domain, service, callback, supports_response = supports_response) @classmethod def service_remove(cls, global_ctx_name, domain, service): diff --git a/custom_components/pyscript/state.py b/custom_components/pyscript/state.py index 1986a8a..3c84d31 100644 --- a/custom_components/pyscript/state.py +++ b/custom_components/pyscript/state.py @@ -3,7 +3,7 @@ import asyncio import logging -from homeassistant.core import Context +from homeassistant.core import Context, SupportsResponse from homeassistant.helpers.restore_state import DATA_RESTORE_STATE from homeassistant.helpers.service import async_get_all_descriptions @@ -290,6 +290,7 @@ async def service_call(*args, **kwargs): for keyword, typ, default in [ ("context", [Context], Function.task2context.get(curr_task, None)), ("blocking", [bool], None), + ("return_response", [bool], None), ("limit", [float, int], None), ]: if keyword in kwargs and type(kwargs[keyword]) in typ: @@ -306,7 +307,15 @@ async def service_call(*args, **kwargs): kwargs[param_name] = args[0] elif len(args) != 0: raise TypeError(f"service {domain}.{service} takes no positional arguments") - await cls.hass.services.async_call(domain, service, kwargs, **hass_args) + + if "return_response" in hass_args and hass_args["return_response"] == True and "blocking" not in hass_args: + hass_args["blocking"] = True + elif "return_response" not in hass_args and cls.hass.services.supports_response(domain, service) == SupportsResponse.ONLY: + hass_args["return_response"] = True + if "blocking" not in hass_args: + hass_args["blocking"] = True + + return await cls.hass.services.async_call(domain, service, kwargs, **hass_args) return service_call