diff --git a/src/blueapi/config.py b/src/blueapi/config.py index e99d02b15..1055eea8b 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -46,8 +46,8 @@ class TiledConfig(BaseModel): Config for connecting to a tiled instance """ - uri: str - api_key: str + host: str + port: int class WorkerEventConfig(BlueapiBaseModel): diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 620b8114a..162064c14 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -3,18 +3,17 @@ from functools import cache from typing import Any -from bluesky.callbacks.tiled_writer import TiledWriter from bluesky_stomp.messaging import StompClient from bluesky_stomp.models import Broker, DestinationBase, MessageTopic -from tiled.client import from_uri -from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig, TiledConfig +from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream from blueapi.service.model import DeviceModel, PlanModel, WorkerTask from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask +from blueapi.worker.tiled import TiledConnection """This module provides interface between web application and underlying Bluesky context and worker""" @@ -42,27 +41,16 @@ def context() -> BlueskyContext: @cache def worker() -> TaskWorker: + conf = config() worker = TaskWorker( context(), - broadcast_statuses=config().env.events.broadcast_status_events, + broadcast_statuses=conf.env.events.broadcast_status_events, + tiled_inserter=TiledConnection(conf.tiled) if conf.tiled else None, ) worker.start() return worker -@cache -def tiled_inserter(): - tiled_config: TiledConfig | None = config().tiled - if tiled_config is not None: - client = from_uri(tiled_config.uri, api_key=tiled_config.api_key) - - ctx = context() - ctx.run_engine.subscribe(TiledWriter(client)) - return client - else: - return None - - @cache def stomp_client() -> StompClient | None: stomp_config: StompConfig | None = config().stomp @@ -101,7 +89,6 @@ def setup(config: ApplicationConfig) -> None: logging.basicConfig(format="%(asctime)s - %(message)s", level=config.logging.level) worker() stomp_client() - tiled_inserter() def teardown() -> None: diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 546c5f3b5..67c22119c 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -9,6 +9,7 @@ from typing import Any, Generic, TypeVar from bluesky.protocols import Status +from httpx import Headers from observability_utils.tracing import ( add_span_attributes, get_tracer, @@ -32,6 +33,7 @@ from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop from blueapi.utils.base_model import BlueapiBaseModel from blueapi.utils.thread_exception import handle_all_exceptions +from blueapi.worker.tiled import TiledConnection from .event import ( ProgressEvent, @@ -112,9 +114,11 @@ def __init__( ctx: BlueskyContext, start_stop_timeout: float = DEFAULT_START_STOP_TIMEOUT, broadcast_statuses: bool = True, + tiled_inserter: TiledConnection | None = None, ) -> None: self._ctx = ctx self._start_stop_timeout = start_stop_timeout + self._tiled_inserter = tiled_inserter self._tasks = {} @@ -194,13 +198,25 @@ def get_active_task(self) -> TrackableTask[Task] | None: return current @start_as_current_span(TRACER, "task_id") - def begin_task(self, task_id: str) -> None: + def begin_task(self, task_id: str, headers: Headers | None) -> None: task = self._tasks.get(task_id) + data_subs: list[int] = [] if task is not None: - self._submit_trackable_task(task) + if self._tiled_inserter: + data_subs.append(self._authorize_running_task(headers)) + self._submit_trackable_task(task, data_subs) + else: raise KeyError(f"No pending task with ID {task_id}") + def _authorize_running_task(self, headers: Headers | None) -> int: + assert self._tiled_inserter + # https://github.com/DiamondLightSource/blueapi/issues/774 + # If users should only be able to run their own scans, pass headers + # as part of submitting a task, cache in TrackableTask field and check + # that token belongs to same user (but may be newer token!) + return self.data_events.subscribe(self._tiled_inserter(headers)) + @start_as_current_span(TRACER, "task.name", "task.params") def submit_task(self, task: Task) -> str: task.prepare_params(self._ctx) # Will raise if parameters are invalid @@ -218,7 +234,9 @@ def submit_task(self, task: Task) -> str: "trackable_task.task.name", "trackable_task.task.params", ) - def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: + def _submit_trackable_task( + self, trackable_task: TrackableTask, data_subs: list[int] | None = None + ) -> None: if self.state is not WorkerState.IDLE: raise WorkerBusyError(f"Worker is in state {self.state}") @@ -235,17 +253,18 @@ def mark_task_as_started(event: WorkerEvent, _: str | None) -> None: sub = self.worker_events.subscribe(mark_task_as_started) try: self._current_task_otel_context = get_current() - sub = self.worker_events.subscribe(mark_task_as_started) """ Cache the current trace context as the one for this task id """ self._task_channel.put_nowait(trackable_task) - task_started.wait(timeout=5.0) - if not task_started.is_set(): + if not task_started.wait(timeout=5.0): raise TimeoutError("Failed to start plan within timeout") except Full as f: LOGGER.error("Cannot submit task while another is running") raise WorkerBusyError("Cannot submit task while another is running") from f finally: self.worker_events.unsubscribe(sub) + if data_subs: + for data_sub in data_subs: + self.data_events.unsubscribe(data_sub) @start_as_current_span(TRACER) def start(self) -> None: diff --git a/src/blueapi/worker/tiled.py b/src/blueapi/worker/tiled.py new file mode 100644 index 000000000..816f0156a --- /dev/null +++ b/src/blueapi/worker/tiled.py @@ -0,0 +1,23 @@ +from bluesky.callbacks.tiled_writer import TiledWriter +from httpx import Headers +from tiled.client import from_context +from tiled.client.context import Context as TiledContext + +from blueapi.config import TiledConfig +from blueapi.core.bluesky_types import DataEvent + + +class TiledConverter: + def __init__(self, tiled_context: TiledContext): + self._writer: TiledWriter = TiledWriter(from_context(tiled_context)) + + def __call__(self, data: DataEvent, _: str | None = None) -> None: + self._writer(data.name, data.doc) + + +class TiledConnection: + def __init__(self, config: TiledConfig): + self.uri = f"{config.host}:{config.port}" + + def __call__(self, headers: Headers | None): + return TiledConverter(TiledContext(self.uri, headers=headers))