Skip to content

Commit

Permalink
Implement server-side ypywidgets rendering (#364)
Browse files Browse the repository at this point in the history
* Implement server-side ypywidgets rendering

* Fix types

* Use ypywidgets-textual in tests

* Update with ypywidgets v0.6.1 and ypywidgets-textual 0.2.1

* Use cell ID instead of cell index in execute API

* Add JupyterLab server_side_execution flag

* Set shared document file_id
  • Loading branch information
davidbrochart authored Dec 15, 2023
1 parent c1f68a3 commit 73a16c3
Show file tree
Hide file tree
Showing 19 changed files with 541 additions and 65 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,6 @@ $RECYCLE.BIN/
.jupyter_ystore.db
.jupyter_ystore.db-journal
fps_cli_args.toml

# pixi environments
.pixi
1 change: 1 addition & 0 deletions jupyverse_api/jupyverse_api/jupyterlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ async def get_workspace(

class JupyterLabConfig(Config):
dev_mode: bool = False
server_side_execution: bool = False
2 changes: 1 addition & 1 deletion jupyverse_api/jupyverse_api/kernels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ class Session(BaseModel):

class Execution(BaseModel):
document_id: str
cell_idx: int
cell_id: str
6 changes: 5 additions & 1 deletion plugins/jupyterlab/fps_jupyterlab/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async def get_lab(
self.get_index(
"default",
self.frontend_config.collaborative,
self.jupyterlab_config.server_side_execution,
self.jupyterlab_config.dev_mode,
self.frontend_config.base_url,
)
Expand All @@ -71,6 +72,7 @@ async def load_workspace(
self.get_index(
"default",
self.frontend_config.collaborative,
self.jupyterlab_config.server_side_execution,
self.jupyterlab_config.dev_mode,
self.frontend_config.base_url,
)
Expand Down Expand Up @@ -99,11 +101,12 @@ async def get_workspace(
return self.get_index(
name,
self.frontend_config.collaborative,
self.jupyterlab_config.server_side_execution,
self.jupyterlab_config.dev_mode,
self.frontend_config.base_url,
)

def get_index(self, workspace, collaborative, dev_mode, base_url="/"):
def get_index(self, workspace, collaborative, server_side_execution, dev_mode, base_url="/"):
for path in (self.static_lab_dir).glob("main.*.js"):
main_id = path.name.split(".")[1]
break
Expand All @@ -121,6 +124,7 @@ def get_index(self, workspace, collaborative, dev_mode, base_url="/"):
"baseUrl": base_url,
"cacheFiles": False,
"collaborative": collaborative,
"serverSideExecution": server_side_execution,
"devMode": dev_mode,
"disabledExtensions": self.disabled_extension,
"exposeAppInBrowser": False,
Expand Down
174 changes: 131 additions & 43 deletions plugins/kernels/fps_kernels/kernel_driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import uuid
from typing import Any, Dict, List, Optional, cast

from pycrdt import Array, Map

from jupyverse_api.yjs import Yjs

from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file
from .connect import write_connection_file as _write_connection_file
from .kernelspec import find_kernelspec
Expand All @@ -23,10 +27,12 @@ def __init__(
connection_file: str = "",
write_connection_file: bool = True,
capture_kernel_output: bool = True,
yjs: Optional[Yjs] = None,
) -> None:
self.capture_kernel_output = capture_kernel_output
self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name)
self.kernel_cwd = kernel_cwd
self.yjs = yjs
if not self.kernelspec_path:
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
if write_connection_file:
Expand All @@ -37,11 +43,12 @@ def __init__(
self.key = cast(str, self.connection_cfg["key"])
self.session_id = uuid.uuid4().hex
self.msg_cnt = 0
self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {}
self.channel_tasks: List[asyncio.Task] = []
self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {}
self.comm_messages: asyncio.Queue = asyncio.Queue()
self.tasks: List[asyncio.Task] = []

async def restart(self, startup_timeout: float = float("inf")) -> None:
for task in self.channel_tasks:
for task in self.tasks:
task.cancel()
msg = create_message("shutdown_request", content={"restart": True})
await send_message(msg, self.control_channel, self.key, change_date_to_str=True)
Expand All @@ -52,7 +59,7 @@ async def restart(self, startup_timeout: float = float("inf")) -> None:
if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]:
break
await self._wait_for_ready(startup_timeout)
self.channel_tasks = []
self.tasks = []
self.listen_channels()

async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None:
Expand All @@ -69,6 +76,7 @@ async def connect(self, startup_timeout: float = float("inf")) -> None:
self.connect_channels()
await self._wait_for_ready(startup_timeout)
self.listen_channels()
self.tasks.append(asyncio.create_task(self._handle_comms()))

def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
connection_cfg = connection_cfg or self.connection_cfg
Expand All @@ -77,40 +85,43 @@ def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
self.iopub_channel = connect_channel("iopub", connection_cfg)

def listen_channels(self):
self.channel_tasks.append(asyncio.create_task(self.listen_iopub()))
self.channel_tasks.append(asyncio.create_task(self.listen_shell()))
self.tasks.append(asyncio.create_task(self.listen_iopub()))
self.tasks.append(asyncio.create_task(self.listen_shell()))

async def stop(self) -> None:
self.kernel_process.kill()
await self.kernel_process.wait()
os.remove(self.connection_file_path)
for task in self.channel_tasks:
for task in self.tasks:
task.cancel()

async def listen_iopub(self):
while True:
msg = await receive_message(self.iopub_channel, change_str_to_date=True)
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["iopub_msg"].set_result(msg)
parent_id = msg["parent_header"].get("msg_id")
if msg["msg_type"] in ("comm_open", "comm_msg"):
self.comm_messages.put_nowait(msg)
elif parent_id in self.execute_requests.keys():
self.execute_requests[parent_id]["iopub_msg"].put_nowait(msg)

async def listen_shell(self):
while True:
msg = await receive_message(self.shell_channel, change_str_to_date=True)
msg_id = msg["parent_header"].get("msg_id")
if msg_id in self.execute_requests.keys():
self.execute_requests[msg_id]["shell_msg"].set_result(msg)
self.execute_requests[msg_id]["shell_msg"].put_nowait(msg)

async def execute(
self,
cell: Dict[str, Any],
ycell: Map,
timeout: float = float("inf"),
msg_id: str = "",
wait_for_executed: bool = True,
) -> None:
if cell["cell_type"] != "code":
if ycell["cell_type"] != "code":
return
content = {"code": cell["source"], "silent": False}
ycell["execution_state"] = "busy"
content = {"code": str(ycell["source"]), "silent": False}
msg = create_message(
"execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt)
)
Expand All @@ -120,40 +131,68 @@ async def execute(
msg_id = msg["header"]["msg_id"]
self.msg_cnt += 1
await send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
self.execute_requests[msg_id] = {
"iopub_msg": asyncio.Queue(),
"shell_msg": asyncio.Queue(),
}
if wait_for_executed:
deadline = time.time() + timeout
self.execute_requests[msg_id] = {
"iopub_msg": asyncio.Future(),
"shell_msg": asyncio.Future(),
}
while True:
try:
await asyncio.wait_for(
self.execute_requests[msg_id]["iopub_msg"],
msg = await asyncio.wait_for(
self.execute_requests[msg_id]["iopub_msg"].get(),
deadline_to_timeout(deadline),
)
except asyncio.TimeoutError:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["iopub_msg"].result()
self._handle_outputs(cell["outputs"], msg)
await self._handle_outputs(ycell["outputs"], msg)
if (
msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle"
(msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle")
):
break
self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future()
try:
await asyncio.wait_for(
self.execute_requests[msg_id]["shell_msg"],
msg = await asyncio.wait_for(
self.execute_requests[msg_id]["shell_msg"].get(),
deadline_to_timeout(deadline),
)
except asyncio.TimeoutError:
error_message = f"Kernel didn't respond in {timeout} seconds"
raise RuntimeError(error_message)
msg = self.execute_requests[msg_id]["shell_msg"].result()
cell["execution_count"] = msg["content"]["execution_count"]
with ycell.doc.transaction():
ycell["execution_count"] = msg["content"]["execution_count"]
ycell["execution_state"] = "idle"
del self.execute_requests[msg_id]
else:
self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell)))

async def _handle_iopub(self, msg_id: str, ycell: Map) -> None:
while True:
msg = await self.execute_requests[msg_id]["iopub_msg"].get()
await self._handle_outputs(ycell["outputs"], msg)
if (
(msg["header"]["msg_type"] == "status"
and msg["content"]["execution_state"] == "idle")
):
msg = await self.execute_requests[msg_id]["shell_msg"].get()
with ycell.doc.transaction():
ycell["execution_count"] = msg["content"]["execution_count"]
ycell["execution_state"] = "idle"

async def _handle_comms(self) -> None:
if self.yjs is None:
return

while True:
msg = await self.comm_messages.get()
msg_type = msg["header"]["msg_type"]
if msg_type == "comm_open":
comm_id = msg["content"]["comm_id"]
comm = Comm(comm_id, self.shell_channel, self.session_id, self.key)
self.yjs.widgets.comm_open(msg, comm) # type: ignore
elif msg_type == "comm_msg":
self.yjs.widgets.comm_msg(msg) # type: ignore

async def _wait_for_ready(self, timeout):
deadline = time.time() + timeout
Expand All @@ -178,22 +217,51 @@ async def _wait_for_ready(self, timeout):
break
new_timeout = deadline_to_timeout(deadline)

def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]):
msg_type = msg["header"]["msg_type"]
content = msg["content"]
if msg_type == "stream":
if (not outputs) or (outputs[-1]["name"] != content["name"]):
outputs.append({"name": content["name"], "output_type": msg_type, "text": []})
outputs[-1]["text"].append(content["text"])
with outputs.doc.transaction():
# TODO: uncomment when changes are made in jupyter-ydoc
if (not outputs) or (outputs[-1]["name"] != content["name"]): # type: ignore
outputs.append(
#Map(
# {
# "name": content["name"],
# "output_type": msg_type,
# "text": Array([content["text"]]),
# }
#)
{
"name": content["name"],
"output_type": msg_type,
"text": [content["text"]],
}
)
else:
#outputs[-1]["text"].append(content["text"]) # type: ignore
last_output = outputs[-1]
last_output["text"].append(content["text"]) # type: ignore
outputs[-1] = last_output
elif msg_type in ("display_data", "execute_result"):
outputs.append(
{
"data": {"text/plain": [content["data"].get("text/plain", "")]},
"execution_count": content["execution_count"],
"metadata": {},
"output_type": msg_type,
}
)
if "application/vnd.jupyter.ywidget-view+json" in content["data"]:
# this is a collaborative widget
model_id = content["data"]["application/vnd.jupyter.ywidget-view+json"]["model_id"]
if self.yjs is not None:
if model_id in self.yjs.widgets.widgets: # type: ignore
doc = self.yjs.widgets.widgets[model_id]["model"].ydoc # type: ignore
path = f"ywidget:{doc.guid}"
await self.yjs.room_manager.websocket_server.get_room(path, ydoc=doc) # type: ignore
outputs.append(doc)
else:
outputs.append(
{
"data": {"text/plain": [content["data"].get("text/plain", "")]},
"execution_count": content["execution_count"],
"metadata": {},
"output_type": msg_type,
}
)
elif msg_type == "error":
outputs.append(
{
Expand All @@ -203,5 +271,25 @@ def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
"traceback": content["traceback"],
}
)
else:
return


class Comm:
def __init__(self, comm_id: str, shell_channel, session_id: str, key: str):
self.comm_id = comm_id
self.shell_channel = shell_channel
self.session_id = session_id
self.key = key
self.msg_cnt = 0

def send(self, buffers):
msg = create_message(
"comm_msg",
content={"comm_id": self.comm_id},
session_id=self.session_id,
msg_id=self.msg_cnt,
buffers=buffers,
)
self.msg_cnt += 1
asyncio.create_task(
send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
)
3 changes: 2 additions & 1 deletion plugins/kernels/fps_kernels/kernel_driver/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def create_message(
content: Dict = {},
session_id: str = "",
msg_id: str = "",
buffers: List = [],
) -> Dict[str, Any]:
header = create_message_header(msg_type, session_id, msg_id)
msg = {
Expand All @@ -65,7 +66,7 @@ def create_message(
"parent_header": {},
"content": content,
"metadata": {},
"buffers": [],
"buffers": buffers,
}
return msg

Expand Down
12 changes: 8 additions & 4 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,21 +259,25 @@ async def execute_cell(
execution = Execution(**r)
if kernel_id in kernels:
ynotebook = self.yjs.get_document(execution.document_id)
cell = ynotebook.get_cell(execution.cell_idx)
cell["outputs"] = []
ycells = [ycell for ycell in ynotebook.ycells if ycell["id"] == execution.cell_id]
if not ycells:
return # FIXME

ycell = ycells[0]
del ycell["outputs"][:]

kernel = kernels[kernel_id]
if not kernel["driver"]:
kernel["driver"] = driver = KernelDriver(
kernelspec_path=Path(find_kernelspec(kernel["name"])).as_posix(),
write_connection_file=False,
connection_file=kernel["server"].connection_file_path,
yjs=self.yjs,
)
await driver.connect()
driver = kernel["driver"]

await driver.execute(cell)
ynotebook.set_cell(execution.cell_idx, cell)
await driver.execute(ycell, wait_for_executed=False)

async def get_kernel(
self,
Expand Down
Loading

0 comments on commit 73a16c3

Please sign in to comment.