From 26b12371e1eb7af9b01e2ec14528b86c4c441a0b Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 12 Oct 2023 18:23:51 +0200 Subject: [PATCH] Replace Ypy with pycrdt --- plugins/yjs/fps_yjs/routes.py | 22 +- plugins/yjs/fps_yjs/ydocs/__init__.py | 9 + plugins/yjs/fps_yjs/ydocs/utils.py | 26 + plugins/yjs/fps_yjs/ydocs/ybasedoc.py | 69 +++ plugins/yjs/fps_yjs/ydocs/yblob.py | 39 ++ plugins/yjs/fps_yjs/ydocs/yfile.py | 5 + plugins/yjs/fps_yjs/ydocs/ynotebook.py | 144 ++++++ plugins/yjs/fps_yjs/ydocs/yunicode.py | 33 ++ plugins/yjs/fps_yjs/ywebsocket/__init__.py | 4 + plugins/yjs/fps_yjs/ywebsocket/asgi_server.py | 92 ++++ plugins/yjs/fps_yjs/ywebsocket/awareness.py | 65 +++ .../ywebsocket/django_channels_consumer.py | 196 ++++++++ plugins/yjs/fps_yjs/ywebsocket/websocket.py | 58 +++ .../fps_yjs/ywebsocket/websocket_provider.py | 139 ++++++ .../fps_yjs/ywebsocket/websocket_server.py | 204 ++++++++ plugins/yjs/fps_yjs/ywebsocket/yroom.py | 235 +++++++++ plugins/yjs/fps_yjs/ywebsocket/ystore.py | 447 ++++++++++++++++++ plugins/yjs/fps_yjs/ywebsocket/yutils.py | 155 ++++++ plugins/yjs/pyproject.toml | 10 +- tests/test_server.py | 8 +- 20 files changed, 1942 insertions(+), 18 deletions(-) create mode 100644 plugins/yjs/fps_yjs/ydocs/__init__.py create mode 100644 plugins/yjs/fps_yjs/ydocs/utils.py create mode 100644 plugins/yjs/fps_yjs/ydocs/ybasedoc.py create mode 100644 plugins/yjs/fps_yjs/ydocs/yblob.py create mode 100644 plugins/yjs/fps_yjs/ydocs/yfile.py create mode 100644 plugins/yjs/fps_yjs/ydocs/ynotebook.py create mode 100644 plugins/yjs/fps_yjs/ydocs/yunicode.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/__init__.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/asgi_server.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/awareness.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/django_channels_consumer.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/websocket.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/websocket_server.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/yroom.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/ystore.py create mode 100644 plugins/yjs/fps_yjs/ywebsocket/yutils.py diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 148c5045..ddc373c0 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -14,17 +14,17 @@ WebSocketDisconnect, status, ) -from jupyter_ydoc import ydocs as YDOCS -from jupyter_ydoc.ybasedoc import YBaseDoc +from .ydocs import ydocs as YDOCS +from .ydocs.ybasedoc import YBaseDoc from jupyverse_api.app import App from jupyverse_api.auth import Auth, User from jupyverse_api.contents import Contents from jupyverse_api.yjs import Yjs from jupyverse_api.yjs.models import CreateDocumentSession from websockets.exceptions import ConnectionClosedOK -from ypy_websocket.websocket_server import WebsocketServer, YRoom -from ypy_websocket.ystore import SQLiteYStore, YDocNotFound -from ypy_websocket.yutils import YMessageType, YSyncMessageType +from .ywebsocket.websocket_server import WebsocketServer, YRoom +from .ywebsocket.ystore import SQLiteYStore, YDocNotFound +from .ywebsocket.yutils import YMessageType, YSyncMessageType YFILE = YDOCS["file"] AWARENESS = 1 @@ -56,8 +56,8 @@ async def collaboration_room_websocket( return websocket, permissions = websocket_permissions await websocket.accept() - ypy_websocket = YpyWebsocket(websocket, path) - await self.room_manager.serve(ypy_websocket, permissions) + ywebsocket = YWebsocket(websocket, path) + await self.room_manager.serve(ywebsocket, permissions) async def create_roomid( self, @@ -95,8 +95,8 @@ def to_datetime(iso_date: str) -> datetime: return datetime.fromisoformat(iso_date.rstrip("Z")) -class YpyWebsocket: - """An wrapper to make a Starlette's WebSocket look like a ypy-websocket's WebSocket""" +class YWebsocket: + """An wrapper to make a Starlette's WebSocket look like a ywebsocket's WebSocket""" def __init__(self, websocket, path: str): self._websocket = websocket @@ -160,7 +160,7 @@ def stop(self): cleaner.cancel() self.websocket_server.stop() - async def serve(self, websocket: YpyWebsocket, permissions) -> None: + async def serve(self, websocket: YWebsocket, permissions) -> None: room = await self.websocket_server.get_room(websocket.path) can_write = permissions is None or "write" in permissions.get("yjs", []) room.on_message = partial(self.filter_message, can_write) @@ -309,7 +309,7 @@ async def maybe_save_document( # if the room cannot be found, don't save try: file_path = await self.get_file_path(file_id, document) - except BaseException: + except Exception: return assert file_path is not None async with self.lock: diff --git a/plugins/yjs/fps_yjs/ydocs/__init__.py b/plugins/yjs/fps_yjs/ydocs/__init__.py new file mode 100644 index 00000000..0704bb37 --- /dev/null +++ b/plugins/yjs/fps_yjs/ydocs/__init__.py @@ -0,0 +1,9 @@ +import sys + + +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyverse_ydoc")} diff --git a/plugins/yjs/fps_yjs/ydocs/utils.py b/plugins/yjs/fps_yjs/ydocs/utils.py new file mode 100644 index 00000000..c8e82d9b --- /dev/null +++ b/plugins/yjs/fps_yjs/ydocs/utils.py @@ -0,0 +1,26 @@ +from typing import Dict, List, Type, Union + +INT = Type[int] +FLOAT = Type[float] + + +def cast_all( + o: Union[List, Dict], from_type: Union[INT, FLOAT], to_type: Union[FLOAT, INT] +) -> Union[List, Dict]: + if isinstance(o, list): + for i, v in enumerate(o): + if type(v) is from_type: + v2 = to_type(v) + if v == v2: + o[i] = v2 + elif isinstance(v, (list, dict)): + cast_all(v, from_type, to_type) + elif isinstance(o, dict): + for k, v in o.items(): + if type(v) is from_type: + v2 = to_type(v) + if v == v2: + o[k] = v2 + elif isinstance(v, (list, dict)): + cast_all(v, from_type, to_type) + return o diff --git a/plugins/yjs/fps_yjs/ydocs/ybasedoc.py b/plugins/yjs/fps_yjs/ydocs/ybasedoc.py new file mode 100644 index 00000000..8b68dad4 --- /dev/null +++ b/plugins/yjs/fps_yjs/ydocs/ybasedoc.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional + +from pycrdt import Doc, Map + + +class YBaseDoc(ABC): + def __init__(self, ydoc: Optional[Doc] = None): + if ydoc is None: + self._ydoc = Doc() + else: + self._ydoc = ydoc + self._ystate = Map() + self._ydoc["state"] = self._ystate + self._subscriptions: Dict[Any, str] = {} + + @property + @abstractmethod + def version(self) -> str: + ... + + @property + def ystate(self) -> Map: + return self._ystate + + @property + def ydoc(self) -> Doc: + return self._ydoc + + @property + def source(self) -> Any: + return self.get() + + @source.setter + def source(self, value: Any): + return self.set(value) + + @property + def dirty(self) -> Optional[bool]: + return self._ystate.get("dirty") + + @dirty.setter + def dirty(self, value: bool) -> None: + self._ystate["dirty"] = value + + @property + def path(self) -> Optional[str]: + return self._ystate.get("path") + + @path.setter + def path(self, value: str) -> None: + self._ystate["path"] = value + + @abstractmethod + def get(self) -> Any: + ... + + @abstractmethod + def set(self, value: Any) -> None: + ... + + @abstractmethod + def observe(self, callback: Callable[[str, Any], None]) -> None: + ... + + def unobserve(self) -> None: + for k, v in self._subscriptions.items(): + k.unobserve(v) + self._subscriptions = {} diff --git a/plugins/yjs/fps_yjs/ydocs/yblob.py b/plugins/yjs/fps_yjs/ydocs/yblob.py new file mode 100644 index 00000000..74813fa2 --- /dev/null +++ b/plugins/yjs/fps_yjs/ydocs/yblob.py @@ -0,0 +1,39 @@ +import base64 +from functools import partial +from typing import Any, Callable, Optional, Union + +from pycrdt import Doc, Map + +from .ybasedoc import YBaseDoc + + +class YBlob(YBaseDoc): + """ + Extends :class:`YBaseDoc`, and represents a blob document. + It is currently encoded as base64 because of: + https://github.com/y-crdt/ypy/issues/108#issuecomment-1377055465 + The Y document can be set from bytes or from str, in which case it is assumed to be encoded as + base64. + """ + + def __init__(self, ydoc: Optional[Doc] = None): + super().__init__(ydoc) + self._ysource = Map() + self._ydoc["source"] = self._ysource + + @property + def version(self) -> str: + return "1.0.0" + + def get(self) -> bytes: + return base64.b64decode(self._ysource["base64"].encode()) + + def set(self, value: Union[bytes, str]) -> None: + if isinstance(value, bytes): + value = base64.b64encode(value).decode() + self._ysource["base64"] = value + + def observe(self, callback: Callable[[str, Any], None]) -> None: + self.unobserve() + self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state")) + self._subscriptions[self._ysource] = self._ysource.observe(partial(callback, "source")) diff --git a/plugins/yjs/fps_yjs/ydocs/yfile.py b/plugins/yjs/fps_yjs/ydocs/yfile.py new file mode 100644 index 00000000..5f102de6 --- /dev/null +++ b/plugins/yjs/fps_yjs/ydocs/yfile.py @@ -0,0 +1,5 @@ +from .yunicode import YUnicode + + +class YFile(YUnicode): # for backwards-compatibility + pass diff --git a/plugins/yjs/fps_yjs/ydocs/ynotebook.py b/plugins/yjs/fps_yjs/ydocs/ynotebook.py new file mode 100644 index 00000000..6316ba19 --- /dev/null +++ b/plugins/yjs/fps_yjs/ydocs/ynotebook.py @@ -0,0 +1,144 @@ +import copy +import json +from functools import partial +from typing import Any, Callable, Dict, Optional +from uuid import uuid4 + +from pycrdt import Array, Doc, Map, Text + +from .utils import cast_all +from .ybasedoc import YBaseDoc + +# The default major version of the notebook format. +NBFORMAT_MAJOR_VERSION = 4 +# The default minor version of the notebook format. +NBFORMAT_MINOR_VERSION = 5 + + +class YNotebook(YBaseDoc): + def __init__(self, ydoc: Optional[Doc] = None): + super().__init__(ydoc) + self._ymeta = Map() + self._ycells = Array() + self._ydoc["meta"] = self._ymeta + self._ydoc["cells"] = self._ycells + + @property + def version(self) -> str: + return "1.0.0" + + @property + def ycells(self): + return self._ycells + + @property + def cell_number(self) -> int: + return len(self._ycells) + + def get_cell(self, index: int) -> Dict[str, Any]: + meta = json.loads(str(self._ymeta)) + cell = json.loads(str(self._ycells[index])) + cast_all(cell, float, int) # cells coming from Yjs have e.g. execution_count as float + if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4: + # strip cell IDs if we have notebook format 4.0-4.4 + del cell["id"] + if ( + "attachments" in cell + and cell["cell_type"] in ("raw", "markdown") + and not cell["attachments"] + ): + del cell["attachments"] + return cell + + def append_cell(self, value: Dict[str, Any]) -> None: + ycell = self.create_ycell(value) + self._ycells.append(ycell) + + def set_cell(self, index: int, value: Dict[str, Any]) -> None: + ycell = self.create_ycell(value) + self.set_ycell(index, ycell) + + def create_ycell(self, value: Dict[str, Any]) -> Map: + cell = copy.deepcopy(value) + if "id" not in cell: + cell["id"] = str(uuid4()) + cell_type = cell["cell_type"] + cell_source = cell["source"] + cell_source = "".join(cell_source) if isinstance(cell_source, list) else cell_source + cell["source"] = Text(cell_source) + cell["metadata"] = Map(cell.get("metadata", {})) + + if cell_type in ("raw", "markdown"): + if "attachments" in cell and not cell["attachments"]: + del cell["attachments"] + elif cell_type == "code": + cell["outputs"] = Array(cell.get("outputs", [])) + + return Map(cell) + + def set_ycell(self, index: int, ycell: Map) -> None: + self._ycells[index] = ycell + + def get(self) -> Dict: + meta = json.loads(str(self._ymeta)) + cast_all(meta, float, int) # notebook coming from Yjs has e.g. nbformat as float + cells = [] + for i in range(len(self._ycells)): + cell = self.get_cell(i) + if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4: + # strip cell IDs if we have notebook format 4.0-4.4 + del cell["id"] + if ( + "attachments" in cell + and cell["cell_type"] in ["raw", "markdown"] + and not cell["attachments"] + ): + del cell["attachments"] + cells.append(cell) + + return dict( + cells=cells, + metadata=meta.get("metadata", {}), + nbformat=int(meta.get("nbformat", 0)), + nbformat_minor=int(meta.get("nbformat_minor", 0)), + ) + + def set(self, value: Dict) -> None: + nb_without_cells = {key: value[key] for key in value.keys() if key != "cells"} + nb = copy.deepcopy(nb_without_cells) + cast_all(nb, int, float) # Yjs expects numbers to be floating numbers + cells = value["cells"] or [ + { + "cell_type": "code", + "execution_count": None, + # auto-created empty code cell without outputs ought be trusted + "metadata": {"trusted": True}, + "outputs": [], + "source": "", + "id": str(uuid4()), + } + ] + + with self._ydoc.transaction(): + # clear document + self._ymeta.clear() + self._ycells.clear() + for key in [k for k in self._ystate.keys() if k not in ("dirty", "path")]: + del self._ystate[key] + + # initialize document + self._ycells.extend([self.create_ycell(cell) for cell in cells]) + self._ymeta["nbformat"] = nb.get("nbformat", NBFORMAT_MAJOR_VERSION) + self._ymeta["nbformat_minor"] = nb.get("nbformat_minor", NBFORMAT_MINOR_VERSION) + + metadata = nb.get("metadata", {}) + metadata.setdefault("language_info", {"name": ""}) + metadata.setdefault("kernelspec", {"name": "", "display_name": ""}) + + self._ymeta["metadata"] = Map(metadata) + + def observe(self, callback: Callable[[str, Any], None]) -> None: + self.unobserve() + self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state")) + self._subscriptions[self._ymeta] = self._ymeta.observe_deep(partial(callback, "meta")) + self._subscriptions[self._ycells] = self._ycells.observe_deep(partial(callback, "cells")) diff --git a/plugins/yjs/fps_yjs/ydocs/yunicode.py b/plugins/yjs/fps_yjs/ydocs/yunicode.py new file mode 100644 index 00000000..5b83a1e6 --- /dev/null +++ b/plugins/yjs/fps_yjs/ydocs/yunicode.py @@ -0,0 +1,33 @@ +from functools import partial +from typing import Any, Callable, Optional + +from pycrdt import Doc, Text + +from .ybasedoc import YBaseDoc + + +class YUnicode(YBaseDoc): + def __init__(self, ydoc: Optional[Doc] = None): + super().__init__(ydoc) + self._ysource = Text() + self._ydoc["source"] = self._ysource + + @property + def version(self) -> str: + return "1.0.0" + + def get(self) -> str: + return str(self._ysource) + + def set(self, value: str) -> None: + with self._ydoc.transaction(): + # clear document + del self._ysource[:] + # initialize document + if value: + self._ysource += value + + def observe(self, callback: Callable[[str, Any], None]) -> None: + self.unobserve() + self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state")) + self._subscriptions[self._ysource] = self._ysource.observe(partial(callback, "source")) diff --git a/plugins/yjs/fps_yjs/ywebsocket/__init__.py b/plugins/yjs/fps_yjs/ywebsocket/__init__.py new file mode 100644 index 00000000..bbb4f37f --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/__init__.py @@ -0,0 +1,4 @@ +from .asgi_server import ASGIServer as ASGIServer +from .websocket_provider import WebsocketProvider as WebsocketProvider +from .websocket_server import WebsocketServer as WebsocketServer, YRoom as YRoom +from .yutils import YMessageType as YMessageType diff --git a/plugins/yjs/fps_yjs/ywebsocket/asgi_server.py b/plugins/yjs/fps_yjs/ywebsocket/asgi_server.py new file mode 100644 index 00000000..cff64d22 --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/asgi_server.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from inspect import isawaitable +from typing import Any, Awaitable, Callable + +from .websocket_server import WebsocketServer + + +class ASGIWebsocket: + def __init__( + self, + receive: Callable[[], Awaitable[dict[str, Any]]], + send: Callable[[dict[str, Any]], Awaitable[None]], + path: str, + on_disconnect: Callable[[dict[str, Any]], Awaitable[None] | None] | None = None, + ): + self._receive = receive + self._send = send + self._path = path + self._on_disconnect = on_disconnect + + @property + def path(self) -> str: + return self._path + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + return await self.recv() + + async def send(self, message: bytes) -> None: + await self._send( + dict( + type="websocket.send", + bytes=message, + ) + ) + + async def recv(self) -> bytes: + message = await self._receive() + if message["type"] == "websocket.receive": + return message["bytes"] + if message["type"] == "websocket.disconnect": + if self._on_disconnect is not None: + res = self._on_disconnect(message) + if isawaitable(res): + await res + raise StopAsyncIteration() + return b"" + + +class ASGIServer: + """ASGI server.""" + + def __init__( + self, + websocket_server: WebsocketServer, + on_connect: Callable[[dict[str, Any], dict[str, Any]], Awaitable[bool] | bool] + | None = None, + on_disconnect: Callable[[dict[str, Any]], Awaitable[None] | None] | None = None, + ): + """Initialize the object. + + Arguments: + websocket_server: An instance of WebsocketServer. + on_connect: An optional callback to call when connecting the WebSocket. + If the callback returns True, the WebSocket is not accepted. + on_disconnect: An optional callback called when disconnecting the WebSocket. + """ + self._websocket_server = websocket_server + self._on_connect = on_connect + self._on_disconnect = on_disconnect + + async def __call__( + self, + scope: dict[str, Any], + receive: Callable[[], Awaitable[dict[str, Any]]], + send: Callable[[dict[str, Any]], Awaitable[None]], + ): + msg = await receive() + if msg["type"] == "websocket.connect": + if self._on_connect is not None: + close = self._on_connect(msg, scope) + if isawaitable(close): + close = await close + if close: + return + + await send({"type": "websocket.accept"}) + websocket = ASGIWebsocket(receive, send, scope["path"], self._on_disconnect) + await self._websocket_server.serve(websocket) diff --git a/plugins/yjs/fps_yjs/ywebsocket/awareness.py b/plugins/yjs/fps_yjs/ywebsocket/awareness.py new file mode 100644 index 00000000..4b8d542a --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/awareness.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import json +import time +from typing import Any + +from .yutils import Decoder, read_message + + +class Awareness: + def __init__(self, ydoc): + self.client_id = ydoc.client_id + self.meta = {} + self.states = {} + + def get_changes(self, message: bytes) -> dict[str, Any]: + message = read_message(message) + decoder = Decoder(message) + timestamp = int(time.time() * 1000) + added = [] + updated = [] + filtered_updated = [] + removed = [] + states = [] + length = decoder.read_var_uint() + for _ in range(length): + client_id = decoder.read_var_uint() + clock = decoder.read_var_uint() + state_str = decoder.read_var_string() + state = None if not state_str else json.loads(state_str) + if state is not None: + states.append(state) + client_meta = self.meta.get(client_id) + prev_state = self.states.get(client_id) + curr_clock = 0 if client_meta is None else client_meta["clock"] + if curr_clock < clock or ( + curr_clock == clock and state is None and client_id in self.states + ): + if state is None: + if client_id == self.client_id and self.states.get(client_id) is not None: + clock += 1 + else: + if client_id in self.states: + del self.states[client_id] + else: + self.states[client_id] = state + self.meta[client_id] = { + "clock": clock, + "last_updated": timestamp, + } + if client_meta is None and state is not None: + added.append(client_id) + elif client_meta is not None and state is None: + removed.append(client_id) + elif state is not None: + if state != prev_state: + filtered_updated.append(client_id) + updated.append(client_id) + return { + "added": added, + "updated": updated, + "filtered_updated": filtered_updated, + "removed": removed, + "states": states, + } diff --git a/plugins/yjs/fps_yjs/ywebsocket/django_channels_consumer.py b/plugins/yjs/fps_yjs/ywebsocket/django_channels_consumer.py new file mode 100644 index 00000000..2942db7d --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/django_channels_consumer.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from logging import getLogger +from typing import TypedDict + +from pycrdt import Doc +from channels.generic.websocket import AsyncWebsocketConsumer # type: ignore + +from .websocket import Websocket +from .yutils import YMessageType, process_sync_message, sync + +logger = getLogger(__name__) + + +class _WebsocketShim(Websocket): + def __init__(self, path, send_func) -> None: + self._path = path + self._send_func = send_func + + @property + def path(self) -> str: + return self._path + + def __aiter__(self): + raise NotImplementedError() + + async def __anext__(self) -> bytes: + raise NotImplementedError() + + async def send(self, message: bytes) -> None: + await self._send_func(message) + + async def recv(self) -> bytes: + raise NotImplementedError() + + +class YjsConsumer(AsyncWebsocketConsumer): + """A working consumer for [Django Channels](https://github.com/django/channels). + + This consumer can be used out of the box simply by adding: + ```py + path("ws/", YjsConsumer.as_asgi()) + ``` + to your `urls.py` file. In practice, once you + [set up Channels](https://channels.readthedocs.io/en/1.x/getting-started.html), + you might have something like: + ```py + # urls.py + from django.urls import path + from backend.consumer import DocConsumer, UpdateConsumer + + urlpatterns = [ + path("ws/", YjsConsumer.as_asgi()), + ] + + # asgi.py + import os + from channels.routing import ProtocolTypeRouter, URLRouter + from urls import urlpatterns + + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings") + + application = ProtocolTypeRouter({ + "websocket": URLRouter(urlpatterns_ws), + }) + ``` + + Additionally, the consumer can be subclassed to customize its behavior. + + In particular, + + - Override `make_room_name` to customize the room name. + - Override `make_ydoc` to initialize the YDoc. This is useful to initialize it with data + from your database, or to add observers to it). + - Override `connect` to do custom validation (like auth) on connect, + but be sure to call `await super().connect()` in the end. + - Call `group_send_message` to send a message to an entire group/room. + - Call `send_message` to send a message to a single client, although this is not recommended. + + A full example of a custom consumer showcasing all of these options is: + ```py + from pycrdt import Doc + from asgiref.sync import async_to_sync + from channels.layers import get_channel_layer + from ypy_websocket.django_channels_consumer import YjsConsumer + from ypy_websocket.yutils import create_update_message + + + class DocConsumer(YjsConsumer): + def make_room_name(self) -> str: + # modify the room name here + return self.scope["url_route"]["kwargs"]["room"] + + async def make_ydoc(self) -> Doc: + doc = Doc() + # fill doc with data from DB here + doc.observe(self.on_update_event) + return doc + + async def connect(self): + user = self.scope["user"] + if user is None or user.is_anonymous: + await self.close() + return + await super().connect() + + def on_update_event(self, event): + # process event here + ... + + async def doc_update(self, update_wrapper): + update = update_wrapper["update"] + self.ydoc.apply_update(update) + await self.group_send_message(create_update_message(update)) + + + def send_doc_update(room_name, update): + layer = get_channel_layer() + async_to_sync(layer.group_send)(room_name, {"type": "doc_update", "update": update}) + ``` + + """ + + def __init__(self): + super().__init__() + self.room_name = None + self.ydoc = None + self._websocket_shim = None + + def make_room_name(self) -> str: + """Make the room name for a new channel. + + Override to customize the room name when a channel is created. + + Returns: + The room name for a new channel. Defaults to the room name from the URL route. + """ + return self.scope["url_route"]["kwargs"]["room"] + + async def make_ydoc(self) -> Doc: + """Make the YDoc for a new channel. + + Override to customize the YDoc when a channel is created + (useful to initialize it with data from your database, or to add observers to it). + + Returns: + The YDoc for a new channel. Defaults to a new empty YDoc. + """ + return Doc() + + def _make_websocket_shim(self, path: str) -> _WebsocketShim: + return _WebsocketShim(path, self.group_send_message) + + async def connect(self) -> None: + self.room_name = self.make_room_name() + self.ydoc = await self.make_ydoc() + self._websocket_shim = self._make_websocket_shim(self.scope["path"]) + + await self.channel_layer.group_add(self.room_name, self.channel_name) + await self.accept() + + await sync(self.ydoc, self._websocket_shim, logger) + + async def disconnect(self, code) -> None: + await self.channel_layer.group_discard(self.room_name, self.channel_name) + + async def receive(self, text_data=None, bytes_data=None): + if bytes_data is None: + return + await self.group_send_message(bytes_data) + if bytes_data[0] != YMessageType.SYNC: + return + await process_sync_message(bytes_data[1:], self.ydoc, self._websocket_shim, logger) + + class WrappedMessage(TypedDict): + """A wrapped message to send to the client.""" + + message: bytes + + async def send_message(self, message_wrapper: WrappedMessage) -> None: + """Send a message to the client. + + Arguments: + message_wrapper: The message to send, wrapped. + """ + await self.send(bytes_data=message_wrapper["message"]) + + async def group_send_message(self, message: bytes) -> None: + """Send a message to the group. + + Arguments: + message: The message to send. + """ + await self.channel_layer.group_send( + self.room_name, {"type": "send_message", "message": message} + ) diff --git a/plugins/yjs/fps_yjs/ywebsocket/websocket.py b/plugins/yjs/fps_yjs/ywebsocket/websocket.py new file mode 100644 index 00000000..e487c11a --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/websocket.py @@ -0,0 +1,58 @@ +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + + +class Websocket(Protocol): + """WebSocket. + + The Websocket instance can receive messages using an async iterator, + until the connection is closed: + ```py + async for message in websocket: + ... + ``` + Or directly by calling `recv()`: + ```py + message = await websocket.recv() + ``` + Sending messages is done with `send()`: + ```py + await websocket.send(message) + ``` + """ + + @property + def path(self) -> str: + """The WebSocket path.""" + ... + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + try: + message = await self.recv() + except Exception: + raise StopAsyncIteration() + + return message + + async def send(self, message: bytes) -> None: + """Send a message. + + Arguments: + message: The message to send. + """ + ... + + async def recv(self) -> bytes: + """Receive a message. + + Returns: + The received message. + """ + ... diff --git a/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py b/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py new file mode 100644 index 00000000..1f942b19 --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/websocket_provider.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from functools import partial +from logging import Logger, getLogger + +from pycrdt import Doc +from anyio import ( + TASK_STATUS_IGNORED, + Event, + create_memory_object_stream, + create_task_group, +) +from anyio.abc import TaskGroup, TaskStatus +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from .websocket import Websocket +from .yutils import ( + YMessageType, + create_update_message, + process_sync_message, + put_updates, + sync, +) + + +class WebsocketProvider: + """WebSocket provider.""" + + _ydoc: Doc + _update_send_stream: MemoryObjectSendStream + _update_receive_stream: MemoryObjectReceiveStream + _started: Event | None + _starting: bool + _task_group: TaskGroup | None + + def __init__(self, ydoc: Doc, websocket: Websocket, log: Logger | None = None) -> None: + """Initialize the object. + + The WebsocketProvider instance should preferably be used as an async context manager: + ```py + async with websocket_provider: + ... + ``` + However, a lower-level API can also be used: + ```py + task = asyncio.create_task(websocket_provider.start()) + await websocket_provider.started.wait() + ... + websocket_provider.stop() + ``` + + Arguments: + ydoc: The YDoc to connect through the WebSocket. + websocket: The WebSocket through which to connect the YDoc. + log: An optional logger. + """ + self._ydoc = ydoc + self._websocket = websocket + self.log = log or getLogger(__name__) + self._update_send_stream, self._update_receive_stream = create_memory_object_stream( + max_buffer_size=65536 + ) + self._started = None + self._starting = False + self._task_group = None + ydoc.observe(partial(put_updates, self._update_send_stream)) + + @property + def started(self) -> Event: + """An async event that is set when the WebSocket provider has started.""" + if self._started is None: + self._started = Event() + return self._started + + async def __aenter__(self) -> WebsocketProvider: + if self._task_group is not None: + raise RuntimeError("WebsocketProvider already running") + + async with AsyncExitStack() as exit_stack: + tg = create_task_group() + self._task_group = await exit_stack.enter_async_context(tg) + self._exit_stack = exit_stack.pop_all() + tg.start_soon(self._run) + self.started.set() + + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + if self._task_group is None: + raise RuntimeError("WebsocketProvider not running") + + self._task_group.cancel_scope.cancel() + self._task_group = None + return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + + async def _run(self): + await sync(self._ydoc, self._websocket, self.log) + self._task_group.start_soon(self._send) + async for message in self._websocket: + if message[0] == YMessageType.SYNC: + await process_sync_message(message[1:], self._ydoc, self._websocket, self.log) + + async def _send(self): + async with self._update_receive_stream: + async for update in self._update_receive_stream: + message = create_update_message(update) + try: + await self._websocket.send(message) + except Exception: + pass + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): + """Start the WebSocket provider. + + Arguments: + task_status: The status to set when the task has started. + """ + if self._starting: + return + else: + self._starting = True + + if self._task_group is not None: + raise RuntimeError("WebsocketProvider already running") + + async with create_task_group() as self._task_group: + self._task_group.start_soon(self._run) + self.started.set() + self._starting = False + task_status.started() + + def stop(self): + """Stop the WebSocket provider.""" + if self._task_group is None: + raise RuntimeError("WebsocketProvider not running") + + self._task_group.cancel_scope.cancel() + self._task_group = None diff --git a/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py b/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py new file mode 100644 index 00000000..6bff2a1c --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from logging import Logger, getLogger + +from anyio import TASK_STATUS_IGNORED, Event, create_task_group +from anyio.abc import TaskGroup, TaskStatus + +from .websocket import Websocket +from .yroom import YRoom + + +class WebsocketServer: + """WebSocket server.""" + + auto_clean_rooms: bool + rooms: dict[str, YRoom] + _started: Event | None + _starting: bool + _task_group: TaskGroup | None + + def __init__( + self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None + ) -> None: + """Initialize the object. + + The WebsocketServer instance should preferably be used as an async context manager: + ```py + async with websocket_server: + ... + ``` + However, a lower-level API can also be used: + ```py + task = asyncio.create_task(websocket_server.start()) + await websocket_server.started.wait() + ... + websocket_server.stop() + ``` + + Arguments: + rooms_ready: Whether rooms are ready to be synchronized when opened. + auto_clean_rooms: Whether rooms should be deleted when no client is there anymore. + log: An optional logger. + """ + self.rooms_ready = rooms_ready + self.auto_clean_rooms = auto_clean_rooms + self.log = log or getLogger(__name__) + self.rooms = {} + self._started = None + self._starting = False + self._task_group = None + + @property + def started(self) -> Event: + """An async event that is set when the WebSocket server has started.""" + if self._started is None: + self._started = Event() + return self._started + + async def get_room(self, name: str) -> YRoom: + """Get or create a room with the given name, and start it. + + Arguments: + name: The room name. + + Returns: + The room with the given name, or a new one if no room with that name was found. + """ + if name not in self.rooms.keys(): + self.rooms[name] = YRoom(ready=self.rooms_ready, log=self.log) + room = self.rooms[name] + await self.start_room(room) + return room + + async def start_room(self, room: YRoom) -> None: + """Start a room, if not already started. + + Arguments: + room: The room to start. + """ + if self._task_group is None: + raise RuntimeError( + "The WebsocketServer is not running: use `async with websocket_server:` " + "or `await websocket_server.start()`" + ) + + if not room.started.is_set(): + await self._task_group.start(room.start) + + def get_room_name(self, room: YRoom) -> str: + """Get the name of a room. + + Arguments: + room: The room to get the name from. + + Returns: + The room name. + """ + return list(self.rooms.keys())[list(self.rooms.values()).index(room)] + + def rename_room( + self, to_name: str, *, from_name: str | None = None, from_room: YRoom | None = None + ) -> None: + """Rename a room. + + Arguments: + to_name: The new name of the room. + from_name: The previous name of the room (if `from_room` is not passed). + from_room: The room to be renamed (if `from_name` is not passed). + """ + if from_name is not None and from_room is not None: + raise RuntimeError("Cannot pass from_name and from_room") + if from_name is None: + assert from_room is not None + from_name = self.get_room_name(from_room) + self.rooms[to_name] = self.rooms.pop(from_name) + + def delete_room(self, *, name: str | None = None, room: YRoom | None = None) -> None: + """Delete a room. + + Arguments: + name: The name of the room to delete (if `room` is not passed). + room: The room to delete ( if `name` is not passed). + """ + if name is not None and room is not None: + raise RuntimeError("Cannot pass name and room") + if name is None: + assert room is not None + name = self.get_room_name(room) + room = self.rooms.pop(name) + room.stop() + + async def serve(self, websocket: Websocket) -> None: + """Serve a client through a WebSocket. + + Arguments: + websocket: The WebSocket through which to serve the client. + """ + if self._task_group is None: + raise RuntimeError( + "The WebsocketServer is not running: use `async with websocket_server:` " + "or `await websocket_server.start()`" + ) + + async with create_task_group() as tg: + tg.start_soon(self._serve, websocket, tg) + + async def _serve(self, websocket: Websocket, tg: TaskGroup): + room = await self.get_room(websocket.path) + await self.start_room(room) + await room.serve(websocket) + + if self.auto_clean_rooms and not room.clients: + self.delete_room(room=room) + tg.cancel_scope.cancel() + + async def __aenter__(self) -> WebsocketServer: + if self._task_group is not None: + raise RuntimeError("WebsocketServer already running") + + async with AsyncExitStack() as exit_stack: + tg = create_task_group() + self._task_group = await exit_stack.enter_async_context(tg) + self._exit_stack = exit_stack.pop_all() + self.started.set() + + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + if self._task_group is None: + raise RuntimeError("WebsocketServer not running") + + self._task_group.cancel_scope.cancel() + self._task_group = None + return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): + """Start the WebSocket server. + + Arguments: + task_status: The status to set when the task has started. + """ + if self._starting: + return + else: + self._starting = True + + if self._task_group is not None: + raise RuntimeError("WebsocketServer already running") + + # create the task group and wait forever + async with create_task_group() as self._task_group: + self._task_group.start_soon(Event().wait) + self.started.set() + self._starting = False + task_status.started() + + def stop(self) -> None: + """Stop the WebSocket server.""" + if self._task_group is None: + raise RuntimeError("WebsocketServer not running") + + self._task_group.cancel_scope.cancel() + self._task_group = None diff --git a/plugins/yjs/fps_yjs/ywebsocket/yroom.py b/plugins/yjs/fps_yjs/ywebsocket/yroom.py new file mode 100644 index 00000000..7c5dfc57 --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/yroom.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from functools import partial +from inspect import isawaitable +from logging import Logger, getLogger +from typing import Awaitable, Callable + +from pycrdt import Doc +from anyio import ( + TASK_STATUS_IGNORED, + Event, + create_memory_object_stream, + create_task_group, +) +from anyio.abc import TaskGroup, TaskStatus +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from .awareness import Awareness +from .websocket import Websocket +from .ystore import BaseYStore +from .yutils import ( + YMessageType, + create_update_message, + process_sync_message, + put_updates, + sync, +) + + +class YRoom: + clients: list + ydoc: Doc + ystore: BaseYStore | None + _on_message: Callable[[bytes], Awaitable[bool] | bool] | None + _update_send_stream: MemoryObjectSendStream + _update_receive_stream: MemoryObjectReceiveStream + _ready: bool + _task_group: TaskGroup | None + _started: Event | None + _starting: bool + + def __init__( + self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None + ): + """Initialize the object. + + The YRoom instance should preferably be used as an async context manager: + ```py + async with room: + ... + ``` + However, a lower-level API can also be used: + ```py + task = asyncio.create_task(room.start()) + await room.started.wait() + ... + room.stop() + ``` + + Arguments: + ready: Whether the internal YDoc is ready to be synchronized right away. + ystore: An optional store in which to persist document updates. + log: An optional logger. + """ + self.ydoc = Doc() + self.awareness = Awareness(self.ydoc) + self._update_send_stream, self._update_receive_stream = create_memory_object_stream( + max_buffer_size=65536 + ) + self._ready = False + self.ready = ready + self.ystore = ystore + self.log = log or getLogger(__name__) + self.clients = [] + self._on_message = None + self._started = None + self._starting = False + self._task_group = None + + @property + def started(self): + """An async event that is set when the YRoom provider has started.""" + if self._started is None: + self._started = Event() + return self._started + + @property + def ready(self) -> bool: + """ + Returns: + True is the internal YDoc is ready to be synchronized. + """ + return self._ready + + @ready.setter + def ready(self, value: bool) -> None: + """ + Arguments: + value: True if the internal YDoc is ready to be synchronized, False otherwise.""" + self._ready = value + if value: + self.ydoc.observe(partial(put_updates, self._update_send_stream)) + + @property + def on_message(self) -> Callable[[bytes], Awaitable[bool] | bool] | None: + """ + Returns: + The optional callback to call when a message is received. + """ + return self._on_message + + @on_message.setter + def on_message(self, value: Callable[[bytes], Awaitable[bool] | bool] | None): + """ + Arguments: + value: An optional callback to call when a message is received. + If the callback returns True, the message is skipped. + """ + self._on_message = value + + async def _broadcast_updates(self): + if self.ystore is not None and not self.ystore.started.is_set(): + self._task_group.start_soon(self.ystore.start) + + async with self._update_receive_stream: + async for update in self._update_receive_stream: + if self._task_group.cancel_scope.cancel_called: + return + # broadcast internal ydoc's update to all clients, that includes changes from the + # clients and changes from the backend (out-of-band changes) + for client in self.clients: + self.log.debug("Sending Y update to client with endpoint: %s", client.path) + message = create_update_message(update) + self._task_group.start_soon(client.send, message) + if self.ystore: + self.log.debug("Writing Y update to YStore") + self._task_group.start_soon(self.ystore.write, update) + + async def __aenter__(self) -> YRoom: + if self._task_group is not None: + raise RuntimeError("YRoom already running") + + async with AsyncExitStack() as exit_stack: + tg = create_task_group() + self._task_group = await exit_stack.enter_async_context(tg) + self._exit_stack = exit_stack.pop_all() + tg.start_soon(self._broadcast_updates) + self.started.set() + + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + if self._task_group is None: + raise RuntimeError("YRoom not running") + + self._task_group.cancel_scope.cancel() + self._task_group = None + return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): + """Start the room. + + Arguments: + task_status: The status to set when the task has started. + """ + if self._starting: + return + else: + self._starting = True + + if self._task_group is not None: + raise RuntimeError("YRoom already running") + + async with create_task_group() as self._task_group: + self._task_group.start_soon(self._broadcast_updates) + self.started.set() + self._starting = False + task_status.started() + + def stop(self): + """Stop the room.""" + if self._task_group is None: + raise RuntimeError("YRoom not running") + + self._task_group.cancel_scope.cancel() + self._task_group = None + + async def serve(self, websocket: Websocket): + """Serve a client. + + Arguments: + websocket: The WebSocket through which to serve the client. + """ + async with create_task_group() as tg: + self.clients.append(websocket) + await sync(self.ydoc, websocket, self.log) + try: + async for message in websocket: + # filter messages (e.g. awareness) + skip = False + if self.on_message: + _skip = self.on_message(message) + skip = await _skip if isawaitable(_skip) else _skip + if skip: + continue + message_type = message[0] + if message_type == YMessageType.SYNC: + # update our internal state in the background + # changes to the internal state are then forwarded to all clients + # and stored in the YStore (if any) + tg.start_soon( + process_sync_message, message[1:], self.ydoc, websocket, self.log + ) + elif message_type == YMessageType.AWARENESS: + # forward awareness messages from this client to all clients, + # including itself, because it's used to keep the connection alive + self.log.debug( + "Received %s message from endpoint: %s", + YMessageType.AWARENESS.name, + websocket.path, + ) + for client in self.clients: + self.log.debug( + "Sending Y awareness from client with endpoint " + "%s to client with endpoint: %s", + websocket.path, + client.path, + ) + tg.start_soon(client.send, message) + except Exception as e: + self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e) + + # remove this client + self.clients = [c for c in self.clients if c != websocket] diff --git a/plugins/yjs/fps_yjs/ywebsocket/ystore.py b/plugins/yjs/fps_yjs/ywebsocket/ystore.py new file mode 100644 index 00000000..d71a1d37 --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/ystore.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +import struct +import tempfile +import time +from abc import ABC, abstractmethod +from contextlib import AsyncExitStack +from inspect import isawaitable +from logging import Logger, getLogger +from pathlib import Path +from typing import AsyncIterator, Awaitable, Callable, cast + +import aiosqlite +import anyio +from pycrdt import Doc +from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group +from anyio.abc import TaskGroup, TaskStatus + +from .yutils import Decoder, get_new_path, write_var_uint + + +class YDocNotFound(Exception): + pass + + +class BaseYStore(ABC): + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None + version = 2 + _started: Event | None = None + _starting: bool = False + _task_group: TaskGroup | None = None + + @abstractmethod + def __init__( + self, path: str, metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None + ): + ... + + @abstractmethod + async def write(self, data: bytes) -> None: + ... + + @abstractmethod + async def read(self) -> AsyncIterator[tuple[bytes, bytes]]: + ... + + @property + def started(self) -> Event: + if self._started is None: + self._started = Event() + return self._started + + async def __aenter__(self) -> BaseYStore: + if self._task_group is not None: + raise RuntimeError("YStore already running") + + async with AsyncExitStack() as exit_stack: + tg = create_task_group() + self._task_group = await exit_stack.enter_async_context(tg) + self._exit_stack = exit_stack.pop_all() + tg.start_soon(self.start) + + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + if self._task_group is None: + raise RuntimeError("YStore not running") + + self._task_group.cancel_scope.cancel() + self._task_group = None + return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): + """Start the store. + + Arguments: + task_status: The status to set when the task has started. + """ + if self._starting: + return + else: + self._starting = True + + if self._task_group is not None: + raise RuntimeError("YStore already running") + + self.started.set() + self._starting = False + task_status.started() + + def stop(self) -> None: + """Stop the store.""" + if self._task_group is None: + raise RuntimeError("YStore not running") + + self._task_group.cancel_scope.cancel() + self._task_group = None + + async def get_metadata(self) -> bytes: + """ + Returns: + The metadata. + """ + if self.metadata_callback is None: + return b"" + + metadata = self.metadata_callback() + if isawaitable(metadata): + metadata = await metadata + metadata = cast(bytes, metadata) + return metadata + + async def encode_state_as_update(self, ydoc: Doc) -> None: + """Store a YDoc state. + + Arguments: + ydoc: The YDoc from which to store the state. + """ + update = ydoc.get_update() + await self.write(update) + + async def apply_updates(self, ydoc: Doc) -> None: + """Apply all stored updates to the YDoc. + + Arguments: + ydoc: The YDoc on which to apply the updates. + """ + async for update, *rest in self.read(): # type: ignore + ydoc.apply_update(update) + + +class FileYStore(BaseYStore): + """A YStore which uses one file per document.""" + + path: str + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None + lock: Lock + + def __init__( + self, + path: str, + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None, + log: Logger | None = None, + ) -> None: + """Initialize the object. + + Arguments: + path: The file path used to store the updates. + metadata_callback: An optional callback to call to get the metadata. + log: An optional logger. + """ + self.path = path + self.metadata_callback = metadata_callback + self.log = log or getLogger(__name__) + self.lock = Lock() + + async def check_version(self) -> int: + """Check the version of the store format. + + Returns: + The offset where the data is located in the file. + """ + if not await anyio.Path(self.path).exists(): + version_mismatch = True + else: + version_mismatch = False + move_file = False + async with await anyio.open_file(self.path, "rb") as f: + header = await f.read(8) + if header == b"VERSION:": + version = int(await f.readline()) + if version == self.version: + offset = await f.tell() + else: + version_mismatch = True + else: + version_mismatch = True + if version_mismatch: + move_file = True + if move_file: + new_path = await get_new_path(self.path) + self.log.warning(f"YStore version mismatch, moving {self.path} to {new_path}") + await anyio.Path(self.path).rename(new_path) + if version_mismatch: + async with await anyio.open_file(self.path, "wb") as f: + version_bytes = f"VERSION:{self.version}\n".encode() + await f.write(version_bytes) + offset = len(version_bytes) + return offset + + async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore + """Async iterator for reading the store content. + + Returns: + A tuple of (update, metadata, timestamp) for each update. + """ + async with self.lock: + if not await anyio.Path(self.path).exists(): + raise YDocNotFound + offset = await self.check_version() + async with await anyio.open_file(self.path, "rb") as f: + await f.seek(offset) + data = await f.read() + if not data: + raise YDocNotFound + i = 0 + for d in Decoder(data).read_messages(): + if i == 0: + update = d + elif i == 1: + metadata = d + else: + timestamp = struct.unpack(" None: + """Store an update. + + Arguments: + data: The update to store. + """ + parent = Path(self.path).parent + async with self.lock: + await anyio.Path(parent).mkdir(parents=True, exist_ok=True) + await self.check_version() + async with await anyio.open_file(self.path, "ab") as f: + data_len = write_var_uint(len(data)) + await f.write(data_len + data) + metadata = await self.get_metadata() + metadata_len = write_var_uint(len(metadata)) + await f.write(metadata_len + metadata) + timestamp = struct.pack(" str: + """Get the base directory where the update file is written. + + Returns: + The base directory path. + """ + if self.base_dir is None: + self.make_directory() + assert self.base_dir is not None + return self.base_dir + + def make_directory(self): + """Create the base directory where the update file is written.""" + type(self).base_dir = tempfile.mkdtemp(prefix=self.prefix_dir) + + +class SQLiteYStore(BaseYStore): + """A YStore which uses an SQLite database. + Unlike file-based YStores, the Y updates of all documents are stored in the same database. + + Subclass to point to your database file: + + ```py + class MySQLiteYStore(SQLiteYStore): + db_path = "path/to/my_ystore.db" + ``` + """ + + db_path: str = "ystore.db" + # Determines the "time to live" for all documents, i.e. how recent the + # latest update of a document must be before purging document history. + # Defaults to never purging document history (None). + document_ttl: int | None = None + path: str + lock: Lock + db_initialized: Event + + def __init__( + self, + path: str, + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None, + log: Logger | None = None, + ) -> None: + """Initialize the object. + + Arguments: + path: The file path used to store the updates. + metadata_callback: An optional callback to call to get the metadata. + log: An optional logger. + """ + self.path = path + self.metadata_callback = metadata_callback + self.log = log or getLogger(__name__) + self.lock = Lock() + self.db_initialized = Event() + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): + """Start the SQLiteYStore. + + Arguments: + task_status: The status to set when the task has started. + """ + if self._starting: + return + else: + self._starting = True + + if self._task_group is not None: + raise RuntimeError("YStore already running") + + async with create_task_group() as self._task_group: + self._task_group.start_soon(self._init_db) + self.started.set() + self._starting = False + task_status.started() + + async def _init_db(self): + create_db = False + move_db = False + if not await anyio.Path(self.db_path).exists(): + create_db = True + else: + async with self.lock: + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + "SELECT count(name) FROM sqlite_master " + "WHERE type='table' and name='yupdates'" + ) + table_exists = (await cursor.fetchone())[0] + if table_exists: + cursor = await db.execute("pragma user_version") + version = (await cursor.fetchone())[0] + if version != self.version: + move_db = True + create_db = True + else: + create_db = True + if move_db: + new_path = await get_new_path(self.db_path) + self.log.warning(f"YStore version mismatch, moving {self.db_path} to {new_path}") + await anyio.Path(self.db_path).rename(new_path) + if create_db: + async with self.lock: + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "CREATE TABLE yupdates " + "(path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)" + ) + await db.execute( + "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" + ) + await db.execute(f"PRAGMA user_version = {self.version}") + await db.commit() + self.db_initialized.set() + + async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore + """Async iterator for reading the store content. + + Returns: + A tuple of (update, metadata, timestamp) for each update. + """ + await self.db_initialized.wait() + try: + async with self.lock: + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", + (self.path,), + ) as cursor: + found = False + async for update, metadata, timestamp in cursor: + found = True + yield update, metadata, timestamp + if not found: + raise YDocNotFound + except Exception: + raise YDocNotFound + + async def write(self, data: bytes) -> None: + """Store an update. + + Arguments: + data: The update to store. + """ + await self.db_initialized.wait() + async with self.lock: + async with aiosqlite.connect(self.db_path) as db: + # first, determine time elapsed since last update + cursor = await db.execute( + "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", + (self.path,), + ) + row = await cursor.fetchone() + diff = (time.time() - row[0]) if row else 0 + + if self.document_ttl is not None and diff > self.document_ttl: + # squash updates + ydoc = Doc() + async with db.execute( + "SELECT yupdate FROM yupdates WHERE path = ?", (self.path,) + ) as cursor: + async for update, in cursor: + ydoc.apply_update(update) + # delete history + await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) + # insert squashed updates + squashed_update = ydoc.get_update() + metadata = await self.get_metadata() + await db.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (self.path, squashed_update, metadata, time.time()), + ) + + # finally, write this update to the DB + metadata = await self.get_metadata() + await db.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (self.path, data, metadata, time.time()), + ) + await db.commit() diff --git a/plugins/yjs/fps_yjs/ywebsocket/yutils.py b/plugins/yjs/fps_yjs/ywebsocket/yutils.py new file mode 100644 index 00000000..a8ca4d46 --- /dev/null +++ b/plugins/yjs/fps_yjs/ywebsocket/yutils.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from enum import IntEnum +from pathlib import Path + +import anyio +from pycrdt import Doc, TransactionEvent +from anyio.streams.memory import MemoryObjectSendStream + + +class YMessageType(IntEnum): + SYNC = 0 + AWARENESS = 1 + + +class YSyncMessageType(IntEnum): + SYNC_STEP1 = 0 + SYNC_STEP2 = 1 + SYNC_UPDATE = 2 + + +def write_var_uint(num: int) -> bytes: + res = [] + while num > 127: + res.append(128 | (127 & num)) + num >>= 7 + res.append(num) + return bytes(res) + + +def create_message(data: bytes, msg_type: int) -> bytes: + return bytes([YMessageType.SYNC, msg_type]) + write_var_uint(len(data)) + data + + +def create_sync_step1_message(data: bytes) -> bytes: + return create_message(data, YSyncMessageType.SYNC_STEP1) + + +def create_sync_step2_message(data: bytes) -> bytes: + return create_message(data, YSyncMessageType.SYNC_STEP2) + + +def create_update_message(data: bytes) -> bytes: + return create_message(data, YSyncMessageType.SYNC_UPDATE) + + +def read_message(stream: bytes) -> bytes: + message = Decoder(stream).read_message() + assert message is not None + return message + + +class Decoder: + def __init__(self, stream: bytes): + self.stream = stream + self.length = len(stream) + self.i0 = 0 + + def read_var_uint(self) -> int: + if self.length <= 0: + raise RuntimeError("Y protocol error") + uint = 0 + i = 0 + while True: + byte = self.stream[self.i0] + uint += (byte & 127) << i + i += 7 + self.i0 += 1 + self.length -= 1 + if byte < 128: + break + return uint + + def read_message(self) -> bytes | None: + if self.length == 0: + return None + length = self.read_var_uint() + if length == 0: + return b"" + i1 = self.i0 + length + message = self.stream[self.i0 : i1] # noqa + self.i0 = i1 + self.length -= length + return message + + def read_messages(self): + while True: + message = self.read_message() + if message is None: + return + yield message + + def read_var_string(self): + message = self.read_message() + if message is None: + return "" + return message.decode("utf-8") + + +def put_updates(update_send_stream: MemoryObjectSendStream, event: TransactionEvent) -> None: + update = event.get_update() # type: ignore + update_send_stream.send_nowait(update) + + +async def process_sync_message(message: bytes, ydoc: Doc, websocket, log) -> None: + message_type = message[0] + msg = message[1:] + log.debug( + "Received %s message from endpoint: %s", + YSyncMessageType(message_type).name, + websocket.path, + ) + if message_type == YSyncMessageType.SYNC_STEP1: + state = read_message(msg) + update = ydoc.get_update(state) + reply = create_sync_step2_message(update) + log.debug( + "Sending %s message to endpoint: %s", + YSyncMessageType.SYNC_STEP2.name, + websocket.path, + ) + await websocket.send(reply) + elif message_type in ( + YSyncMessageType.SYNC_STEP2, + YSyncMessageType.SYNC_UPDATE, + ): + update = read_message(msg) + # Ignore empty updates + if update != b"\x00\x00": + ydoc.apply_update(update) + + +async def sync(ydoc: Doc, websocket, log): + state = ydoc.get_state() + msg = create_sync_step1_message(state) + log.debug( + "Sending %s message to endpoint: %s", + YSyncMessageType.SYNC_STEP1.name, + websocket.path, + ) + await websocket.send(msg) + + +async def get_new_path(path: str) -> str: + p = Path(path) + ext = p.suffix + p_noext = p.with_suffix("") + i = 1 + dir_list = [p async for p in anyio.Path().iterdir()] + while True: + new_path = f"{p_noext}({i}){ext}" + if new_path not in dir_list: + break + i += 1 + return str(new_path) diff --git a/plugins/yjs/pyproject.toml b/plugins/yjs/pyproject.toml index e6bd2d81..d395ae81 100644 --- a/plugins/yjs/pyproject.toml +++ b/plugins/yjs/pyproject.toml @@ -8,9 +8,7 @@ description = "An FPS plugin for the Yjs API" keywords = [ "jupyter", "server", "fastapi", "plugins" ] requires-python = ">=3.8" dependencies = [ - "jupyter_ydoc >=1,<2", - "ypy-websocket >=0.12.1,<0.13.0", - "y-py >=0.6.0,<0.7.0", + "pycrdt >=0.3.4,<0.4.0", "jupyverse-api >=0.1.2,<1", ] dynamic = [ "version",] @@ -38,5 +36,11 @@ skip = [ "check-links",] "asphalt.components" = {yjs = "fps_yjs.main:YjsComponent"} "jupyverse.components" = {yjs = "fps_yjs.main:YjsComponent"} +[project.entry-points.jupyverse_ydoc] +blob = "fps_yjs.ydocs.yblob:YBlob" +file = "fps_yjs.ydocs.yfile:YFile" +unicode = "fps_yjs.ydocs.yunicode:YUnicode" +notebook = "fps_yjs.ydocs.ynotebook:YNotebook" + [tool.hatch.version] path = "fps_yjs/__init__.py" diff --git a/tests/test_server.py b/tests/test_server.py index 07993dbc..8adb7cae 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,9 +4,9 @@ import pytest import requests -import y_py as Y +from pycrdt import Doc from websockets import connect -from ypy_websocket import WebsocketProvider +from fps_yjs.ywebsocket import WebsocketProvider prev_theme = {} test_theme = {"raw": '{// jupyverse test\n"theme": "JupyterLab Dark"}'} @@ -79,7 +79,7 @@ async def test_rest_api(start_jupyverse): ) file_id = response.json()["fileId"] document_id = f"json:notebook:{file_id}" - ydoc = Y.YDoc() + ydoc = Doc() async with connect( f"{ws_url}/api/collaboration/room/{document_id}" ) as websocket, WebsocketProvider(ydoc, websocket): @@ -100,7 +100,7 @@ async def test_rest_api(start_jupyverse): # wait for Y model to be updated await asyncio.sleep(0.5) # retrieve cells - cells = json.loads(ydoc.get_array("cells").to_json()) + cells = json.loads(str(ydoc.get_array("cells"))) assert cells[0]["outputs"] == [ { "data": {"text/plain": ["3"]},