From 73a16c326e333838aaa8ec2321c40d5d13a96f05 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 15 Dec 2023 17:49:00 +0100 Subject: [PATCH] Implement server-side ypywidgets rendering (#364) * Implement server-side ypywidgets rendering * Fix types * Use ypywidgets-textual in tests * Update with ypywidgets v0.6.1 and ypywidgets-textual 0.2.1 * Use cell ID instead of cell index in execute API * Add JupyterLab server_side_execution flag * Set shared document file_id --- .gitignore | 3 + .../jupyverse_api/jupyterlab/__init__.py | 1 + jupyverse_api/jupyverse_api/kernels/models.py | 2 +- plugins/jupyterlab/fps_jupyterlab/routes.py | 6 +- .../fps_kernels/kernel_driver/driver.py | 174 +++++++++++++----- .../fps_kernels/kernel_driver/message.py | 3 +- plugins/kernels/fps_kernels/routes.py | 12 +- plugins/noauth/pyproject.toml | 4 +- plugins/yjs/fps_yjs/routes.py | 10 +- plugins/yjs/fps_yjs/ydocs/ybasedoc.py | 8 + plugins/yjs/fps_yjs/ydocs/ynotebook.py | 4 +- .../fps_yjs/ywebsocket/websocket_server.py | 5 +- plugins/yjs/fps_yjs/ywebsocket/yroom.py | 8 +- plugins/yjs/fps_yjs/ywidgets/__init__.py | 1 + plugins/yjs/fps_yjs/ywidgets/widgets.py | 52 ++++++ pyproject.toml | 2 + tests/data/notebook1.ipynb | 55 ++++++ tests/test_execute.py | 150 +++++++++++++++ tests/test_server.py | 106 ++++++++++- 19 files changed, 541 insertions(+), 65 deletions(-) create mode 100644 plugins/yjs/fps_yjs/ywidgets/__init__.py create mode 100644 plugins/yjs/fps_yjs/ywidgets/widgets.py create mode 100644 tests/data/notebook1.ipynb create mode 100644 tests/test_execute.py diff --git a/.gitignore b/.gitignore index d4751320..70e4192c 100644 --- a/.gitignore +++ b/.gitignore @@ -344,3 +344,6 @@ $RECYCLE.BIN/ .jupyter_ystore.db .jupyter_ystore.db-journal fps_cli_args.toml + +# pixi environments +.pixi diff --git a/jupyverse_api/jupyverse_api/jupyterlab/__init__.py b/jupyverse_api/jupyverse_api/jupyterlab/__init__.py index 36f18b56..c9bdf98c 100644 --- a/jupyverse_api/jupyverse_api/jupyterlab/__init__.py +++ b/jupyverse_api/jupyverse_api/jupyterlab/__init__.py @@ -89,3 +89,4 @@ async def get_workspace( class JupyterLabConfig(Config): dev_mode: bool = False + server_side_execution: bool = False diff --git a/jupyverse_api/jupyverse_api/kernels/models.py b/jupyverse_api/jupyverse_api/kernels/models.py index acd573c7..08ac52e6 100644 --- a/jupyverse_api/jupyverse_api/kernels/models.py +++ b/jupyverse_api/jupyverse_api/kernels/models.py @@ -39,4 +39,4 @@ class Session(BaseModel): class Execution(BaseModel): document_id: str - cell_idx: int + cell_id: str diff --git a/plugins/jupyterlab/fps_jupyterlab/routes.py b/plugins/jupyterlab/fps_jupyterlab/routes.py index 50e8b802..6f600d22 100644 --- a/plugins/jupyterlab/fps_jupyterlab/routes.py +++ b/plugins/jupyterlab/fps_jupyterlab/routes.py @@ -58,6 +58,7 @@ async def get_lab( self.get_index( "default", self.frontend_config.collaborative, + self.jupyterlab_config.server_side_execution, self.jupyterlab_config.dev_mode, self.frontend_config.base_url, ) @@ -71,6 +72,7 @@ async def load_workspace( self.get_index( "default", self.frontend_config.collaborative, + self.jupyterlab_config.server_side_execution, self.jupyterlab_config.dev_mode, self.frontend_config.base_url, ) @@ -99,11 +101,12 @@ async def get_workspace( return self.get_index( name, self.frontend_config.collaborative, + self.jupyterlab_config.server_side_execution, self.jupyterlab_config.dev_mode, self.frontend_config.base_url, ) - def get_index(self, workspace, collaborative, dev_mode, base_url="/"): + def get_index(self, workspace, collaborative, server_side_execution, dev_mode, base_url="/"): for path in (self.static_lab_dir).glob("main.*.js"): main_id = path.name.split(".")[1] break @@ -121,6 +124,7 @@ def get_index(self, workspace, collaborative, dev_mode, base_url="/"): "baseUrl": base_url, "cacheFiles": False, "collaborative": collaborative, + "serverSideExecution": server_side_execution, "devMode": dev_mode, "disabledExtensions": self.disabled_extension, "exposeAppInBrowser": False, diff --git a/plugins/kernels/fps_kernels/kernel_driver/driver.py b/plugins/kernels/fps_kernels/kernel_driver/driver.py index 4e41061c..cc65dcc3 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/driver.py +++ b/plugins/kernels/fps_kernels/kernel_driver/driver.py @@ -4,6 +4,10 @@ import uuid from typing import Any, Dict, List, Optional, cast +from pycrdt import Array, Map + +from jupyverse_api.yjs import Yjs + from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file from .connect import write_connection_file as _write_connection_file from .kernelspec import find_kernelspec @@ -23,10 +27,12 @@ def __init__( connection_file: str = "", write_connection_file: bool = True, capture_kernel_output: bool = True, + yjs: Optional[Yjs] = None, ) -> None: self.capture_kernel_output = capture_kernel_output self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name) self.kernel_cwd = kernel_cwd + self.yjs = yjs if not self.kernelspec_path: raise RuntimeError("Could not find a kernel, maybe you forgot to install one?") if write_connection_file: @@ -37,11 +43,12 @@ def __init__( self.key = cast(str, self.connection_cfg["key"]) self.session_id = uuid.uuid4().hex self.msg_cnt = 0 - self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {} - self.channel_tasks: List[asyncio.Task] = [] + self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {} + self.comm_messages: asyncio.Queue = asyncio.Queue() + self.tasks: List[asyncio.Task] = [] async def restart(self, startup_timeout: float = float("inf")) -> None: - for task in self.channel_tasks: + for task in self.tasks: task.cancel() msg = create_message("shutdown_request", content={"restart": True}) await send_message(msg, self.control_channel, self.key, change_date_to_str=True) @@ -52,7 +59,7 @@ async def restart(self, startup_timeout: float = float("inf")) -> None: if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]: break await self._wait_for_ready(startup_timeout) - self.channel_tasks = [] + self.tasks = [] self.listen_channels() async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None: @@ -69,6 +76,7 @@ async def connect(self, startup_timeout: float = float("inf")) -> None: self.connect_channels() await self._wait_for_ready(startup_timeout) self.listen_channels() + self.tasks.append(asyncio.create_task(self._handle_comms())) def connect_channels(self, connection_cfg: Optional[cfg_t] = None): connection_cfg = connection_cfg or self.connection_cfg @@ -77,40 +85,43 @@ def connect_channels(self, connection_cfg: Optional[cfg_t] = None): self.iopub_channel = connect_channel("iopub", connection_cfg) def listen_channels(self): - self.channel_tasks.append(asyncio.create_task(self.listen_iopub())) - self.channel_tasks.append(asyncio.create_task(self.listen_shell())) + self.tasks.append(asyncio.create_task(self.listen_iopub())) + self.tasks.append(asyncio.create_task(self.listen_shell())) async def stop(self) -> None: self.kernel_process.kill() await self.kernel_process.wait() os.remove(self.connection_file_path) - for task in self.channel_tasks: + for task in self.tasks: task.cancel() async def listen_iopub(self): while True: msg = await receive_message(self.iopub_channel, change_str_to_date=True) - msg_id = msg["parent_header"].get("msg_id") - if msg_id in self.execute_requests.keys(): - self.execute_requests[msg_id]["iopub_msg"].set_result(msg) + parent_id = msg["parent_header"].get("msg_id") + if msg["msg_type"] in ("comm_open", "comm_msg"): + self.comm_messages.put_nowait(msg) + elif parent_id in self.execute_requests.keys(): + self.execute_requests[parent_id]["iopub_msg"].put_nowait(msg) async def listen_shell(self): while True: msg = await receive_message(self.shell_channel, change_str_to_date=True) msg_id = msg["parent_header"].get("msg_id") if msg_id in self.execute_requests.keys(): - self.execute_requests[msg_id]["shell_msg"].set_result(msg) + self.execute_requests[msg_id]["shell_msg"].put_nowait(msg) async def execute( self, - cell: Dict[str, Any], + ycell: Map, timeout: float = float("inf"), msg_id: str = "", wait_for_executed: bool = True, ) -> None: - if cell["cell_type"] != "code": + if ycell["cell_type"] != "code": return - content = {"code": cell["source"], "silent": False} + ycell["execution_state"] = "busy" + content = {"code": str(ycell["source"]), "silent": False} msg = create_message( "execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt) ) @@ -120,40 +131,68 @@ async def execute( msg_id = msg["header"]["msg_id"] self.msg_cnt += 1 await send_message(msg, self.shell_channel, self.key, change_date_to_str=True) + self.execute_requests[msg_id] = { + "iopub_msg": asyncio.Queue(), + "shell_msg": asyncio.Queue(), + } if wait_for_executed: deadline = time.time() + timeout - self.execute_requests[msg_id] = { - "iopub_msg": asyncio.Future(), - "shell_msg": asyncio.Future(), - } while True: try: - await asyncio.wait_for( - self.execute_requests[msg_id]["iopub_msg"], + msg = await asyncio.wait_for( + self.execute_requests[msg_id]["iopub_msg"].get(), deadline_to_timeout(deadline), ) except asyncio.TimeoutError: error_message = f"Kernel didn't respond in {timeout} seconds" raise RuntimeError(error_message) - msg = self.execute_requests[msg_id]["iopub_msg"].result() - self._handle_outputs(cell["outputs"], msg) + await self._handle_outputs(ycell["outputs"], msg) if ( - msg["header"]["msg_type"] == "status" - and msg["content"]["execution_state"] == "idle" + (msg["header"]["msg_type"] == "status" + and msg["content"]["execution_state"] == "idle") ): break - self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future() try: - await asyncio.wait_for( - self.execute_requests[msg_id]["shell_msg"], + msg = await asyncio.wait_for( + self.execute_requests[msg_id]["shell_msg"].get(), deadline_to_timeout(deadline), ) except asyncio.TimeoutError: error_message = f"Kernel didn't respond in {timeout} seconds" raise RuntimeError(error_message) - msg = self.execute_requests[msg_id]["shell_msg"].result() - cell["execution_count"] = msg["content"]["execution_count"] + with ycell.doc.transaction(): + ycell["execution_count"] = msg["content"]["execution_count"] + ycell["execution_state"] = "idle" del self.execute_requests[msg_id] + else: + self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell))) + + async def _handle_iopub(self, msg_id: str, ycell: Map) -> None: + while True: + msg = await self.execute_requests[msg_id]["iopub_msg"].get() + await self._handle_outputs(ycell["outputs"], msg) + if ( + (msg["header"]["msg_type"] == "status" + and msg["content"]["execution_state"] == "idle") + ): + msg = await self.execute_requests[msg_id]["shell_msg"].get() + with ycell.doc.transaction(): + ycell["execution_count"] = msg["content"]["execution_count"] + ycell["execution_state"] = "idle" + + async def _handle_comms(self) -> None: + if self.yjs is None: + return + + while True: + msg = await self.comm_messages.get() + msg_type = msg["header"]["msg_type"] + if msg_type == "comm_open": + comm_id = msg["content"]["comm_id"] + comm = Comm(comm_id, self.shell_channel, self.session_id, self.key) + self.yjs.widgets.comm_open(msg, comm) # type: ignore + elif msg_type == "comm_msg": + self.yjs.widgets.comm_msg(msg) # type: ignore async def _wait_for_ready(self, timeout): deadline = time.time() + timeout @@ -178,22 +217,51 @@ async def _wait_for_ready(self, timeout): break new_timeout = deadline_to_timeout(deadline) - def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]): + async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]): msg_type = msg["header"]["msg_type"] content = msg["content"] if msg_type == "stream": - if (not outputs) or (outputs[-1]["name"] != content["name"]): - outputs.append({"name": content["name"], "output_type": msg_type, "text": []}) - outputs[-1]["text"].append(content["text"]) + with outputs.doc.transaction(): + # TODO: uncomment when changes are made in jupyter-ydoc + if (not outputs) or (outputs[-1]["name"] != content["name"]): # type: ignore + outputs.append( + #Map( + # { + # "name": content["name"], + # "output_type": msg_type, + # "text": Array([content["text"]]), + # } + #) + { + "name": content["name"], + "output_type": msg_type, + "text": [content["text"]], + } + ) + else: + #outputs[-1]["text"].append(content["text"]) # type: ignore + last_output = outputs[-1] + last_output["text"].append(content["text"]) # type: ignore + outputs[-1] = last_output elif msg_type in ("display_data", "execute_result"): - outputs.append( - { - "data": {"text/plain": [content["data"].get("text/plain", "")]}, - "execution_count": content["execution_count"], - "metadata": {}, - "output_type": msg_type, - } - ) + if "application/vnd.jupyter.ywidget-view+json" in content["data"]: + # this is a collaborative widget + model_id = content["data"]["application/vnd.jupyter.ywidget-view+json"]["model_id"] + if self.yjs is not None: + if model_id in self.yjs.widgets.widgets: # type: ignore + doc = self.yjs.widgets.widgets[model_id]["model"].ydoc # type: ignore + path = f"ywidget:{doc.guid}" + await self.yjs.room_manager.websocket_server.get_room(path, ydoc=doc) # type: ignore + outputs.append(doc) + else: + outputs.append( + { + "data": {"text/plain": [content["data"].get("text/plain", "")]}, + "execution_count": content["execution_count"], + "metadata": {}, + "output_type": msg_type, + } + ) elif msg_type == "error": outputs.append( { @@ -203,5 +271,25 @@ def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]): "traceback": content["traceback"], } ) - else: - return + + +class Comm: + def __init__(self, comm_id: str, shell_channel, session_id: str, key: str): + self.comm_id = comm_id + self.shell_channel = shell_channel + self.session_id = session_id + self.key = key + self.msg_cnt = 0 + + def send(self, buffers): + msg = create_message( + "comm_msg", + content={"comm_id": self.comm_id}, + session_id=self.session_id, + msg_id=self.msg_cnt, + buffers=buffers, + ) + self.msg_cnt += 1 + asyncio.create_task( + send_message(msg, self.shell_channel, self.key, change_date_to_str=True) + ) diff --git a/plugins/kernels/fps_kernels/kernel_driver/message.py b/plugins/kernels/fps_kernels/kernel_driver/message.py index 6ca6117b..6946c73c 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/message.py +++ b/plugins/kernels/fps_kernels/kernel_driver/message.py @@ -56,6 +56,7 @@ def create_message( content: Dict = {}, session_id: str = "", msg_id: str = "", + buffers: List = [], ) -> Dict[str, Any]: header = create_message_header(msg_type, session_id, msg_id) msg = { @@ -65,7 +66,7 @@ def create_message( "parent_header": {}, "content": content, "metadata": {}, - "buffers": [], + "buffers": buffers, } return msg diff --git a/plugins/kernels/fps_kernels/routes.py b/plugins/kernels/fps_kernels/routes.py index 2f9a2a9a..e15eab80 100644 --- a/plugins/kernels/fps_kernels/routes.py +++ b/plugins/kernels/fps_kernels/routes.py @@ -259,8 +259,12 @@ async def execute_cell( execution = Execution(**r) if kernel_id in kernels: ynotebook = self.yjs.get_document(execution.document_id) - cell = ynotebook.get_cell(execution.cell_idx) - cell["outputs"] = [] + ycells = [ycell for ycell in ynotebook.ycells if ycell["id"] == execution.cell_id] + if not ycells: + return # FIXME + + ycell = ycells[0] + del ycell["outputs"][:] kernel = kernels[kernel_id] if not kernel["driver"]: @@ -268,12 +272,12 @@ async def execute_cell( kernelspec_path=Path(find_kernelspec(kernel["name"])).as_posix(), write_connection_file=False, connection_file=kernel["server"].connection_file_path, + yjs=self.yjs, ) await driver.connect() driver = kernel["driver"] - await driver.execute(cell) - ynotebook.set_cell(execution.cell_idx, cell) + await driver.execute(ycell, wait_for_executed=False) async def get_kernel( self, diff --git a/plugins/noauth/pyproject.toml b/plugins/noauth/pyproject.toml index d308dc11..56828eab 100644 --- a/plugins/noauth/pyproject.toml +++ b/plugins/noauth/pyproject.toml @@ -27,8 +27,8 @@ text = "BSD 3-Clause License" Homepage = "https://jupyter.org" [project.entry-points] -"asphalt.components" = {noauth = "fps_noauth.main:NoAuthComponent"} -"jupyverse.components" = {noauth = "fps_noauth.main:NoAuthComponent"} +"asphalt.components" = {auth = "fps_noauth.main:NoAuthComponent"} +"jupyverse.components" = {auth = "fps_noauth.main:NoAuthComponent"} [tool.check-manifest] ignore = [ ".*",] diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index ef0e30fc..56852885 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -14,6 +14,7 @@ WebSocketDisconnect, status, ) +from pycrdt import Doc from websockets.exceptions import ConnectionClosedOK from jupyverse_api.app import App @@ -27,6 +28,7 @@ from .ywebsocket.websocket_server import WebsocketServer, YRoom from .ywebsocket.ystore import SQLiteYStore, YDocNotFound from .ywebsocket.yutils import YMessageType, YSyncMessageType +from .ywidgets import Widgets YFILE = YDOCS["file"] AWARENESS = 1 @@ -48,6 +50,7 @@ def __init__( super().__init__(app=app, auth=auth) self.contents = contents self.room_manager = RoomManager(contents) + self.widgets = Widgets() async def collaboration_room_websocket( self, @@ -178,6 +181,7 @@ async def serve(self, websocket: YWebsocket, permissions) -> None: file_path = await self.contents.file_id_manager.get_path(file_id) logger.info(f"Opening collaboration room: {websocket.path} ({file_path})") document = YDOCS.get(file_type, YFILE)(room.ydoc) + document.file_id = file_id self.documents[websocket.path] = document async with self.lock: model = await self.contents.read_content(file_path, True, file_format) @@ -359,17 +363,17 @@ async def maybe_clean_room(self, room, ws_path: str) -> None: class JupyterWebsocketServer(WebsocketServer): - async def get_room(self, ws_path: str) -> YRoom: + async def get_room(self, ws_path: str, ydoc: Doc | None = None) -> YRoom: if ws_path not in self.rooms: if ws_path.count(":") >= 2: # it is a stored document (e.g. a notebook) file_format, file_type, file_id = ws_path.split(":", 2) updates_file_path = f".{file_type}:{file_id}.y" ystore = JupyterSQLiteYStore(path=updates_file_path) # FIXME: pass in config - self.rooms[ws_path] = YRoom(ready=False, ystore=ystore) + self.rooms[ws_path] = YRoom(ydoc=ydoc, ready=False, ystore=ystore) else: # it is a transient document (e.g. awareness) - self.rooms[ws_path] = YRoom() + self.rooms[ws_path] = YRoom(ydoc=ydoc) room = self.rooms[ws_path] await self.start_room(room) return room diff --git a/plugins/yjs/fps_yjs/ydocs/ybasedoc.py b/plugins/yjs/fps_yjs/ydocs/ybasedoc.py index 8b68dad4..3d75d35e 100644 --- a/plugins/yjs/fps_yjs/ydocs/ybasedoc.py +++ b/plugins/yjs/fps_yjs/ydocs/ybasedoc.py @@ -51,6 +51,14 @@ def path(self) -> Optional[str]: def path(self, value: str) -> None: self._ystate["path"] = value + @property + def file_id(self) -> Optional[str]: + return self._ystate.get("file_id") + + @file_id.setter + def file_id(self, value: str) -> None: + self._ystate["file_id"] = value + @abstractmethod def get(self) -> Any: ... diff --git a/plugins/yjs/fps_yjs/ydocs/ynotebook.py b/plugins/yjs/fps_yjs/ydocs/ynotebook.py index 6316ba19..0cadd698 100644 --- a/plugins/yjs/fps_yjs/ydocs/ynotebook.py +++ b/plugins/yjs/fps_yjs/ydocs/ynotebook.py @@ -38,6 +38,7 @@ def cell_number(self) -> int: def get_cell(self, index: int) -> Dict[str, Any]: meta = json.loads(str(self._ymeta)) cell = json.loads(str(self._ycells[index])) + cell.pop("execution_status", None) cast_all(cell, float, int) # cells coming from Yjs have e.g. execution_count as float if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4: # strip cell IDs if we have notebook format 4.0-4.4 @@ -73,6 +74,7 @@ def create_ycell(self, value: Dict[str, Any]) -> Map: del cell["attachments"] elif cell_type == "code": cell["outputs"] = Array(cell.get("outputs", [])) + cell["execution_status"] = "idle" return Map(cell) @@ -123,7 +125,7 @@ def set(self, value: Dict) -> None: # clear document self._ymeta.clear() self._ycells.clear() - for key in [k for k in self._ystate.keys() if k not in ("dirty", "path")]: + for key in [k for k in self._ystate.keys() if k not in ("dirty", "path", "file_id")]: del self._ystate[key] # initialize document diff --git a/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py b/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py index 6bff2a1c..40100211 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py +++ b/plugins/yjs/fps_yjs/ywebsocket/websocket_server.py @@ -5,6 +5,7 @@ from anyio import TASK_STATUS_IGNORED, Event, create_task_group from anyio.abc import TaskGroup, TaskStatus +from pycrdt import Doc from .websocket import Websocket from .yroom import YRoom @@ -57,7 +58,7 @@ def started(self) -> Event: self._started = Event() return self._started - async def get_room(self, name: str) -> YRoom: + async def get_room(self, name: str, ydoc: Doc | None = None) -> YRoom: """Get or create a room with the given name, and start it. Arguments: @@ -67,7 +68,7 @@ async def get_room(self, name: str) -> YRoom: The room with the given name, or a new one if no room with that name was found. """ if name not in self.rooms.keys(): - self.rooms[name] = YRoom(ready=self.rooms_ready, log=self.log) + self.rooms[name] = YRoom(ydoc=ydoc, ready=self.rooms_ready, log=self.log) room = self.rooms[name] await self.start_room(room) return room diff --git a/plugins/yjs/fps_yjs/ywebsocket/yroom.py b/plugins/yjs/fps_yjs/ywebsocket/yroom.py index 0d167dcd..15fd41de 100644 --- a/plugins/yjs/fps_yjs/ywebsocket/yroom.py +++ b/plugins/yjs/fps_yjs/ywebsocket/yroom.py @@ -41,7 +41,11 @@ class YRoom: _starting: bool def __init__( - self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None + self, + ydoc: Doc | None = None, + ready: bool = True, + ystore: BaseYStore | None = None, + log: Logger | None = None, ): """Initialize the object. @@ -63,7 +67,7 @@ def __init__( ystore: An optional store in which to persist document updates. log: An optional logger. """ - self.ydoc = Doc() + self.ydoc = Doc() if ydoc is None else ydoc self.awareness = Awareness(self.ydoc) self._update_send_stream, self._update_receive_stream = create_memory_object_stream( max_buffer_size=65536 diff --git a/plugins/yjs/fps_yjs/ywidgets/__init__.py b/plugins/yjs/fps_yjs/ywidgets/__init__.py new file mode 100644 index 00000000..b7d9174b --- /dev/null +++ b/plugins/yjs/fps_yjs/ywidgets/__init__.py @@ -0,0 +1 @@ +from .widgets import Widgets as Widgets diff --git a/plugins/yjs/fps_yjs/ywidgets/widgets.py b/plugins/yjs/fps_yjs/ywidgets/widgets.py new file mode 100644 index 00000000..c0c5ed93 --- /dev/null +++ b/plugins/yjs/fps_yjs/ywidgets/widgets.py @@ -0,0 +1,52 @@ +import pkg_resources +from pycrdt import TransactionEvent +from ypywidgets.utils import ( # type: ignore + YMessageType, + YSyncMessageType, + create_update_message, + process_sync_message, + sync, +) + + +class Widgets: + def __init__(self): + self.ydocs = { + ep.name: ep.load() for ep in pkg_resources.iter_entry_points(group="ypywidgets") + } + self.widgets = {} + + def comm_open(self, msg, comm) -> None: + target_name = msg["content"]["target_name"] + if target_name != "ywidget": + return + + name = msg["metadata"]["ymodel_name"] + comm_id = msg["content"]["comm_id"] + self.comm = comm + model = self.ydocs[f"{name}Model"]() + self.widgets[comm_id] = {"model": model, "comm": comm} + msg = sync(model.ydoc) + comm.send(**msg) + + def comm_msg(self, msg) -> None: + comm_id = msg["content"]["comm_id"] + message = bytes(msg["buffers"][0]) + if message[0] == YMessageType.SYNC: + ydoc = self.widgets[comm_id]["model"].ydoc + reply = process_sync_message( + message[1:], + ydoc, + ) + if reply: + self.widgets[comm_id]["comm"].send(buffers=[reply]) + if message[1] == YSyncMessageType.SYNC_STEP2: + ydoc.observe(self._send) + + def _send(self, event: TransactionEvent): + update = event.update # type: ignore + message = create_update_message(update) + try: + self.comm.send(buffers=[message]) + except Exception: + pass diff --git a/pyproject.toml b/pyproject.toml index 9942a6f2..c86aacbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ test = [ "requests", "websockets", "ipykernel", + "ypywidgets >=0.6.4,<0.7.0", + "ypywidgets-textual >=0.2.2,<0.3.0", ] docs = [ "mkdocs", "mkdocs-material" ] diff --git a/tests/data/notebook1.ipynb b/tests/data/notebook1.ipynb new file mode 100644 index 00000000..e1c94429 --- /dev/null +++ b/tests/data/notebook1.ipynb @@ -0,0 +1,55 @@ +{ + "cells": [ + { + "execution_count": null, + "outputs": [], + "id": "a7243792-6f06-4462-a6b5-7e9ec604348e", + "source": "from ypywidgets_textual.switch import Switch", + "cell_type": "code", + "metadata": { + "trusted": false + } + }, + { + "id": "a7243792-6f06-4462-a6b5-7e9ec604348f", + "source": "switch = Switch()\nswitch", + "execution_count": null, + "metadata": { + "trusted": false + }, + "outputs": [], + "cell_type": "code" + }, + { + "outputs": [], + "id": "a7243792-6f06-4462-a6b5-7e9ec604349f", + "source": "switch.toggle()", + "cell_type": "code", + "metadata": { + "trusted": false + }, + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "language": "python", + "name": "python3", + "display_name": "Python 3 (ipykernel)" + }, + "language_info": { + "version": "3.7.12", + "codemirror_mode": { + "version": 3, + "name": "ipython" + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_execute.py b/tests/test_execute.py new file mode 100644 index 00000000..d423f1a1 --- /dev/null +++ b/tests/test_execute.py @@ -0,0 +1,150 @@ +import asyncio +import os +from functools import partial +from pathlib import Path + +import pytest +from asphalt.core import Context +from fps_yjs.ydocs import ydocs +from fps_yjs.ywebsocket import WebsocketProvider +from httpx import AsyncClient +from httpx_ws import aconnect_ws +from pycrdt import Doc, Map, Text +from utils import configure + +from jupyverse_api.main import JupyverseComponent + +os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1" + +COMPONENTS = { + "app": {"type": "app"}, + "auth": {"type": "auth", "test": True}, + "contents": {"type": "contents"}, + "frontend": {"type": "frontend"}, + "lab": {"type": "lab"}, + "jupyterlab": {"type": "jupyterlab"}, + "kernels": {"type": "kernels"}, + "yjs": {"type": "yjs"}, +} + + +class Websocket: + def __init__(self, websocket, roomid: str): + self.websocket = websocket + self.roomid = roomid + + @property + def path(self) -> str: + return self.roomid + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + try: + message = await self.recv() + except Exception: + raise StopAsyncIteration() + return message + + async def send(self, message: bytes): + await self.websocket.send_bytes(message) + + async def recv(self) -> bytes: + b = await self.websocket.receive_bytes() + return bytes(b) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("auth_mode", ("noauth",)) +async def test_execute(auth_mode, unused_tcp_port): + url = f"http://127.0.0.1:{unused_tcp_port}" + components = configure(COMPONENTS, { + "auth": {"mode": auth_mode}, + "kernels": {"require_yjs": True}, + }) + async with Context() as ctx, AsyncClient() as http: + await JupyverseComponent( + components=components, + port=unused_tcp_port, + ).start(ctx) + + ws_url = url.replace("http", "ws", 1) + name = "notebook1.ipynb" + path = (Path("tests") / "data" / name).as_posix() + # create a session to launch a kernel + response = await http.post( + f"{url}/api/sessions", + json={ + "kernel": {"name": "python3"}, + "name": name, + "path": path, + "type": "notebook", + }, + ) + r = response.json() + kernel_id = r["kernel"]["id"] + # get the room ID for the document + response = await http.put( + f"{url}/api/collaboration/session/{path}", + json={ + "format": "json", + "type": "notebook", + } + ) + file_id = response.json()["fileId"] + document_id = f"json:notebook:{file_id}" + ynb = ydocs["notebook"]() + def callback(aevent, events, event): + events.append(event) + aevent.set() + aevent = asyncio.Event() + events = [] + ynb.ydoc.observe_subdocs(partial(callback, aevent, events)) + async with aconnect_ws( + f"{ws_url}/api/collaboration/room/{document_id}" + ) as websocket, WebsocketProvider(ynb.ydoc, Websocket(websocket, document_id)): + # connect to the shared notebook document + # wait for file to be loaded and Y model to be created in server and client + await asyncio.sleep(0.5) + # execute notebook + for cell_idx in range(2): + response = await http.post( + f"{url}/api/kernels/{kernel_id}/execute", + json={ + "document_id": document_id, + "cell_id": ynb.ycells[cell_idx]["id"], + } + ) + while True: + await aevent.wait() + aevent.clear() + guid = None + for event in events: + if event.added: + guid = event.added[0] + if guid is not None: + break + task = asyncio.create_task(connect_ywidget(ws_url, guid)) + response = await http.post( + f"{url}/api/kernels/{kernel_id}/execute", + json={ + "document_id": document_id, + "cell_id": ynb.ycells[2]["id"], + } + ) + await task + + +async def connect_ywidget(ws_url, guid): + ywidget_doc = Doc() + async with aconnect_ws( + f"{ws_url}/api/collaboration/room/ywidget:{guid}" + ) as websocket, WebsocketProvider(ywidget_doc, Websocket(websocket, guid)): + await asyncio.sleep(0.5) + attrs = Map() + model_name = Text() + ywidget_doc["_attrs"] = attrs + ywidget_doc["_model_name"] = model_name + assert str(model_name) == "Switch" + assert str(attrs) == '{"value":true}' diff --git a/tests/test_server.py b/tests/test_server.py index 041df492..bc2325d4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,11 +1,13 @@ import asyncio import json +from functools import partial from pathlib import Path import pytest import requests +from fps_yjs.ydocs import ydocs from fps_yjs.ywebsocket import WebsocketProvider -from pycrdt import Array, Doc +from pycrdt import Array, Doc, Map, Text from websockets import connect prev_theme = {} @@ -86,6 +88,7 @@ async def test_rest_api(start_jupyverse): # connect to the shared notebook document # wait for file to be loaded and Y model to be created in server and client await asyncio.sleep(0.5) + ydoc["cells"] = ycells = Array() # execute notebook for cell_idx in range(3): response = requests.post( @@ -93,16 +96,14 @@ async def test_rest_api(start_jupyverse): data=json.dumps( { "document_id": document_id, - "cell_idx": cell_idx, + "cell_id": ycells[cell_idx]["id"], } ), ) # wait for Y model to be updated await asyncio.sleep(0.5) # retrieve cells - array = Array() - ydoc["cells"] = array - cells = json.loads(str(array)) + cells = json.loads(str(ycells)) assert cells[0]["outputs"] == [ { "data": {"text/plain": ["3"]}, @@ -122,3 +123,98 @@ async def test_rest_api(start_jupyverse): "output_type": "execute_result", } ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("auth_mode", ("noauth",)) +@pytest.mark.parametrize("clear_users", (False,)) +async def test_ywidgets(start_jupyverse): + url = start_jupyverse + ws_url = url.replace("http", "ws", 1) + name = "notebook1.ipynb" + path = (Path("tests") / "data" / name).as_posix() + # create a session to launch a kernel + response = requests.post( + f"{url}/api/sessions", + data=json.dumps( + { + "kernel": {"name": "python3"}, + #"kernel": {"name": "akernel"}, + "name": name, + "path": path, + "type": "notebook", + } + ), + ) + r = response.json() + kernel_id = r["kernel"]["id"] + # get the room ID for the document + response = requests.put( + f"{url}/api/collaboration/session/{path}", + data=json.dumps( + { + "format": "json", + "type": "notebook", + } + ), + ) + file_id = response.json()["fileId"] + document_id = f"json:notebook:{file_id}" + ynb = ydocs["notebook"]() + def callback(aevent, events, event): + events.append(event) + aevent.set() + aevent = asyncio.Event() + events = [] + ynb.ydoc.observe_subdocs(partial(callback, aevent, events)) + async with connect( + f"{ws_url}/api/collaboration/room/{document_id}" + ) as websocket, WebsocketProvider(ynb.ydoc, websocket): + # connect to the shared notebook document + # wait for file to be loaded and Y model to be created in server and client + await asyncio.sleep(0.5) + # execute notebook + for cell_idx in range(2): + response = requests.post( + f"{url}/api/kernels/{kernel_id}/execute", + data=json.dumps( + { + "document_id": document_id, + "cell_id": ynb.ycells[cell_idx]["id"], + } + ), + ) + while True: + await aevent.wait() + aevent.clear() + guid = None + for event in events: + if event.added: + guid = event.added[0] + if guid is not None: + break + task = asyncio.create_task(connect_ywidget(ws_url, guid)) + response = requests.post( + f"{url}/api/kernels/{kernel_id}/execute", + data=json.dumps( + { + "document_id": document_id, + "cell_id": ynb.ycells[2]["id"], + } + ), + ) + await task + + +async def connect_ywidget(ws_url, guid): + ywidget_doc = Doc() + async with connect( + f"{ws_url}/api/collaboration/room/ywidget:{guid}" + ) as websocket, WebsocketProvider(ywidget_doc, websocket): + await asyncio.sleep(0.5) + attrs = Map() + model_name = Text() + ywidget_doc["_attrs"] = attrs + ywidget_doc["_model_name"] = model_name + assert str(model_name) == "Switch" + assert str(attrs) == '{"value":true}'