Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] isolate matplotlib as soft dependency #1636

Merged
merged 18 commits into from
Sep 4, 2024
Merged
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[dev,github-actions,mqf2]"
python -m pip install ".[dev,all_extras,github-actions]"

- name: Show dependencies
run: python -m pip list
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ dependencies = [
"scipy >=1.8,<2.0",
"pandas >=1.3.0,<3.0.0",
"scikit-learn >=1.2,<2.0",
"matplotlib",
"pytorch-optimizer >=2.5.1,<4.0.0",
]

Expand All @@ -84,6 +83,7 @@ dependencies = [
#
all_extras = [
"cpflows",
"matplotlib",
"optuna >=3.1.0,<4.0.0",
"optuna-integration",
"statsmodels",
Expand Down
10 changes: 6 additions & 4 deletions pytorch_forecasting/data/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Any, Callable, Dict, List, Tuple, Union
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.exceptions import NotFittedError
Expand All @@ -32,6 +31,7 @@
)
from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler
from pytorch_forecasting.utils import repr_class
from pytorch_forecasting.utils._dependencies import _check_matplotlib


def _find_end_indices(diffs: np.ndarray, max_lengths: np.ndarray, min_length: int) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -1357,9 +1357,7 @@ def decoded_index(self) -> pd.DataFrame:
)
return index

def plot_randomization(
self, betas: Tuple[float, float] = None, length: int = None, min_length: int = None
) -> Tuple[plt.Figure, torch.Tensor]:
def plot_randomization(self, betas: Tuple[float, float] = None, length: int = None, min_length: int = None):
"""
Plot expected randomized length distribution.

Expand All @@ -1372,6 +1370,10 @@ def plot_randomization(
Returns:
Tuple[plt.Figure, torch.Tensor]: tuple of figure and histogram based on 1000 samples
"""
_check_matplotlib("plot_randomization")

import matplotlib.pyplot as plt

if betas is None:
betas = self.randomize_length
if length is None:
Expand Down
30 changes: 26 additions & 4 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from lightning.pytorch.callbacks import BasePredictionWriter, LearningRateFinder
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.parsing import get_init_args
import matplotlib.pyplot as plt
import numpy as np
from numpy import iterable
import pandas as pd
Expand Down Expand Up @@ -55,6 +54,7 @@
groupby_apply,
to_list,
)
from pytorch_forecasting.utils._dependencies import _check_matplotlib

# todo: compile models

Expand Down Expand Up @@ -940,6 +940,12 @@ def log_prediction(
)
else:
log_indices = [0]

mpl_available = _check_matplotlib("plot_prediction", raise_error=False)

if not mpl_available:
return None # don't log matplotlib plots if not available

for idx in log_indices:
fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
tag = f"{self.current_stage} prediction"
Expand Down Expand Up @@ -971,7 +977,7 @@ def plot_prediction(
ax=None,
quantiles_kwargs: Dict[str, Any] = {},
prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:
):
"""
Plot prediction of prediction vs actuals

Expand All @@ -990,6 +996,10 @@ def plot_prediction(
Returns:
matplotlib figure
"""
_check_matplotlib("plot_prediction")

from matplotlib import pyplot as plt

# all true values for y of the first sample in batch
encoder_targets = to_list(x["encoder_target"])
decoder_targets = to_list(x["decoder_target"])
Expand Down Expand Up @@ -1103,6 +1113,14 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None:
layers.append(name)
ave_grads.append(p.grad.abs().cpu().mean())
self.logger.experiment.add_histogram(tag=name, values=p.grad, global_step=self.global_step)

mpl_available = _check_matplotlib("log_gradient_flow", raise_error=False)

if not mpl_available:
return None

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(ave_grads)
ax.set_xlabel("Layers")
Expand Down Expand Up @@ -1842,7 +1860,7 @@ def calculate_prediction_actual_by_variable(

def plot_prediction_actual_by_variable(
self, data: Dict[str, Dict[str, torch.Tensor]], name: str = None, ax=None, log_scale: bool = None
) -> Union[Dict[str, plt.Figure], plt.Figure]:
):
"""
Plot predicions and actual averages by variables

Expand All @@ -1860,6 +1878,10 @@ def plot_prediction_actual_by_variable(
Returns:
Union[Dict[str, plt.Figure], plt.Figure]: matplotlib figure
"""
_check_matplotlib("plot_prediction_actual_by_variable")

from matplotlib import pyplot as plt

if name is None: # run recursion for figures
figs = {name: self.plot_prediction_actual_by_variable(data, name) for name in data["support"].keys()}
return figs
Expand Down Expand Up @@ -2230,7 +2252,7 @@ def plot_prediction(
ax=None,
quantiles_kwargs: Dict[str, Any] = {},
prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:
):
"""
Plot prediction of prediction vs actuals

Expand Down
13 changes: 11 additions & 2 deletions pytorch_forecasting/models/nbeats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import Dict, List

import matplotlib.pyplot as plt
import torch
from torch import nn

Expand All @@ -13,6 +12,7 @@
from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric
from pytorch_forecasting.models.base_model import BaseModel
from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock
from pytorch_forecasting.utils._dependencies import _check_matplotlib


class NBeats(BaseModel):
Expand Down Expand Up @@ -263,6 +263,11 @@ def log_interpretation(self, x, out, batch_idx):
"""
Log interpretation of network predictions in tensorboard.
"""
mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
return None

label = ["val", "train"][self.training]
if self.log_interval > 0 and batch_idx % self.log_interval == 0:
fig = self.plot_interpretation(x, out, idx=0)
Expand All @@ -280,7 +285,7 @@ def plot_interpretation(
idx: int,
ax=None,
plot_seasonality_and_generic_on_secondary_axis: bool = False,
) -> plt.Figure:
):
"""
Plot interpretation.

