diff --git a/RWKV-v5/src/model.py b/RWKV-v5/src/model.py index a7f9f794..f71f46c0 100644 --- a/RWKV-v5/src/model.py +++ b/RWKV-v5/src/model.py @@ -183,6 +183,9 @@ def __init__(self, lr_final: float = -1.0, lr_period: int = -1, lr_period_type: str = 'epoch', + # Use either "cosine" or "linear" + lr_type: str = 'cosine', + # Dropout rate dropout: float = 0.0, # Adam optimizer settings @@ -271,6 +274,7 @@ def __init__(self, self.lr_final = lr_final self.lr_period = lr_period self.lr_period_type = lr_period_type + self.lr_type = lr_type self.dropout = dropout self.warmup_steps = warmup_steps self.beta1 = beta1 @@ -521,12 +525,21 @@ def configure_optimizers(self): raise ValueError(f"lr_period_type {self.lr_period_type} not supported.") # Lets initialize the lr_scheduler - lr_scheduler = torch.optim.lr_scheduler.LinearLR( - optimizer, - start_factor=1.0, - end_factor= lr_final / lr_init, - total_iters=lr_total_step - ) + if self.lr_type == "cosine": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=lr_total_step, + eta_min=lr_final + ) + elif self.lr_type == "linear": + lr_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1.0, + end_factor= lr_final / lr_init, + total_iters=lr_total_step + ) + else: + raise ValueError(f"lr_type {self.lr_type} not supported.") return { 'optimizer': optimizer, @@ -566,7 +579,8 @@ def num_step_per_epoch(self) -> int: dataset_size = len(train_dataloader) num_devices = max(1, self.trainer.num_devices) - num_steps = dataset_size // (self.trainer.accumulate_grad_batches * num_devices) + num_nodes = max(1, self.trainer.num_nodes) + num_steps = dataset_size // (self.trainer.accumulate_grad_batches * num_devices * num_nodes) return num_steps @property