Skip to content

Commit

Permalink
maxtasksperchild
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 22, 2023
1 parent b607c0e commit aa42433
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
19 changes: 13 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,11 +2914,12 @@ def map(
self,
fn: Callable,
dim: int = 0,
num_workers: int = None,
chunksize: int = None,
num_chunks: int = None,
pool: mp.Pool = None,
seed: int = None,
num_workers: int | None = None,
chunksize: int | None = None,
num_chunks: int | None = None,
pool: mp.Pool | None = None,
seed: int | None = None,
maxtasksperchild: int | None = None,
):
"""Maps a function to splits of the tensordict across one dimension.
Expand Down Expand Up @@ -2975,6 +2976,9 @@ def map(
know which worker will pick which job. However, we can make sure
that each worker has a different seed and that the pseudo-random
operations on each will be uncorrelated.
maxtasksperchild (int, optional): the maximum number of jobs picked
by every child process. Defaults to ``None``, i.e., no restriction
on the number of jobs.
Examples:
>>> import torch
Expand Down Expand Up @@ -3009,7 +3013,10 @@ def map(
for i in range(num_workers):
queue.put(i)
with mp.Pool(
num_workers, initializer=_proc_init, initargs=(seed, queue)
num_workers,
initializer=_proc_init,
initargs=(seed, queue),
maxtasksperchild=maxtasksperchild,
) as pool:
return self.map(
fn, dim=dim, chunksize=chunksize, num_chunks=num_chunks, pool=pool
Expand Down
8 changes: 6 additions & 2 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6572,8 +6572,12 @@ def test_map_seed(self):
# we use 20 workers to make sure that each worker has one item to work with
# Using less could cause undeterministic behaviour depending on the workers'
# speed, since we cannot tell who will pick which job.
td_out_0 = td.map(self.get_rand_incr, num_workers=20, seed=0, chunksize=1)
td_out_1 = td.map(self.get_rand_incr, num_workers=20, seed=0, chunksize=1)
td_out_0 = td.map(
self.get_rand_incr, num_workers=20, seed=0, chunksize=1, maxtasksperchild=1
)
td_out_1 = td.map(
self.get_rand_incr, num_workers=20, seed=0, chunksize=1, maxtasksperchild=1
)
# we cannot know which worker picks which job, but since they will all have
# a seed from 0 to 4 and produce 1 number each, we can chekc that
# those numbers are exactly what we were expecting.
Expand Down

0 comments on commit aa42433

Please sign in to comment.