From 6b5fb29f258c70f7dc322b6b667204e7461746fe Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:28:32 +0200 Subject: [PATCH] Print learning rate decreases (#365) --- .../alchemical_model/default-hypers.yaml | 4 ++-- .../experimental/alchemical_model/trainer.py | 12 +++++++++-- .../soap_bpnn/default-hypers.yaml | 4 ++-- .../experimental/soap_bpnn/trainer.py | 21 +++++++++++++------ 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/metatrain/experimental/alchemical_model/default-hypers.yaml b/src/metatrain/experimental/alchemical_model/default-hypers.yaml index 9ffe7be8..4c7c14fb 100644 --- a/src/metatrain/experimental/alchemical_model/default-hypers.yaml +++ b/src/metatrain/experimental/alchemical_model/default-hypers.yaml @@ -19,8 +19,8 @@ training: batch_size: 8 num_epochs: 100 learning_rate: 0.001 - early_stopping_patience: 50 - scheduler_patience: 10 + early_stopping_patience: 200 + scheduler_patience: 100 scheduler_factor: 0.8 log_interval: 5 checkpoint_interval: 25 diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 7b722911..3ed190c0 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -216,6 +216,10 @@ def train( # per-atom targets: per_structure_targets = self.hypers["per_structure_targets"] + # Log the initial learning rate: + old_lr = optimizer.param_groups[0]["lr"] + logger.info(f"Initial learning rate: {old_lr}") + start_epoch = 0 if self.epoch is None else self.epoch + 1 # Train the model: @@ -322,8 +326,6 @@ def train( ) ) - lr_scheduler.step(val_loss) - # Now we log the information: finalized_train_info = {"loss": train_loss, **finalized_train_info} finalized_val_info = { @@ -344,6 +346,12 @@ def train( epoch=epoch, ) + lr_scheduler.step(val_loss) + new_lr = lr_scheduler.get_last_lr()[0] + if new_lr != old_lr: + logger.info(f"Changing learning rate from {old_lr} to {new_lr}") + old_lr = new_lr + if epoch % self.hypers["checkpoint_interval"] == 0: self.optimizer_state_dict = optimizer.state_dict() self.scheduler_state_dict = lr_scheduler.state_dict() diff --git a/src/metatrain/experimental/soap_bpnn/default-hypers.yaml b/src/metatrain/experimental/soap_bpnn/default-hypers.yaml index 33198a80..74bf2301 100644 --- a/src/metatrain/experimental/soap_bpnn/default-hypers.yaml +++ b/src/metatrain/experimental/soap_bpnn/default-hypers.yaml @@ -27,8 +27,8 @@ training: batch_size: 8 num_epochs: 100 learning_rate: 0.001 - early_stopping_patience: 50 - scheduler_patience: 10 + early_stopping_patience: 200 + scheduler_patience: 100 scheduler_factor: 0.8 log_interval: 5 checkpoint_interval: 25 diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index d4f28acb..aed858bc 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -234,6 +234,10 @@ def train( # per-atom targets: per_structure_targets = self.hypers["per_structure_targets"] + # Log the initial learning rate: + old_lr = optimizer.param_groups[0]["lr"] + logger.info(f"Initial learning rate: {old_lr}") + start_epoch = 0 if self.epoch is None else self.epoch + 1 # Train the model: @@ -357,8 +361,6 @@ def train( ) ) - lr_scheduler.step(val_loss) - # Now we log the information: finalized_train_info = {"loss": train_loss, **finalized_train_info} finalized_val_info = {"loss": val_loss, **finalized_val_info} @@ -377,16 +379,23 @@ def train( rank=rank, ) + lr_scheduler.step(val_loss) + new_lr = lr_scheduler.get_last_lr()[0] + if new_lr != old_lr: + logger.info(f"Changing learning rate from {old_lr} to {new_lr}") + old_lr = new_lr + if epoch % self.hypers["checkpoint_interval"] == 0: if is_distributed: torch.distributed.barrier() self.optimizer_state_dict = optimizer.state_dict() self.scheduler_state_dict = lr_scheduler.state_dict() self.epoch = epoch - self.save_checkpoint( - (model.module if is_distributed else model), - Path(checkpoint_dir) / f"model_{epoch}.ckpt", - ) + if rank == 0: + self.save_checkpoint( + (model.module if is_distributed else model), + Path(checkpoint_dir) / f"model_{epoch}.ckpt", + ) # early stopping criterion: if val_loss < best_val_loss: