From 9fbccbf29cac681dcec3c30011d672e1daca6f4f Mon Sep 17 00:00:00 2001 From: clementchadebec Date: Thu, 13 Apr 2023 18:06:36 +0200 Subject: [PATCH 1/2] integrate ray --- examples/hp_tuning_with_ray.py | 73 +++++++++++++++++++ .../trainers/base_trainer/base_trainer.py | 5 +- src/pythae/trainers/training_callbacks.py | 4 +- 3 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 examples/hp_tuning_with_ray.py diff --git a/examples/hp_tuning_with_ray.py b/examples/hp_tuning_with_ray.py new file mode 100644 index 00000000..89e05857 --- /dev/null +++ b/examples/hp_tuning_with_ray.py @@ -0,0 +1,73 @@ +from pythae.pipelines import TrainingPipeline +from pythae.models import VAE, VAEConfig +from pythae.trainers import BaseTrainerConfig, BaseTrainer +from pythae.data.datasets import BaseDataset +import torch +import numpy as np + +import torchvision.datasets as datasets + +from ray import air, tune +from ray.tune.schedulers import ASHAScheduler +from pythae.trainers.training_callbacks import TrainingCallback + +class RayCallback(TrainingCallback): + + def __init__(self) -> None: + super().__init__() + + def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs): + metrics = kwargs.pop("metrics") + tune.report(eval_epoch_loss=metrics["eval_epoch_loss"]) + +def train_ray(config): + + mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None) + + train_dataset = BaseDataset(mnist_trainset.data[:1000].reshape(-1, 1, 28, 28) / 255., torch.ones(1000)) + eval_dataset = BaseDataset(mnist_trainset.data[-1000:].reshape(-1, 1, 28, 28) / 255., torch.ones(1000)) + + my_training_config = BaseTrainerConfig( + output_dir='my_model', + num_epochs=50, + learning_rate=config["lr"], + per_device_train_batch_size=200, + per_device_eval_batch_size=200, + steps_saving=None, + optimizer_cls="AdamW", + optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)}, + scheduler_cls="ReduceLROnPlateau", + scheduler_params={"patience": 5, "factor": 0.5} + ) + + my_vae_config = model_config = VAEConfig( + input_dim=(1, 28, 28), + latent_dim=10 + ) + + my_vae_model = VAE( + model_config=my_vae_config + ) + + callbacks = [RayCallback()] + + trainer = BaseTrainer(my_vae_model, train_dataset, eval_dataset, my_training_config, callbacks=callbacks) + + trainer.train() + + +search_space = { + "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())), + "momentum": tune.uniform(0.1, 0.9), +} + +tuner = tune.Tuner( + train_ray, + tune_config=tune.TuneConfig( + num_samples=20, + scheduler=ASHAScheduler(metric="eval_epoch_loss", mode="min"), + ), + param_space=search_space, +) + +results = tuner.fit() \ No newline at end of file diff --git a/src/pythae/trainers/base_trainer/base_trainer.py b/src/pythae/trainers/base_trainer/base_trainer.py index 434eac1d..6f39c5d4 100644 --- a/src/pythae/trainers/base_trainer/base_trainer.py +++ b/src/pythae/trainers/base_trainer/base_trainer.py @@ -476,7 +476,10 @@ def train(self, log_output_dir: str = None): global_step=epoch, ) - self.callback_handler.on_epoch_end(training_config=self.training_config) + self.callback_handler.on_epoch_end( + training_config=self.training_config, + metrics=metrics + ) # save checkpoints if ( diff --git a/src/pythae/trainers/training_callbacks.py b/src/pythae/trainers/training_callbacks.py index fe6bde78..09ddaa06 100644 --- a/src/pythae/trainers/training_callbacks.py +++ b/src/pythae/trainers/training_callbacks.py @@ -169,7 +169,7 @@ def on_epoch_begin(self, training_config: BaseTrainerConfig, **kwargs): self.call_event("on_epoch_begin", training_config, **kwargs) def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs): - self.call_event("on_epoch_end", training_config) + self.call_event("on_epoch_end", training_config, **kwargs) def on_evaluate(self, training_config: BaseTrainerConfig, **kwargs): self.call_event("on_evaluate", **kwargs) @@ -285,7 +285,7 @@ def on_eval_step_end(self, training_config: BaseTrainerConfig, **kwargs): if self.eval_progress_bar is not None: self.eval_progress_bar.update(1) - def on_epoch_end(self, training_config: BaseTrainerConfig, **kwags): + def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs): if self.train_progress_bar is not None: self.train_progress_bar.close() From cc9ee609dcfa8197abefbc8a706b30f152af19e1 Mon Sep 17 00:00:00 2001 From: clementchadebec Date: Thu, 13 Apr 2023 18:28:51 +0200 Subject: [PATCH 2/2] put example in good folder --- examples/{ => scripts}/hp_tuning_with_ray.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{ => scripts}/hp_tuning_with_ray.py (100%) diff --git a/examples/hp_tuning_with_ray.py b/examples/scripts/hp_tuning_with_ray.py similarity index 100% rename from examples/hp_tuning_with_ray.py rename to examples/scripts/hp_tuning_with_ray.py