From 3cc81b3f0152da121d5fa184ebc5f8578adb46de Mon Sep 17 00:00:00 2001 From: Ewan Thompson Date: Wed, 9 Oct 2024 16:07:55 +0800 Subject: [PATCH] [FIX] AttributeError: 'ExperimentWriter' object has no attribute 'add_figure' This fixes the AttributeError failure by checking that `add_figure` exists before attempting to call it. See https://github.com/sktime/pytorch-forecasting/issues/1256 --- pytorch_forecasting/models/base_model.py | 8 ++++++++ pytorch_forecasting/models/nbeats/__init__.py | 4 ++++ pytorch_forecasting/models/nhits/__init__.py | 4 ++++ .../models/temporal_fusion_transformer/__init__.py | 4 ++++ 4 files changed, 20 insertions(+) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index a82196cf..e1368e3d 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -976,6 +976,10 @@ def log_prediction( if not mpl_available: return None # don't log matplotlib plots if not available + # Don't log figures if add_figure is not available + if not hasattr(self.logger.experiment, "add_figure"): + return None + for idx in log_indices: fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs) tag = f"{self.current_stage} prediction" @@ -1149,6 +1153,10 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None: if not mpl_available: return None + # Don't log figures if add_figure is not available + if not hasattr(self.logger.experiment, "add_figure"): + return None + import matplotlib.pyplot as plt fig, ax = plt.subplots() diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 149b4fbc..b75d173a 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -268,6 +268,10 @@ def log_interpretation(self, x, out, batch_idx): if not mpl_available: return None + # Don't log figures if add_figure is not available + if not hasattr(self.logger.experiment, "add_figure"): + return None + label = ["val", "train"][self.training] if self.log_interval > 0 and batch_idx % self.log_interval == 0: fig = self.plot_interpretation(x, out, idx=0) diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 68816f22..58a8ad71 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -544,6 +544,10 @@ def log_interpretation(self, x, out, batch_idx): if not mpl_available: return None + # Don't log figures if add_figure is not available + if not hasattr(self.logger.experiment, "add_figure"): + return None + label = ["val", "train"][self.training] if self.log_interval > 0 and batch_idx % self.log_interval == 0: fig = self.plot_interpretation(x, out, idx=0) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index cc506612..83877151 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -821,6 +821,10 @@ def log_interpretation(self, outputs): if not mpl_available: return None + # Don't log figures if add_figure is not available + if not hasattr(self.logger.experiment, "add_figure"): + return None + import matplotlib.pyplot as plt figs = self.plot_interpretation(interpretation) # make interpretation figures