Expand All @@ -299,6 +304,10 @@ def plot_interpretation(
Returns:
plt.Figure: matplotlib figure
"""
_check_matplotlib("plot_interpretation")

import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots(2, 1, figsize=(6, 8))
else:
Expand Down
13 changes: 11 additions & 2 deletions pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from copy import copy
from typing import Dict, List, Optional, Tuple, Union

from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
Expand All @@ -17,6 +16,7 @@
from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule
from pytorch_forecasting.models.nn.embeddings import MultiEmbedding
from pytorch_forecasting.utils import create_mask, to_list
from pytorch_forecasting.utils._dependencies import _check_matplotlib


class NHiTS(BaseModelWithCovariates):
Expand Down Expand Up @@ -419,7 +419,7 @@ def plot_interpretation(
output: Dict[str, torch.Tensor],
idx: int,
ax=None,
) -> plt.Figure:
):
"""
Plot interpretation.

Expand All @@ -436,6 +436,10 @@ def plot_interpretation(
Returns:
plt.Figure: matplotlib figure
"""
_check_matplotlib("plot_interpretation")

from matplotlib import pyplot as plt

if not isinstance(self.loss, MultiLoss): # not multi-target
prediction = self.to_prediction(dict(prediction=output["prediction"][[idx]].detach()))[0].cpu()
block_forecasts = [
Expand Down Expand Up @@ -535,6 +539,11 @@ def log_interpretation(self, x, out, batch_idx):
"""
Log interpretation of network predictions in tensorboard.
"""
mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
return None

label = ["val", "train"][self.training]
if self.log_interval > 0 and batch_idx % self.log_interval == 0:
fig = self.plot_interpretation(x, out, idx=0)
Expand Down
18 changes: 14 additions & 4 deletions pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from copy import copy
from typing import Dict, List, Tuple, Union

from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
Expand All @@ -24,6 +23,7 @@
VariableSelectionNetwork,
)
from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list
from pytorch_forecasting.utils._dependencies import _check_matplotlib


class TemporalFusionTransformer(BaseModelWithCovariates):
Expand Down Expand Up @@ -690,7 +690,7 @@ def plot_prediction(
show_future_observed: bool = True,
ax=None,
**kwargs,
) -> plt.Figure:
):
"""
Plot actuals vs prediction and attention

Expand All @@ -706,7 +706,6 @@ def plot_prediction(
Returns:
plt.Figure: matplotlib figure
"""

# plot prediction as normal
fig = super().plot_prediction(
x,
Expand Down Expand Up @@ -735,7 +734,7 @@ def plot_prediction(
f.tight_layout()
return fig

def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]) -> Dict[str, plt.Figure]:
def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]):
"""
Make figures that interpret model.

Expand All @@ -748,6 +747,10 @@ def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]) -> Dict[s
Returns:
dictionary of matplotlib figures
"""
_check_matplotlib("plot_interpretation")

import matplotlib.pyplot as plt

figs = {}

# attention
Expand Down Expand Up @@ -813,6 +816,13 @@ def log_interpretation(self, outputs):
interpretation["attention"] = interpretation["attention"] / attention_occurances.pow(2).clamp(1.0)
interpretation["attention"] = interpretation["attention"] / interpretation["attention"].sum()

mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
return None

import matplotlib.pyplot as plt

figs = self.plot_interpretation(interpretation) # make interpretation figures
label = self.current_stage
# log to tensorboard
Expand Down
22 changes: 22 additions & 0 deletions pytorch_forecasting/utils/_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,25 @@ def _get_installed_packages():
MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
"""
return _get_installed_packages_private().copy()


def _check_matplotlib(ref="This feature", raise_error=True):
"""Check if matplotlib is installed.

Parameters
----------
ref : str, optional (default="This feature")
reference to the feature that requires matplotlib, used in error message
raise_error : bool, optional (default=True)
whether to raise an error if matplotlib is not installed

Returns
-------
bool : whether matplotlib is installed
"""
pkgs = _get_installed_packages()

if raise_error and "matplotlib" not in pkgs:
raise ImportError(f"{ref} requires matplotlib. Please install matplotlib with `pip install matplotlib`.")

return "matplotlib" in pkgs
5 changes: 5 additions & 0 deletions tests/test_models/test_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from pytorch_forecasting.models import NBeats
from pytorch_forecasting.utils._dependencies import _get_installed_packages


def test_integration(dataloaders_fixed_window_without_covariates, tmp_path):
Expand Down Expand Up @@ -76,6 +77,10 @@ def test_pickle(model):
pickle.loads(pkl)


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
reason="skip test if required package matplotlib not installed",
)
def test_interpretation(model, dataloaders_fixed_window_without_covariates):
raw_predictions = model.predict(
dataloaders_fixed_window_without_covariates["val"], mode="raw", return_x=True, fast_dev_run=True
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def test_pickle(model):
pickle.loads(pkl)


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
reason="skip test if required package matplotlib not installed",
)
def test_interpretation(model, dataloaders_with_covariates):
raw_predictions = model.predict(dataloaders_with_covariates["val"], mode="raw", return_x=True, fast_dev_run=True)
model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0, add_loss_to_title=True)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ def test_predict_dependency(model, dataloaders_with_covariates, data_with_covari
model.predict_dependency(dataset, variable="agency", values=data_with_covariates.agency.unique()[:2], **kwargs)


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
reason="skip test if required package matplotlib not installed",
)
def test_actual_vs_predicted_plot(model, dataloaders_with_covariates):
prediction = model.predict(dataloaders_with_covariates["val"], return_x=True)
averages = model.calculate_prediction_actual_by_variable(prediction.x, prediction.output)
Expand Down
Loading