diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 04a7e79cb0..be5b6ffa7d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -13,11 +13,10 @@ # limitations under the License. import os -import random import textwrap import warnings from collections import defaultdict -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Sized, Union from unittest.mock import patch import torch @@ -69,21 +68,32 @@ class RepeatRandomSampler(Sampler): """ Sampler that repeats the indices of a dataset N times. + + Args: + data_source (`Sized`): + Dataset to sample from. + repeat_count (`int`): + Number of times to repeat each index. + + Example: + ```python + >>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2) + >>> list(sampler) + [2, 2, 0, 0, 3, 3, 1, 1] + ``` """ - def __init__(self, data_source, repeat_count): + def __init__(self, data_source: Sized, repeat_count: int): self.data_source = data_source self.repeat_count = repeat_count self.num_samples = len(data_source) def __iter__(self): - while True: - index = random.randint(0, self.num_samples - 1) # Pick a random index - for _ in range(self.repeat_count): - yield index # Yield the same index N times + indexes = [idx for idx in torch.randperm(self.num_samples).tolist() for _ in range(self.repeat_count)] + return iter(indexes) def __len__(self): - return self.num_samples * self.repeat_count # Theoretically infinite, but define a max length if needed + return self.num_samples * self.repeat_count def broadcast_and_slice_dict(accelerator: Accelerator, tensor_dict: dict[str, Tensor], from_process: int = 0):