Skip to content

Commit

Permalink
Try "warm up" phase (#41)
Browse files Browse the repository at this point in the history
* adding linear_bump_lr
* Add get_memory_usage_mb function for memory monitoring
  • Loading branch information
matsen authored Jun 19, 2024
1 parent 731258f commit d88b855
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
27 changes: 27 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import inspect
import itertools
import resource
import subprocess

import numpy as np
Expand Down Expand Up @@ -228,6 +229,12 @@ def print_tensor_devices(scope="local"):
print(f"{var_name}: {var_value.device}")


def get_memory_usage_mb():
# Returns the peak memory usage in MB
usage = resource.getrusage(resource.RUSAGE_SELF)
return usage.ru_maxrss / 1024 # Convert from KB to MB


# Reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
Expand All @@ -253,3 +260,23 @@ def forward(self, x: Tensor) -> Tensor:
"""
x = x + self.pe[: x.size(0)]
return self.dropout(x)


def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr):
"""
Linearly increase the learning rate from min_lr to max_lr over warmup_epochs,
then linearly decrease the learning rate from max_lr to min_lr.
See https://github.com/matsengrp/netam/pull/41 for more details.
pd.Series([
linear_bump_lr(epoch, warmup_epochs=20, total_epochs=200, max_lr=0.01, min_lr=1e-5)
for epoch in range(200)]).plot()
"""
if epoch < warmup_epochs:
lr = min_lr + ((max_lr - min_lr) / warmup_epochs) * epoch
else:
lr = max_lr - ((max_lr - min_lr) / (total_epochs - warmup_epochs)) * (
epoch - warmup_epochs
)
return lr
5 changes: 4 additions & 1 deletion netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,10 @@ def record_losses(train_loss, val_loss):
with tqdm(range(1, epochs + 1), desc="Epoch") as pbar:
for epoch in pbar:
current_lr = self.optimizer.param_groups[0]["lr"]
if current_lr < self.min_learning_rate:
if (
isinstance(self.scheduler, ReduceLROnPlateau)
and current_lr < self.min_learning_rate
):
break

if self.device.type == "cuda":
Expand Down

0 comments on commit d88b855

Please sign in to comment.