Skip to content

Commit

Permalink
[MNT] show_versions utility (#1688)
Browse files Browse the repository at this point in the history
Adds a `show_versions` utility for users to easily share version
information in bug reports or questions.

Adapted from `sktime`, which in turn is an evolution of the `sklearn`
utility of the same name.
  • Loading branch information
fkiraly authored Sep 30, 2024
1 parent 199c0b4 commit fdf8a7f
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 2 deletions.
6 changes: 4 additions & 2 deletions pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
PyTorch Forecasting package for timeseries forecasting with PyTorch.
"""

__version__ = "1.1.1"

from pytorch_forecasting.data import (
EncoderNormalizer,
GroupNormalizer,
Expand Down Expand Up @@ -59,6 +61,7 @@
to_list,
unpack_sequence,
)
from pytorch_forecasting.utils._maint._show_versions import show_versions

__all__ = [
"TimeSeriesDataSet",
Expand Down Expand Up @@ -109,7 +112,6 @@
"integer_histogram",
"groupby_apply",
"profile",
"show_versions",
"unpack_sequence",
]

__version__ = "1.1.1"
Empty file.
142 changes: 142 additions & 0 deletions pytorch_forecasting/utils/_maint/_show_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# License: BSD 3 clause
"""Utility methods to print system info for debugging.
adapted from
:func: `sklearn.show_versions` and `sktime.show_versions`
"""

__all__ = ["show_versions"]

import importlib
import platform
import sys


def _get_sys_info():
"""System information.
Return
------
sys_info : dict
system and Python version information
"""
python = sys.version.replace("\n", " ")

blob = [
("python", python),
("executable", sys.executable),
("machine", platform.platform()),
]

return dict(blob)


# dependencies to print versions of, by default
DEFAULT_DEPS_TO_SHOW = [
"pip",
"pytorch-forecasting",
"torch",
"lightning",
"numpy",
"scipy",
"pandas",
"cpflows",
"matplotlib",
"optuna",
"optuna-integration",
"pytorch_optimizer",
"scikit-learn",
"scikit-base",
"statsmodels",
]


def _get_deps_info(deps=None, source="distributions"):
"""Overview of the installed version of main dependencies.
Parameters
----------
deps : optional, list of strings with package names
if None, behaves as deps = ["sktime"].
source : str, optional one of "distributions" (default) or "import"
source of version information
* "distributions" - uses importlib.distributions. In this case,
strings in deps are assumed to be PEP 440 package strings,
e.g., scikit-learn, not sklearn.
* "import" - uses the __version__ attribute of the module.
In this case, strings in deps are assumed to be import names,
e.g., sklearn, not scikit-learn.
Returns
-------
deps_info: dict
version information on libraries in `deps`
keys are package names, import names if source is "import",
and PEP 440 package strings if source is "distributions";
values are PEP 440 version strings
of the import as present in the current python environment
"""
if deps is None:
deps = ["pytorch-forecasting"]

if source == "distributions":
from pytorch_forecasting.utils._dependencies import _get_installed_packages

KEY_ALIAS = {"sklearn": "scikit-learn", "skbase": "scikit-base"}

pkgs = _get_installed_packages()

deps_info = {}
for modname in deps:
pkg_name = KEY_ALIAS.get(modname, modname)
deps_info[modname] = pkgs.get(pkg_name, None)

return deps_info

def get_version(module):
return getattr(module, "__version__", None)

deps_info = {}

for modname in deps:
try:
if modname in sys.modules:
mod = sys.modules[modname]
else:
mod = importlib.import_module(modname)
except ImportError:
deps_info[modname] = None
else:
ver = get_version(mod)
deps_info[modname] = ver

return deps_info


def show_versions():
"""Print python version, OS version, sktime version, selected dependency versions.
Pretty prints:
* python version of environment
* python executable location
* OS version
* list of import name and version number for selected python dependencies
Developer note:
Python version/executable and OS version are from `_get_sys_info`
Package versions are retrieved by `_get_deps_info`
Selected dependencies are as in the DEFAULT_DEPS_TO_SHOW variable
"""
sys_info = _get_sys_info()
deps_info = _get_deps_info(deps=DEFAULT_DEPS_TO_SHOW)

print("\nSystem:") # noqa: T001, T201
for k, stat in sys_info.items():
print(f"{k:>10}: {stat}") # noqa: T001, T201

print("\nPython dependencies:") # noqa: T001, T201
for k, stat in deps_info.items():
print(f"{k:>13}: {stat}") # noqa: T001, T201
42 changes: 42 additions & 0 deletions tests/test_utils/test_show_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Tests for the show_versions utility."""

import pathlib
import uuid

from pytorch_forecasting.utils._maint._show_versions import DEFAULT_DEPS_TO_SHOW, _get_deps_info, show_versions


def test_show_versions_runs():
"""Test that show_versions runs without exceptions."""
# only prints, should return None
assert show_versions() is None


def test_show_versions_import_loc():
"""Test that show_version can be imported from root."""
from pytorch_forecasting import show_versions as show_versions_imported

assert show_versions == show_versions_imported


def test_deps_info():
"""Test that _get_deps_info returns package/version dict as per contract."""
deps_info = _get_deps_info()
assert isinstance(deps_info, dict)
assert set(deps_info.keys()) == {"pytorch-forecasting"}

deps_info_default = _get_deps_info(DEFAULT_DEPS_TO_SHOW)
assert isinstance(deps_info_default, dict)
assert set(deps_info_default.keys()) == set(DEFAULT_DEPS_TO_SHOW)


def test_deps_info_deps_missing_package_present_directory():
"""Test that _get_deps_info does not fail if a dependency is missing."""
dummy_package_name = uuid.uuid4().hex

dummy_folder_path = pathlib.Path(dummy_package_name)
dummy_folder_path.mkdir()

assert _get_deps_info([dummy_package_name]) == {dummy_package_name: None}

dummy_folder_path.rmdir()

0 comments on commit fdf8a7f

Please sign in to comment.