diff --git a/.cspell/frigate-dictionary.txt b/.cspell/frigate-dictionary.txt index 0cbcc4beb5..d5a2e3892b 100644 --- a/.cspell/frigate-dictionary.txt +++ b/.cspell/frigate-dictionary.txt @@ -42,6 +42,7 @@ codeproject colormap colorspace comms +coro ctypeslib CUDA Cuvid diff --git a/frigate/mypy.ini b/frigate/mypy.ini index d8f8493342..dd726f4549 100644 --- a/frigate/mypy.ini +++ b/frigate/mypy.ini @@ -59,3 +59,7 @@ ignore_errors = false [mypy-frigate.watchdog] ignore_errors = false disallow_untyped_calls = false + + +[mypy-frigate.service_manager.*] +ignore_errors = false diff --git a/frigate/service_manager/__init__.py b/frigate/service_manager/__init__.py new file mode 100644 index 0000000000..2da23b8b0a --- /dev/null +++ b/frigate/service_manager/__init__.py @@ -0,0 +1,4 @@ +from .multiprocessing import ServiceProcess +from .service import Service, ServiceManager + +__all__ = ["Service", "ServiceProcess", "ServiceManager"] diff --git a/frigate/service_manager/multiprocessing.py b/frigate/service_manager/multiprocessing.py new file mode 100644 index 0000000000..a716ac0b55 --- /dev/null +++ b/frigate/service_manager/multiprocessing.py @@ -0,0 +1,67 @@ +import asyncio +import multiprocessing as mp +from abc import ABC, abstractmethod +from asyncio.exceptions import TimeoutError +from typing import Optional + +from .multiprocessing_waiter import wait as mp_wait +from .service import Service, ServiceManager + +DEFAULT_STOP_TIMEOUT = 10 # seconds + + +class ServiceProcess(Service, ABC): + _process: mp.Process + + def __init__( + self, + name: Optional[str] = None, + manager: Optional[ServiceManager] = None, + ) -> None: + super().__init__(name=name, manager=manager) + + self._process_lock = asyncio.Lock() + + async def on_start(self) -> None: + async with self._process_lock: + if hasattr(self, "_process"): + if self._process.is_alive(): + self.manager.logger.debug( + "Attempted to start already running process" + f" {self.name} (pid: {self._process.pid})" + ) + return + else: + self._process.close() + + # At this point, the process is either stopped or dead, so we can recreate it. + self._process = mp.Process(name=self.name, target=self.run, daemon=True) + self._process.start() + self.manager.logger.info(f"Started {self.name} (pid: {self._process.pid})") + + async def on_stop(self, *, timeout: Optional[float] = None) -> None: + if timeout is None: + timeout = DEFAULT_STOP_TIMEOUT + + async with self._process_lock: + if not hasattr(self, "_process"): + return # Already stopped. + + self._process.terminate() + try: + await asyncio.wait_for(mp_wait(self._process), timeout) + except TimeoutError: + self.manager.logger.warning( + f"{self.name} is still running after " + f"{timeout} seconds. Killing." + ) + self._process.kill() + await mp_wait(self._process) + + del self._process + + self.manager.logger.info(f"{self.name} stopped") + + @abstractmethod + def run(self) -> None: + pass diff --git a/frigate/service_manager/multiprocessing_waiter.py b/frigate/service_manager/multiprocessing_waiter.py new file mode 100644 index 0000000000..8acdf583c7 --- /dev/null +++ b/frigate/service_manager/multiprocessing_waiter.py @@ -0,0 +1,150 @@ +import asyncio +import functools +import logging +import multiprocessing as mp +import queue +import threading +from multiprocessing.connection import Connection +from multiprocessing.connection import wait as mp_wait +from socket import socket +from typing import Any, Optional, Union + +logger = logging.getLogger(__name__) + + +class MultiprocessingWaiter(threading.Thread): + """A background thread that manages futures for the multiprocessing.connection.wait() method.""" + + def __init__(self) -> None: + super().__init__(daemon=True) + + # Queue of objects to wait for and futures to set results for. + self._queue: queue.Queue[tuple[Any, asyncio.Future[None]]] = queue.Queue() + + # This is required to get mp_wait() to wake up when new objects to wait for are received. + receive, send = mp.Pipe(duplex=False) + self._receive_connection = receive + self._send_connection = send + + def wait_for_sentinel(self, sentinel: Any) -> asyncio.Future[None]: + """Create an asyncio.Future tracking a sentinel for multiprocessing.connection.wait() + + Warning: This method is NOT thread-safe. + """ + # This would be incredibly stupid, but you never know. + assert sentinel != self._receive_connection + + # Send the future to the background thread for processing. + future = asyncio.get_running_loop().create_future() + self._queue.put((sentinel, future)) + + # Notify the background thread. + # + # This is the non-thread-safe part, but since this method is not really meant to be called + # by users, we can get away with not adding a lock at this point (to avoid adding 2 locks). + self._send_connection.send_bytes(b".") + + return future + + def run(self) -> None: + logger.debug("Started background thread") + + wait_dict: dict[Any, set[asyncio.Future[None]]] = { + self._receive_connection: set() + } + while True: + for ready_obj in mp_wait(wait_dict.keys()): + # Make sure we never remove the receive connection from the wait dict + if ready_obj is self._receive_connection: + continue + + logger.debug( + f"Sentinel {ready_obj!r} is ready. " + f"Notifying {len(wait_dict[ready_obj])} future(s)." + ) + + # Go over all the futures attached to this object and mark them as ready. + for fut in wait_dict.pop(ready_obj): + if fut.cancelled(): + logger.debug( + f"A future for sentinel {ready_obj!r} is ready, " + "but the future is cancelled. Skipping." + ) + else: + fut.get_loop().call_soon_threadsafe( + # Note: We need to check fut.cancelled() again, since it might + # have been set before the event loop's definition of "soon". + functools.partial( + lambda fut: fut.cancelled() or fut.set_result(None), fut + ) + ) + + # Check for cancellations in the remaining futures. + done_objects = [] + for obj, fut_set in wait_dict.items(): + if obj is self._receive_connection: + continue + + # Find any cancelled futures and remove them. + cancelled = [fut for fut in fut_set if fut.cancelled()] + fut_set.difference_update(cancelled) + logger.debug( + f"Removing {len(cancelled)} future(s) from sentinel: {obj!r}" + ) + + # Mark objects with no remaining futures for removal. + if len(fut_set) == 0: + done_objects.append(obj) + + # Remove any objects that are done after removing cancelled futures. + for obj in done_objects: + logger.debug( + f"Sentinel {obj!r} no longer has any futures waiting for it." + ) + del wait_dict[obj] + + # Get new objects to wait for from the queue. + while True: + try: + obj, fut = self._queue.get_nowait() + self._receive_connection.recv_bytes(maxlength=1) + self._queue.task_done() + + logger.debug(f"Received new sentinel: {obj!r}") + + wait_dict.setdefault(obj, set()).add(fut) + except queue.Empty: + break + + +waiter_lock = threading.Lock() +waiter_thread: Optional[MultiprocessingWaiter] = None + + +async def wait(object: Union[mp.Process, Connection, socket]) -> None: + """Wait for the supplied object to be ready. + + Under the hood, this uses multiprocessing.connection.wait() and a background thread manage the + returned futures. + """ + global waiter_thread, waiter_lock + + sentinel: Union[Connection, socket, int] + if isinstance(object, mp.Process): + sentinel = object.sentinel + elif isinstance(object, Connection) or isinstance(object, socket): + sentinel = object + else: + raise ValueError(f"Cannot wait for object of type {type(object).__qualname__}") + + with waiter_lock: + if waiter_thread is None: + # Start a new waiter thread. + waiter_thread = MultiprocessingWaiter() + waiter_thread.start() + + # Create the future while still holding the lock, + # since wait_for_sentinel() is not thread safe. + fut = waiter_thread.wait_for_sentinel(sentinel) + + await fut diff --git a/frigate/service_manager/service.py b/frigate/service_manager/service.py new file mode 100644 index 0000000000..eb11a5b652 --- /dev/null +++ b/frigate/service_manager/service.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import asyncio +import atexit +import logging +import threading +from abc import ABC, abstractmethod +from contextvars import ContextVar +from dataclasses import dataclass +from types import TracebackType +from typing import Coroutine, Optional, Union + +from typing_extensions import Self + + +@dataclass +class Command: + coro: Coroutine + done: Optional[threading.Event] = None + + +class Service(ABC): + def __init__( + self, + *, + name: Optional[str] = None, + manager: Optional[ServiceManager] = None, + ): + self.__name = name or type(self).__qualname__ + + self.__manager = manager or ServiceManager.current() + self.__manager.register(self) + + @property + def name(self) -> str: + return self.__name + + @property + def manager(self) -> ServiceManager: + return self.__manager + + def start( + self, + *, + wait: bool = False, + wait_timeout: Optional[float] = None, + ) -> None: + self.manager.run_task(self.on_start(), wait=wait, wait_timeout=wait_timeout) + + @abstractmethod + async def on_start(self) -> None: + pass + + def stop( + self, + *, + timeout: Optional[float] = None, + wait: bool = False, + wait_timeout: Optional[float] = None, + ) -> None: + self.manager.run_task( + self.on_stop(timeout=timeout), + wait=wait, + wait_timeout=wait_timeout, + ) + + @abstractmethod + async def on_stop(self, *, timeout: Optional[float] = None) -> None: + pass + + def restart( + self, + wait: bool = False, + wait_timeout: Optional[float] = None, + ) -> None: + self.manager.run_task( + self.on_restart(), + wait=wait, + wait_timeout=wait_timeout, + ) + + async def on_restart(self) -> None: + await self.on_stop() + await self.on_start() + + +current_service_manager: ContextVar[ServiceManager] = ContextVar( + "current_service_manager" +) + + +class ServiceManager: + _name: str + _services: dict[str, Service] + _services_lock: threading.Lock + + _command_queue: asyncio.Queue + _event_loop: asyncio.AbstractEventLoop + _setup_event: threading.Event + + def __init__(self, *, name: Optional[str] = None): + self._name = name if name is not None else (__package__ or __name__) + self.logger = logging.getLogger(self.name) + + # The set of registered services. + self._services = dict() + self._services_lock = threading.Lock() + + # --- Start the manager thread and wait for it to be ready. --- + + self._setup_event = threading.Event() + + async def start_manager() -> None: + self._event_loop = asyncio.get_running_loop() + self._command_queue = asyncio.Queue() + + self._setup_event.set() + await self._run_manager() + + self._manager_thread = threading.Thread( + name=self.name, + target=lambda: asyncio.run(start_manager()), + daemon=True, + ) + + self._manager_thread.start() + atexit.register(self.shutdown) + + def run_task( + self, + coro: Coroutine, + *, + wait: bool = False, + wait_timeout: Optional[float] = None, + ) -> None: + """Run an async task in the background thread.""" + + if not isinstance(coro, Coroutine): + raise ValueError(f"Cannot schedule task for object of type {type(coro)}") + + cmd = Command(coro=coro) + if wait or wait_timeout is not None: + cmd.done = threading.Event() + + self._send_command(cmd) + + if cmd.done is not None: + cmd.done.wait(timeout=wait_timeout) + + def register(self, service: Service) -> None: + self._ensure_running() + with self._services_lock: + name_conflict: Optional[Service] = next( + ( + existing + for name, existing in self._services.items() + if name == service.name + ), + None, + ) + + if name_conflict is service: + raise RuntimeError(f"Attempt to re-register service: {service.name}") + elif name_conflict is not None: + raise RuntimeError(f"Duplicate service name: {service.name}") + + self.logger.debug(f"Registering service: {service.name}") + self._services[service.name] = service + + def shutdown(self) -> None: + """Shutdown the service manager.""" + + self._send_command(None) + self._manager_thread.join() + + def _ensure_running(self) -> None: + self._setup_event.wait() + if not self._manager_thread.is_alive(): + raise RuntimeError(f"ServiceManager {self.name} is not running") + + def _send_command(self, command: Union[Command, None]) -> None: + self._ensure_running() + asyncio.run_coroutine_threadsafe( + self._command_queue.put(command), self._event_loop + ) + + def __enter__(self) -> Self: + self._context_token = current_service_manager.set(self) + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_info: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + current_service_manager.reset(self._context_token) + del self._context_token + + @classmethod + def current(cls) -> ServiceManager: + current = current_service_manager.get(None) + if current is None: + current = cls() + current_service_manager.set(current) + return current + + @property + def name(self) -> str: + return self._name + + async def _run_manager(self) -> None: + self.logger.info("Started service manager") + + tasks = set() + + def run_command(command: Command) -> None: + def task_done(task: asyncio.Task) -> None: + exc = task.exception() + if exc: + self.logger.exception( + "Exception in service manager task", exc_info=exc + ) + tasks.discard(task) + if command.done is not None: + command.done.set() + + task = asyncio.create_task(command.coro) + tasks.add(task) + task.add_done_callback(task_done) + + # Main command processing loop. + while (command := await self._command_queue.get()) is not None: + run_command(command) + + # Stop all services. + with self._services_lock: + self.logger.debug(f"Stopping {len(self._services)} services") + for service in self._services.values(): + run_command(Command(coro=service.on_stop())) + + # Wait for any pending tasks. + if tasks: + self.logger.debug(f"Waiting for {len(tasks)} tasks to finish") + done, pending = await asyncio.wait(tasks) + if len(pending) > 0: + self.logger.warning(f"{len(pending)} tasks did not finish on shutdown") + + self.logger.info("Exiting service manager")