From aa7a5499755e164b8de493151f643c503966c378 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 23 Feb 2024 21:11:34 +0000 Subject: [PATCH] Make mypy happy I think --- docs/tutorials/custom_segmentation_trainer.ipynb | 14 ++++++++++---- docs/tutorials/custom_segmentation_trainer.py | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/docs/tutorials/custom_segmentation_trainer.ipynb b/docs/tutorials/custom_segmentation_trainer.ipynb index 6c50f6df227..a20c679d63e 100644 --- a/docs/tutorials/custom_segmentation_trainer.ipynb +++ b/docs/tutorials/custom_segmentation_trainer.ipynb @@ -68,6 +68,8 @@ "metadata": {}, "outputs": [], "source": [ + "from typing import Any, Union, Sequence\n", + "from lightning.pytorch.callbacks.callback import Callback\n", "import lightning\n", "import lightning.pytorch as pl\n", "from lightning.pytorch.callbacks import ModelCheckpoint\n", @@ -134,7 +136,7 @@ "class CustomSemanticSegmentationTask(SemanticSegmentationTask):\n", "\n", " # any keywords we add here between *args and **kwargs will be found in self.hparams\n", - " def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None:\n", + " def __init__(self, *args: Any, tmax: int=50, eta_min: float=1e-6, **kwargs: Any) -> None:\n", " super().__init__(*args, **kwargs) # pass args and kwargs to the parent class\n", "\n", " def configure_optimizers(\n", @@ -185,7 +187,7 @@ " self.val_metrics = self.train_metrics.clone(prefix=\"val_\")\n", " self.test_metrics = self.train_metrics.clone(prefix=\"test_\")\n", "\n", - " def configure_callbacks(self):\n", + " def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:\n", " \"\"\"Initialize callbacks for saving the best and latest models.\n", "\n", " Returns:\n", @@ -198,8 +200,12 @@ "\n", " def on_train_epoch_start(self) -> None:\n", " \"\"\"Log the learning rate at the start of each training epoch.\"\"\"\n", - " lr = self.optimizers().param_groups[0][\"lr\"]\n", - " self.logger.experiment.add_scalar(\"lr\", lr, self.current_epoch)" + " optimizers = self.optimizers()\n", + " if isinstance(optimizers, list):\n", + " lr = optimizers[0].param_groups[0][\"lr\"]\n", + " else:\n", + " lr = optimizers.param_groups[0][\"lr\"]\n", + " self.logger.experiment.add_scalar(\"lr\", lr, self.current_epoch) # type: ignore" ] }, { diff --git a/docs/tutorials/custom_segmentation_trainer.py b/docs/tutorials/custom_segmentation_trainer.py index 5efa90fadf1..95e45ac887e 100644 --- a/docs/tutorials/custom_segmentation_trainer.py +++ b/docs/tutorials/custom_segmentation_trainer.py @@ -40,6 +40,8 @@ # UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details. import warnings +from typing import Any, Union, Sequence +from lightning.pytorch.callbacks.callback import Callback import lightning import lightning.pytorch as pl from lightning.pytorch.callbacks import ModelCheckpoint @@ -79,7 +81,7 @@ class CustomSemanticSegmentationTask(SemanticSegmentationTask): # any keywords we add here between *args and **kwargs will be found in self.hparams - def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None: + def __init__(self, *args: Any, tmax: int=50, eta_min: float=1e-6, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # pass args and kwargs to the parent class def configure_optimizers( @@ -130,7 +132,7 @@ def configure_metrics(self) -> None: self.val_metrics = self.train_metrics.clone(prefix="val_") self.test_metrics = self.train_metrics.clone(prefix="test_") - def configure_callbacks(self): + def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: """Initialize callbacks for saving the best and latest models. Returns: @@ -143,8 +145,12 @@ def configure_callbacks(self): def on_train_epoch_start(self) -> None: """Log the learning rate at the start of each training epoch.""" - lr = self.optimizers().param_groups[0]["lr"] - self.logger.experiment.add_scalar("lr", lr, self.current_epoch) + optimizers = self.optimizers() + if isinstance(optimizers, list): + lr = optimizers[0].param_groups[0]["lr"] + else: + lr = optimizers.param_groups[0]["lr"] + self.logger.experiment.add_scalar("lr", lr, self.current_epoch) # type: ignore # ## Train model