Skip to content

Commit

Permalink
fix and document RepeatRandomSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 5, 2025
1 parent 0b131b1 commit a38231f
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.

import os
import random
import textwrap
import warnings
from collections import defaultdict
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Sized, Union
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -69,21 +68,32 @@
class RepeatRandomSampler(Sampler):
"""
Sampler that repeats the indices of a dataset N times.
Args:
data_source (`Sized`):
Dataset to sample from.
repeat_count (`int`):
Number of times to repeat each index.
Example:
```python
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
>>> list(sampler)
[2, 2, 0, 0, 3, 3, 1, 1]
```
"""

def __init__(self, data_source, repeat_count):
def __init__(self, data_source: Sized, repeat_count: int):
self.data_source = data_source
self.repeat_count = repeat_count
self.num_samples = len(data_source)

def __iter__(self):
while True:
index = random.randint(0, self.num_samples - 1) # Pick a random index
for _ in range(self.repeat_count):
yield index # Yield the same index N times
indexes = [idx for idx in torch.randperm(self.num_samples).tolist() for _ in range(self.repeat_count)]
return iter(indexes)

def __len__(self):
return self.num_samples * self.repeat_count # Theoretically infinite, but define a max length if needed
return self.num_samples * self.repeat_count


def broadcast_and_slice_dict(accelerator: Accelerator, tensor_dict: dict[str, Tensor], from_process: int = 0):
Expand Down

0 comments on commit a38231f

Please sign in to comment.