Skip to content

Commit

Permalink
[FIX] AttributeError: 'ExperimentWriter' object has no attribute 'add…
Browse files Browse the repository at this point in the history
…_figure'

This fixes the AttributeError failure by checking that `add_figure` exists before attempting to call it.

See #1256
  • Loading branch information
ewth committed Oct 9, 2024
1 parent 53a1c41 commit 3cc81b3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,10 @@ def log_prediction(
if not mpl_available:
return None # don't log matplotlib plots if not available

# Don't log figures if add_figure is not available
if not hasattr(self.logger.experiment, "add_figure"):
return None

for idx in log_indices:
fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
tag = f"{self.current_stage} prediction"
Expand Down Expand Up @@ -1149,6 +1153,10 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None:
if not mpl_available:
return None

# Don't log figures if add_figure is not available
if not hasattr(self.logger.experiment, "add_figure"):
return None

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
Expand Down
4 changes: 4 additions & 0 deletions pytorch_forecasting/models/nbeats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def log_interpretation(self, x, out, batch_idx):
if not mpl_available:
return None

# Don't log figures if add_figure is not available
if not hasattr(self.logger.experiment, "add_figure"):
return None

label = ["val", "train"][self.training]
if self.log_interval > 0 and batch_idx % self.log_interval == 0:
fig = self.plot_interpretation(x, out, idx=0)
Expand Down
4 changes: 4 additions & 0 deletions pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,10 @@ def log_interpretation(self, x, out, batch_idx):
if not mpl_available:
return None

# Don't log figures if add_figure is not available
if not hasattr(self.logger.experiment, "add_figure"):
return None

label = ["val", "train"][self.training]
if self.log_interval > 0 and batch_idx % self.log_interval == 0:
fig = self.plot_interpretation(x, out, idx=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,10 @@ def log_interpretation(self, outputs):
if not mpl_available:
return None

# Don't log figures if add_figure is not available
if not hasattr(self.logger.experiment, "add_figure"):
return None

import matplotlib.pyplot as plt

figs = self.plot_interpretation(interpretation) # make interpretation figures
Expand Down

0 comments on commit 3cc81b3

Please sign in to comment.