Skip to content

Commit

Permalink
Print learning rate decreases (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Oct 18, 2024
1 parent fc17b49 commit 6b5fb29
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ training:
batch_size: 8
num_epochs: 100
learning_rate: 0.001
early_stopping_patience: 50
scheduler_patience: 10
early_stopping_patience: 200
scheduler_patience: 100
scheduler_factor: 0.8
log_interval: 5
checkpoint_interval: 25
Expand Down
12 changes: 10 additions & 2 deletions src/metatrain/experimental/alchemical_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def train(
# per-atom targets:
per_structure_targets = self.hypers["per_structure_targets"]

# Log the initial learning rate:
old_lr = optimizer.param_groups[0]["lr"]
logger.info(f"Initial learning rate: {old_lr}")

start_epoch = 0 if self.epoch is None else self.epoch + 1

# Train the model:
Expand Down Expand Up @@ -322,8 +326,6 @@ def train(
)
)

lr_scheduler.step(val_loss)

# Now we log the information:
finalized_train_info = {"loss": train_loss, **finalized_train_info}
finalized_val_info = {
Expand All @@ -344,6 +346,12 @@ def train(
epoch=epoch,
)

lr_scheduler.step(val_loss)
new_lr = lr_scheduler.get_last_lr()[0]
if new_lr != old_lr:
logger.info(f"Changing learning rate from {old_lr} to {new_lr}")
old_lr = new_lr

if epoch % self.hypers["checkpoint_interval"] == 0:
self.optimizer_state_dict = optimizer.state_dict()
self.scheduler_state_dict = lr_scheduler.state_dict()
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/experimental/soap_bpnn/default-hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ training:
batch_size: 8
num_epochs: 100
learning_rate: 0.001
early_stopping_patience: 50
scheduler_patience: 10
early_stopping_patience: 200
scheduler_patience: 100
scheduler_factor: 0.8
log_interval: 5
checkpoint_interval: 25
Expand Down
21 changes: 15 additions & 6 deletions src/metatrain/experimental/soap_bpnn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def train(
# per-atom targets:
per_structure_targets = self.hypers["per_structure_targets"]

# Log the initial learning rate:
old_lr = optimizer.param_groups[0]["lr"]
logger.info(f"Initial learning rate: {old_lr}")

start_epoch = 0 if self.epoch is None else self.epoch + 1

# Train the model:
Expand Down Expand Up @@ -357,8 +361,6 @@ def train(
)
)

lr_scheduler.step(val_loss)

# Now we log the information:
finalized_train_info = {"loss": train_loss, **finalized_train_info}
finalized_val_info = {"loss": val_loss, **finalized_val_info}
Expand All @@ -377,16 +379,23 @@ def train(
rank=rank,
)

lr_scheduler.step(val_loss)
new_lr = lr_scheduler.get_last_lr()[0]
if new_lr != old_lr:
logger.info(f"Changing learning rate from {old_lr} to {new_lr}")
old_lr = new_lr

if epoch % self.hypers["checkpoint_interval"] == 0:
if is_distributed:
torch.distributed.barrier()
self.optimizer_state_dict = optimizer.state_dict()
self.scheduler_state_dict = lr_scheduler.state_dict()
self.epoch = epoch
self.save_checkpoint(
(model.module if is_distributed else model),
Path(checkpoint_dir) / f"model_{epoch}.ckpt",
)
if rank == 0:
self.save_checkpoint(
(model.module if is_distributed else model),
Path(checkpoint_dir) / f"model_{epoch}.ckpt",
)

# early stopping criterion:
if val_loss < best_val_loss:
Expand Down

0 comments on commit 6b5fb29

Please sign in to comment.