Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add notebook_vars utility with example notebook and unit test depicting how it can be ised to construct a pytest fixture with the notebook state after execution #39

Merged
merged 11 commits into from
Jan 20, 2025
1 change: 1 addition & 0 deletions open_atmos_jupyter_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from open_atmos_jupyter_utils.pip_install_on_colab import pip_install_on_colab
from open_atmos_jupyter_utils.temporary_file import TemporaryFile
from open_atmos_jupyter_utils.show_plot import show_plot, save_and_make_link
from open_atmos_jupyter_utils.notebook_vars import notebook_vars
31 changes: 31 additions & 0 deletions open_atmos_jupyter_utils/notebook_vars.py
slayoo marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
""" helper routines for use in smoke tests """

from pathlib import Path

import nbformat


def notebook_vars(file: Path, plot: bool):
"""Executes the code from all cells of the Jupyter notebook `file` and
returns a dictionary with the notebook variables. If the `plot` argument
is set to `True`, any code line within the notebook starting with `show_plot(`
(see [open_atmos_jupyter_utils docs](https://pypi.org/p/open_atmos_jupyter_utils))
is replaced with `pyplot.show() #`, otherwise it is replaced with `pyplot.gca().clear() #`
to match the smoke-test conventions."""
notebook = nbformat.read(file, nbformat.NO_CONVERT)
context = {}
for cell in notebook.cells:
if cell.cell_type != "markdown":
lines = cell.source.splitlines()
for i, line in enumerate(lines):
if line.strip().startswith("!"):
lines[i] = line.replace("!", "pass #")
if line.strip().startswith("show_plot("):
lines[i] = line.replace(
"show_plot(",
"from matplotlib import pyplot; "
+ ("pyplot.show() #" if plot else "pyplot.gca().clear() #"),
)

exec("\n".join(lines), context) # pylint: disable=exec-used
return context
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"IPython",
"matplotlib",
"imageio",
"nbformat"
]
[project.urls]
"Homepage" = "https://github.com/open-atmos/jupyter-utils"
Expand Down
Loading