From 5b14f753669c1b270dd771f9a37e497f7a39f09f Mon Sep 17 00:00:00 2001 From: "Carlos D. Escobar-Valbuena" Date: Sun, 5 Mar 2023 18:10:14 +0000 Subject: [PATCH 1/5] Updated logger to work with mlflow log_figure --- pytorch_forecasting/models/base_model.py | 16 +++++++--------- pytorch_forecasting/models/nhits/__init__.py | 16 +++++++--------- .../temporal_fusion_transformer/__init__.py | 10 ++++++---- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index b3d39130..115eabe0 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -722,16 +722,14 @@ def log_prediction( tag += f" of item {idx} in batch {batch_idx}" if isinstance(fig, (list, tuple)): for idx, f in enumerate(fig): - self.logger.experiment.add_figure( - f"{self.target_names[idx]} {tag}", - f, - global_step=self.global_step, + self.logger.experiment.log_figure( + image=f, + artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: - self.logger.experiment.add_figure( - tag, - fig, - global_step=self.global_step, + self.logger.experiment.log_figure( + image=f, + artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) def plot_prediction( @@ -883,7 +881,7 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None: ax.set_ylabel("Average gradient") ax.set_yscale("log") ax.set_title("Gradient flow") - self.logger.experiment.add_figure("Gradient flow", fig, global_step=self.global_step) + self.logger.experiment.log_figure(image=fig, artifact_file=f"gradient_flow.png") def on_after_backward(self): """ diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 477a362d..e5150179 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -523,17 +523,15 @@ def log_interpretation(self, x, out, batch_idx): name += f"step {self.global_step}" else: name += f"batch {batch_idx}" - self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + self.logger.experiment.log_figure(image=fig, artifact_file=f"{name}.png") if isinstance(fig, (list, tuple)): for idx, f in enumerate(fig): - self.logger.experiment.add_figure( - f"{self.target_names[idx]} {name}", - f, - global_step=self.global_step, + self.logger.experiment.log_figure( + image=f, + artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: - self.logger.experiment.add_figure( - name, - fig, - global_step=self.global_step, + self.logger.experiment.log_figure( + image=f, + artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 8c816dbe..0a88a0c4 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -817,8 +817,9 @@ def log_interpretation(self, outputs): label = self.current_stage # log to tensorboard for name, fig in figs.items(): - self.logger.experiment.add_figure( - f"{label.capitalize()} {name} importance", fig, global_step=self.global_step + self.logger.experiment.log_figure( + image=fig, + artifact_file=f"{label.capitalize()}_{name}_step_{self.global_step}.png" ) # log lengths of encoder/decoder @@ -839,8 +840,9 @@ def log_interpretation(self, outputs): ax.set_ylabel("Number of samples") ax.set_title(f"{type.capitalize()} length distribution in {label} epoch") - self.logger.experiment.add_figure( - f"{label.capitalize()} {type} length distribution", fig, global_step=self.global_step + self.logger.experiment.log_figure( + image=fig, + artifact_file=f"{label.capitalize()}_{type}_length_distribution_step_{self.global_step}.png", ) def log_embeddings(self): From c13ca89c41d23bfc0336d56d7d8692d774241b61 Mon Sep 17 00:00:00 2001 From: "Carlos D. Escobar-Valbuena" Date: Sun, 5 Mar 2023 18:43:47 +0000 Subject: [PATCH 2/5] Updated typo from log_figure to log_image --- pytorch_forecasting/models/base_model.py | 8 ++++---- pytorch_forecasting/models/nhits/__init__.py | 8 ++++---- .../models/temporal_fusion_transformer/__init__.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 115eabe0..1b6a6aaa 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -722,13 +722,13 @@ def log_prediction( tag += f" of item {idx} in batch {batch_idx}" if isinstance(fig, (list, tuple)): for idx, f in enumerate(fig): - self.logger.experiment.log_figure( + self.logger.experiment.log_image( image=f, artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: - self.logger.experiment.log_figure( - image=f, + self.logger.experiment.log_image( + image=fig, artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) @@ -881,7 +881,7 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None: ax.set_ylabel("Average gradient") ax.set_yscale("log") ax.set_title("Gradient flow") - self.logger.experiment.log_figure(image=fig, artifact_file=f"gradient_flow.png") + self.logger.experiment.log_image(image=fig, artifact_file=f"gradient_flow.png") def on_after_backward(self): """ diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index e5150179..26229293 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -523,15 +523,15 @@ def log_interpretation(self, x, out, batch_idx): name += f"step {self.global_step}" else: name += f"batch {batch_idx}" - self.logger.experiment.log_figure(image=fig, artifact_file=f"{name}.png") + self.logger.experiment.log_image(image=fig, artifact_file=f"{name}.png") if isinstance(fig, (list, tuple)): for idx, f in enumerate(fig): - self.logger.experiment.log_figure( + self.logger.experiment.log_image( image=f, artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: - self.logger.experiment.log_figure( - image=f, + self.logger.experiment.log_image( + image=fig, artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 0a88a0c4..ab184293 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -817,7 +817,7 @@ def log_interpretation(self, outputs): label = self.current_stage # log to tensorboard for name, fig in figs.items(): - self.logger.experiment.log_figure( + self.logger.experiment.log_image( image=fig, artifact_file=f"{label.capitalize()}_{name}_step_{self.global_step}.png" ) @@ -840,7 +840,7 @@ def log_interpretation(self, outputs): ax.set_ylabel("Number of samples") ax.set_title(f"{type.capitalize()} length distribution in {label} epoch") - self.logger.experiment.log_figure( + self.logger.experiment.log_image( image=fig, artifact_file=f"{label.capitalize()}_{type}_length_distribution_step_{self.global_step}.png", ) From 16fc734d1fca4178c91fe3236602678a8da7488b Mon Sep 17 00:00:00 2001 From: "Carlos D. Escobar-Valbuena" Date: Mon, 6 Mar 2023 00:56:50 +0000 Subject: [PATCH 3/5] Updates to make it work with mlflow log_image --- pytorch_forecasting/models/base_model.py | 4 +++- pytorch_forecasting/models/nhits/__init__.py | 2 ++ .../models/temporal_fusion_transformer/__init__.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 1b6a6aaa..fce89bcc 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -723,11 +723,13 @@ def log_prediction( if isinstance(fig, (list, tuple)): for idx, f in enumerate(fig): self.logger.experiment.log_image( + run_id=self.logger.run_id, image=f, artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: self.logger.experiment.log_image( + run_id=self.logger.run_id, image=fig, artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) @@ -881,7 +883,7 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None: ax.set_ylabel("Average gradient") ax.set_yscale("log") ax.set_title("Gradient flow") - self.logger.experiment.log_image(image=fig, artifact_file=f"gradient_flow.png") + self.logger.experiment.log_image(run_id=self.logger.run_id, image=fig, artifact_file=f"gradient_flow.png") def on_after_backward(self): """ diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 26229293..746e7589 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -527,11 +527,13 @@ def log_interpretation(self, x, out, batch_idx): if isinstance(fig, (list, tuple)): for idx, f in enumerate(fig): self.logger.experiment.log_image( + run_id=self.logger.run_id, image=f, artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: self.logger.experiment.log_image( + run_id=self.logger.run_id, image=fig, artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index ab184293..cdcbf159 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -818,6 +818,7 @@ def log_interpretation(self, outputs): # log to tensorboard for name, fig in figs.items(): self.logger.experiment.log_image( + run_id=self.logger.run_id, image=fig, artifact_file=f"{label.capitalize()}_{name}_step_{self.global_step}.png" ) @@ -841,6 +842,7 @@ def log_interpretation(self, outputs): ax.set_title(f"{type.capitalize()} length distribution in {label} epoch") self.logger.experiment.log_image( + run_id=self.logger.run_id, image=fig, artifact_file=f"{label.capitalize()}_{type}_length_distribution_step_{self.global_step}.png", ) From 3479326ecd1b1ef30ef19267c081a645488732d0 Mon Sep 17 00:00:00 2001 From: "Carlos D. Escobar-Valbuena" Date: Mon, 6 Mar 2023 01:11:14 +0000 Subject: [PATCH 4/5] Added fig2img --- pytorch_forecasting/models/base_model.py | 16 +++++++++++++--- pytorch_forecasting/models/nhits/__init__.py | 13 +++++++++++-- .../temporal_fusion_transformer/__init__.py | 13 +++++++++++-- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index fce89bcc..39587982 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -152,6 +152,16 @@ def _concatenate_output( } +def fig2img(fig): + """Convert a Matplotlib figure to a PIL Image and return it""" + import io + buf = io.BytesIO() + fig.savefig(buf) + buf.seek(0) + img = Image.open(buf) + return img + + class BaseModel(InitialParameterRepresenterMixIn, LightningModule, TupleOutputMixIn): """ BaseModel from which new timeseries models should inherit from. @@ -724,13 +734,13 @@ def log_prediction( for idx, f in enumerate(fig): self.logger.experiment.log_image( run_id=self.logger.run_id, - image=f, + image=fig2img(f), artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: self.logger.experiment.log_image( run_id=self.logger.run_id, - image=fig, + image=fig2img(fig), artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) @@ -883,7 +893,7 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None: ax.set_ylabel("Average gradient") ax.set_yscale("log") ax.set_title("Gradient flow") - self.logger.experiment.log_image(run_id=self.logger.run_id, image=fig, artifact_file=f"gradient_flow.png") + self.logger.experiment.log_image(run_id=self.logger.run_id, image=fig2img(fig), artifact_file=f"gradient_flow.png") def on_after_backward(self): """ diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 746e7589..3e7ba373 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -18,6 +18,15 @@ from pytorch_forecasting.utils import create_mask, detach, to_list +def fig2img(fig): + """Convert a Matplotlib figure to a PIL Image and return it""" + import io + buf = io.BytesIO() + fig.savefig(buf) + buf.seek(0) + img = Image.open(buf) + return img + class NHiTS(BaseModelWithCovariates): def __init__( self, @@ -528,12 +537,12 @@ def log_interpretation(self, x, out, batch_idx): for idx, f in enumerate(fig): self.logger.experiment.log_image( run_id=self.logger.run_id, - image=f, + image=fig2img(f), artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) else: self.logger.experiment.log_image( run_id=self.logger.run_id, - image=fig, + image=fig2img(fig), artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png" ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index cdcbf159..b160f501 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -25,6 +25,15 @@ ) from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list +def fig2img(fig): + """Convert a Matplotlib figure to a PIL Image and return it""" + import io + buf = io.BytesIO() + fig.savefig(buf) + buf.seek(0) + img = Image.open(buf) + return img + class TemporalFusionTransformer(BaseModelWithCovariates): def __init__( @@ -819,7 +828,7 @@ def log_interpretation(self, outputs): for name, fig in figs.items(): self.logger.experiment.log_image( run_id=self.logger.run_id, - image=fig, + image=fig2img(fig), artifact_file=f"{label.capitalize()}_{name}_step_{self.global_step}.png" ) @@ -843,7 +852,7 @@ def log_interpretation(self, outputs): self.logger.experiment.log_image( run_id=self.logger.run_id, - image=fig, + image=fig2img(fig), artifact_file=f"{label.capitalize()}_{type}_length_distribution_step_{self.global_step}.png", ) From 7720d07a69096b387f9ccaedf79ccaa564e268fc Mon Sep 17 00:00:00 2001 From: "Carlos D. Escobar-Valbuena" Date: Mon, 6 Mar 2023 01:21:48 +0000 Subject: [PATCH 5/5] . --- pytorch_forecasting/models/base_model.py | 3 +++ pytorch_forecasting/models/nhits/__init__.py | 3 +++ .../models/temporal_fusion_transformer/__init__.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 39587982..e4384a7a 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -153,6 +153,9 @@ def _concatenate_output( def fig2img(fig): + import numpy as np + import matplotlib.pyplot as plt + from PIL import Image """Convert a Matplotlib figure to a PIL Image and return it""" import io buf = io.BytesIO() diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 3e7ba373..113e7c28 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -19,6 +19,9 @@ def fig2img(fig): + import numpy as np + import matplotlib.pyplot as plt + from PIL import Image """Convert a Matplotlib figure to a PIL Image and return it""" import io buf = io.BytesIO() diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index b160f501..78514f95 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -26,6 +26,9 @@ from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list def fig2img(fig): + import numpy as np + import matplotlib.pyplot as plt + from PIL import Image """Convert a Matplotlib figure to a PIL Image and return it""" import io buf = io.BytesIO()