diff --git a/tensordict/base.py b/tensordict/base.py index 63c1b197e..deaa93ba8 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2914,11 +2914,12 @@ def map( self, fn: Callable, dim: int = 0, - num_workers: int = None, - chunksize: int = None, - num_chunks: int = None, - pool: mp.Pool = None, - seed: int = None, + num_workers: int | None = None, + chunksize: int | None = None, + num_chunks: int | None = None, + pool: mp.Pool | None = None, + seed: int | None = None, + maxtasksperchild: int | None = None, ): """Maps a function to splits of the tensordict across one dimension. @@ -2975,6 +2976,9 @@ def map( know which worker will pick which job. However, we can make sure that each worker has a different seed and that the pseudo-random operations on each will be uncorrelated. + maxtasksperchild (int, optional): the maximum number of jobs picked + by every child process. Defaults to ``None``, i.e., no restriction + on the number of jobs. Examples: >>> import torch @@ -3009,7 +3013,10 @@ def map( for i in range(num_workers): queue.put(i) with mp.Pool( - num_workers, initializer=_proc_init, initargs=(seed, queue) + num_workers, + initializer=_proc_init, + initargs=(seed, queue), + maxtasksperchild=maxtasksperchild, ) as pool: return self.map( fn, dim=dim, chunksize=chunksize, num_chunks=num_chunks, pool=pool diff --git a/test/test_tensordict.py b/test/test_tensordict.py index cd88880d0..83ca117d5 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6572,8 +6572,12 @@ def test_map_seed(self): # we use 20 workers to make sure that each worker has one item to work with # Using less could cause undeterministic behaviour depending on the workers' # speed, since we cannot tell who will pick which job. - td_out_0 = td.map(self.get_rand_incr, num_workers=20, seed=0, chunksize=1) - td_out_1 = td.map(self.get_rand_incr, num_workers=20, seed=0, chunksize=1) + td_out_0 = td.map( + self.get_rand_incr, num_workers=20, seed=0, chunksize=1, maxtasksperchild=1 + ) + td_out_1 = td.map( + self.get_rand_incr, num_workers=20, seed=0, chunksize=1, maxtasksperchild=1 + ) # we cannot know which worker picks which job, but since they will all have # a seed from 0 to 4 and produce 1 number each, we can chekc that # those numbers are exactly what we were expecting.