Skip to content

Commit

Permalink
added tune.py to tune hyperparameters | started to work on a lr_sched…
Browse files Browse the repository at this point in the history
…uler | now the hyperparameters should be logged
  • Loading branch information
MaloOLIVIER committed Dec 6, 2024
1 parent 590e04b commit b3b3660
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 1 deletion.
1 change: 1 addition & 0 deletions configs/optimizer/lr_scheduler.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
4 changes: 3 additions & 1 deletion configs/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@ defaults:
- rich_model_summary
- logging: tensorboard
- metrics: f1
- optimizer: adam
- optimizer:
- adam
- lr_scheduler
- _self_ # priority is given to run.yaml for overrides
4 changes: 4 additions & 0 deletions hungarian_net/lightning_modules/hnet_gru_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def __init__(
optimizer: partial[optim.Optimizer] = partial(optim.Adam),
):
super().__init__()

# Automatically save hyperparameters except for non-serializable objects
self.save_hyperparameters(ignore=["metrics", "device", "optimizer"])

self._device = device
self.model = HNetGRU(max_len=max_len).to(self._device)

Expand Down
147 changes: 147 additions & 0 deletions tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from functools import partial
import os
import random
import warnings
from typing import List

import hydra
import lightning as L
import numpy as np
import ray
import torch
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from ray import tune
from ray.tune import CLIReporter
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.schedulers import ASHAScheduler
from torchmetrics import MetricCollection


@hydra.main(
config_path="configs",
config_name="run.yaml",
version_base="1.3",
)
def main(cfg: DictConfig):
"""
Instantiate all necessary modules, train and test the model.
Args:
cfg (DictConfig): Hydra configuration object, passed in by the @hydra.main decorator
"""

# TODO: leverager RayTune, Docker

# Initialize Ray
ray.init()

# Define the hyperparameter search space
config = {
"learning_rate": tune.loguniform(1e-5, 1e-1),
"batch_size": tune.choice([64, 128, 256]),
}

# Set up the scheduler and reporter for Ray Tune
scheduler = ASHAScheduler(
metric="validation_loss",
mode="min",
max_t=cfg.nb_epochs,
grace_period=1,
reduction_factor=2,
)

reporter = CLIReporter(
metric_columns=["validation_loss", "training_iteration"]
)

# Instantiate LightningDataModule
lightning_datamodule: L.LightningDataModule = hydra.utils.instantiate(
cfg.lightning_datamodule,
batch_size=config["batch_size"].sample(),
)

# Instantiate LightningModule
metrics: MetricCollection = MetricCollection(
dict(hydra.utils.instantiate(cfg.metrics))
)
lightning_module: L.LightningModule = hydra.utils.instantiate(
cfg.lightning_module, metrics=metrics,
optimizer=partial(torch.optim.Adam, lr=config["learning_rate"].sample()),
)

# Instantiate Trainer with Ray Tune callback
tune_callback = TuneReportCallback({"validation_loss": "validation_loss"}, on="validation_end")
callbacks: List[L.Callback] = list(hydra.utils.instantiate(cfg.callbacks).values()) + [tune_callback]
logger: Logger = hydra.utils.instantiate(cfg.logging.logger)
trainer: L.Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
_convert_="partial"
)

# Define the training function for Ray Tune
def train_tune(config):
trainer.fit(lightning_module, datamodule=lightning_datamodule)
trainer.test(ckpt_path="best", datamodule=lightning_datamodule)

# Run hyperparameter tuning
result = tune.run(
train_tune,
resources_per_trial={"cpu": cfg.num_workers, "gpu": 1 if torch.cuda.is_available() else 0},
config=config,
num_samples=10,
scheduler=scheduler,
progress_reporter=reporter,
name="tune_hnet_training",
)

best_trial = result.get_best_trial("validation_loss", "min", "last")
print(f"Best trial config: {best_trial.config}")
print(f"Best trial final validation loss: {best_trial.last_result['validation_loss']}")

# Shutdown Ray
ray.shutdown()


def set_seed(seed=42):
L.seed_everything(seed, workers=True)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


def setup_environment():
"""
Setup environment for training.
"""
# Set Random Seed
set_seed()

# Check wether to run on cpu or gpu
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Using device:", device)

warnings.filterwarnings("ignore")

# Set environment variables for full trace of errors
os.environ["HYDRA_FULL_ERROR"] = "1"

# Enable CUDNN backend
torch.backends.cudnn.enabled = True

# Enable CUDNN benchmarking to choose the best algorithm for every new input size
# e.g. for convolutional layers chose between Winograd, GEMM-based, or FFT algorithms
torch.backends.cudnn.benchmark = True

return device


if __name__ == "__main__":
setup_environment()
main()

0 comments on commit b3b3660

Please sign in to comment.