Skip to content

Commit

Permalink
Update event loop handling
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 committed Jan 8, 2024
1 parent 63bc61c commit f725764
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 86 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
]
dependencies = [
"pytest",
"jupyter_core"
"jupyter_core>=5.7"
]
requires-python = ">=3.8"

Expand Down
1 change: 0 additions & 1 deletion pytest_jupyter/jupyter_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

# Bring in local plugins.
from pytest_jupyter.jupyter_core import * # noqa: F403
from pytest_jupyter.pytest_tornasync import * # noqa: F403


@pytest.fixture()
Expand Down
53 changes: 26 additions & 27 deletions pytest_jupyter/jupyter_core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Fixtures for use with jupyter core and downstream."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import json
import os
import sys
import typing
from inspect import iscoroutinefunction
from pathlib import Path

import jupyter_core
import pytest
from jupyter_core.utils import ensure_event_loop

from .utils import mkdir

Expand All @@ -35,34 +35,33 @@
resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard))


@pytest.hookimpl(tryfirst=True)
def pytest_pycollect_makeitem(collector, name, obj):
"""Custom pytest collection hook."""
if collector.funcnamefilter(name) and iscoroutinefunction(obj):
return list(collector._genfunctions(name, obj))
return None


@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem):
"""Custom pytest function call hook."""
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}

if not iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj(**testargs)
return True

loop = ensure_event_loop()
loop.run_until_complete(pyfuncitem.obj(**testargs))
return True


@pytest.fixture()
def jp_asyncio_loop():
"""Get an asyncio loop."""
if os.name == "nt":
asyncio.set_event_loop_policy(
asyncio.WindowsSelectorEventLoopPolicy() # type:ignore[attr-defined]
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()


@pytest.fixture(autouse=True)
def io_loop(jp_asyncio_loop):
"""Override the io_loop for pytest_tornasync. This is a no-op
if tornado is not installed."""

async def get_tornado_loop() -> typing.Any:
"""Asynchronously get a tornado loop."""
try:
from tornado.ioloop import IOLoop

return IOLoop.current()
except ImportError:
pass

return jp_asyncio_loop.run_until_complete(get_tornado_loop())
return ensure_event_loop()


@pytest.fixture()
Expand Down
36 changes: 2 additions & 34 deletions pytest_jupyter/jupyter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations

import asyncio
import importlib
import io
import logging
Expand Down Expand Up @@ -53,36 +52,6 @@
from pytest_jupyter.pytest_tornasync import * # noqa: F403
from pytest_jupyter.utils import mkdir

# Override some of the fixtures from pytest_tornasync
# The io_loop fixture is overridden in jupyter_core.py so it
# can be shared by other plugins that need it (e.g. jupyter_client.py).


@pytest.fixture()
def http_server(io_loop, http_server_port, jp_web_app):
"""Start a tornado HTTP server that listens on all available interfaces."""

async def get_server():
"""Get a server asynchronously."""
server = tornado.httpserver.HTTPServer(jp_web_app)
server.add_socket(http_server_port[0])
return server

server = io_loop.run_sync(get_server)
yield server
server.stop()

if hasattr(server, "close_all_connections"):
try:
io_loop.run_sync(server.close_all_connections)
except asyncio.TimeoutError:
pass

http_server_port[0].close()


# End pytest_tornasync overrides


@pytest.fixture()
def jp_server_config():
Expand Down Expand Up @@ -177,7 +146,6 @@ def jp_configurable_serverapp(
jp_root_dir,
jp_logging_stream,
jp_asyncio_loop,
io_loop,
):
"""Starts a Jupyter Server instance based on
the provided configuration values.
Expand Down Expand Up @@ -207,7 +175,6 @@ def _configurable_serverapp(
environ=jp_environ,
http_port=jp_http_port,
tmp_path=tmp_path,
io_loop=io_loop,
root_dir=jp_root_dir,
**kwargs,
):
Expand Down Expand Up @@ -345,7 +312,7 @@ async def my_test(jp_fetch, jp_ws_fetch):
...
"""

def client_fetch(*parts, headers=None, params=None, **kwargs): # noqa: ARG
def client_fetch(*parts, headers=None, params=None, **kwargs):
if not headers:
headers = {}
if not params:
Expand Down Expand Up @@ -414,6 +381,7 @@ async def _(url, **fetch_kwargs):
code = r.code
except HTTPClientError as err:
code = err.code
print(f"HTTPClientError ({err.code}): {err}") # noqa: T201
else:
if fetch is jp_ws_fetch:
r.close()
Expand Down
49 changes: 26 additions & 23 deletions pytest_jupyter/pytest_tornasync.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Vendored fork of pytest_tornasync from
https://github.com/eukaryote/pytest-tornasync/blob/9f1bdeec3eb5816e0183f975ca65b5f6f29fbfbb/src/pytest_tornasync/plugin.py
"""
import asyncio
from contextlib import closing
from inspect import iscoroutinefunction

try:
import tornado.ioloop
Expand All @@ -14,33 +14,36 @@
import pytest

# mypy: disable-error-code="no-untyped-call"
# Bring in local plugins.
from pytest_jupyter.jupyter_core import * # noqa: F403


@pytest.hookimpl(tryfirst=True)
def pytest_pycollect_makeitem(collector, name, obj):
"""Custom pytest collection hook."""
if collector.funcnamefilter(name) and iscoroutinefunction(obj):
return list(collector._genfunctions(name, obj))
return None
@pytest.fixture()
def io_loop(jp_asyncio_loop):
return tornado.ioloop.IOLoop.current()


@pytest.fixture()
def http_server(jp_asyncio_loop, http_server_port, jp_web_app):
"""Start a tornado HTTP server that listens on all available interfaces."""

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem):
"""Custom pytest function call hook."""
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
async def get_server():
"""Get a server asynchronously."""
server = tornado.httpserver.HTTPServer(jp_web_app)
server.add_socket(http_server_port[0])
return server

if not iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj(**testargs)
return True
server = jp_asyncio_loop.run_until_complete(get_server())
yield server
server.stop()

try:
loop = funcargs["io_loop"]
except KeyError:
loop = tornado.ioloop.IOLoop.current()
if hasattr(server, "close_all_connections"):
try:
jp_asyncio_loop.run_until_complete(server.close_all_connections())
except asyncio.TimeoutError:
pass

loop.run_sync(lambda: pyfuncitem.obj(**testargs))
return True
http_server_port[0].close()


@pytest.fixture()
Expand All @@ -52,7 +55,7 @@ def http_server_port():


@pytest.fixture()
def http_server_client(http_server, io_loop):
def http_server_client(http_server, jp_asyncio_loop):
"""
Create an asynchronous HTTP client that can fetch from `http_server`.
"""
Expand All @@ -61,7 +64,7 @@ async def get_client():
"""Get a client."""
return AsyncHTTPServerClient(http_server=http_server)

client = io_loop.run_sync(get_client)
client = jp_asyncio_loop.run_until_complete(get_client())
with closing(client) as context:
yield context

Expand Down
4 changes: 4 additions & 0 deletions tests/test_jupyter_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,7 @@ def test_template_dir(jp_template_dir):

def test_extension_environ(jp_extension_environ):
pass


def test_ioloop_fixture(io_loop):
pass

0 comments on commit f725764

Please sign in to comment.