Skip to content

Commit

Permalink
Bug fixes, and feature patches from @SMERKY
Browse files Browse the repository at this point in the history
  • Loading branch information
pic-o committed Jan 25, 2024
1 parent a1b1046 commit 6e89d63
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def __init__(self,
lr_final: float = -1.0,
lr_period: int = -1,
lr_period_type: str = 'epoch',
# Use either "cosine" or "linear"
lr_type: str = 'cosine',

# Dropout rate
dropout: float = 0.0,
# Adam optimizer settings
Expand Down Expand Up @@ -271,6 +274,7 @@ def __init__(self,
self.lr_final = lr_final
self.lr_period = lr_period
self.lr_period_type = lr_period_type
self.lr_type = lr_type
self.dropout = dropout
self.warmup_steps = warmup_steps
self.beta1 = beta1
Expand Down Expand Up @@ -521,12 +525,21 @@ def configure_optimizers(self):
raise ValueError(f"lr_period_type {self.lr_period_type} not supported.")

# Lets initialize the lr_scheduler
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1.0,
end_factor= lr_final / lr_init,
total_iters=lr_total_step
)
if self.lr_type == "cosine":
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=lr_total_step,
eta_min=lr_final
)
elif self.lr_type == "linear":
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1.0,
end_factor= lr_final / lr_init,
total_iters=lr_total_step
)
else:
raise ValueError(f"lr_type {self.lr_type} not supported.")

return {
'optimizer': optimizer,
Expand Down Expand Up @@ -566,7 +579,8 @@ def num_step_per_epoch(self) -> int:
dataset_size = len(train_dataloader)

num_devices = max(1, self.trainer.num_devices)
num_steps = dataset_size // (self.trainer.accumulate_grad_batches * num_devices)
num_nodes = max(1, self.trainer.num_nodes)
num_steps = dataset_size // (self.trainer.accumulate_grad_batches * num_devices * num_nodes)
return num_steps

@property
Expand Down

0 comments on commit 6e89d63

Please sign in to comment.