From 38fee4288b33c25a3167fa3e3175bc2381526542 Mon Sep 17 00:00:00 2001 From: Gert-Jan Both Date: Fri, 26 Apr 2024 13:36:21 -0400 Subject: [PATCH 1/7] Updated Multisynccollector --- torchrl/collectors/collectors.py | 44 ++++++++++++++++++-------------- torchrl/data/tensor_specs.py | 10 ++++++-- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 2364f70ac0c..947fea59b8a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1884,28 +1884,35 @@ class MultiSyncDataCollector(_MultiDataCollector): trajectory and the start of the next collection. This class can be safely used with online RL sota-implementations. + note:: Python requires multiprocessed code to be instantiated within a + ```if __name__ == "__main__":``` block. See https://docs.python.org/3/library/multiprocessing.html + for more info. + + Examples: >>> from torchrl.envs.libs.gym import GymEnv - >>> from torchrl.envs import StepCounter >>> from tensordict.nn import TensorDictModule >>> from torch import nn - >>> env_maker = lambda: TransformedEnv(GymEnv("Pendulum-v1", device="cpu"), StepCounter(max_steps=50)) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = MultiSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... devices="cpu", - ... storing_devices="cpu", - ... ) - >>> for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + >>> for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), @@ -1987,7 +1994,6 @@ def _queue_len(self) -> int: return self.num_workers def iterator(self) -> Iterator[TensorDictBase]: - cat_results = self.cat_results if cat_results is None: cat_results = 0 diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 49604f7024e..6f8443e45b0 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1597,12 +1597,18 @@ def __init__( if high is not None: raise TypeError(self.CONFLICTING_KWARGS.format("high", "maximum")) high = kwargs.pop("maximum") - warnings.warn("Maximum is deprecated since v0.4.0, using high instead.", category=DeprecationWarning) + warnings.warn( + "Maximum is deprecated since v0.4.0, using high instead.", + category=DeprecationWarning, + ) if "minimum" in kwargs: if low is not None: raise TypeError(self.CONFLICTING_KWARGS.format("low", "minimum")) low = kwargs.pop("minimum") - warnings.warn("Minimum is deprecated since v0.4.0, using low instead.", category=DeprecationWarning) + warnings.warn( + "Minimum is deprecated since v0.4.0, using low instead.", + category=DeprecationWarning, + ) domain = kwargs.pop("domain", "continuous") if len(kwargs): raise TypeError(f"Got unrecognised kwargs {tuple(kwargs.keys())}.") From 6cbc7387785650800c69d3b7bdc6405b6e30a04e Mon Sep 17 00:00:00 2001 From: Gert-Jan Both Date: Fri, 26 Apr 2024 14:01:47 -0400 Subject: [PATCH 2/7] Update multiasync docstring. --- torchrl/collectors/collectors.py | 48 +++++++++++++++++++------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 947fea59b8a..3a18df47cb0 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -535,7 +535,10 @@ def __init__( self.reset_when_done = reset_when_done self.n_env = self.env.batch_size.numel() - (self.policy, self.get_weights_fn,) = self._get_policy_and_device( + ( + self.policy, + self.get_weights_fn, + ) = self._get_policy_and_device( policy=policy, observation_spec=self.env.observation_spec, ) @@ -2238,27 +2241,34 @@ class MultiaSyncDataCollector(_MultiDataCollector): the batch of rollouts is collected and the next call to the iterator. This class can be safely used with offline RL sota-implementations. + note:: Python requires multiprocessed code to be instantiated within a + ```if __name__ == "__main__":``` block. See https://docs.python.org/3/library/multiprocessing.html + for more info. + Examples: - >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn - >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = MultiaSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... devices="cpu", - ... storing_devices="cpu", - ... ) - >>> for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = MultiaSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + >>> for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), From f9960afbc7109231c2182a1d7ab3471738757e05 Mon Sep 17 00:00:00 2001 From: Gert-Jan Both Date: Fri, 26 Apr 2024 15:14:09 -0400 Subject: [PATCH 3/7] Code block to inline. --- torchrl/collectors/collectors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3a18df47cb0..bc682b86e56 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1888,7 +1888,7 @@ class MultiSyncDataCollector(_MultiDataCollector): This class can be safely used with online RL sota-implementations. note:: Python requires multiprocessed code to be instantiated within a - ```if __name__ == "__main__":``` block. See https://docs.python.org/3/library/multiprocessing.html + `if __name__ == "__main__":` block. See https://docs.python.org/3/library/multiprocessing.html for more info. @@ -2242,7 +2242,7 @@ class MultiaSyncDataCollector(_MultiDataCollector): This class can be safely used with offline RL sota-implementations. note:: Python requires multiprocessed code to be instantiated within a - ```if __name__ == "__main__":``` block. See https://docs.python.org/3/library/multiprocessing.html + `if __name__ == "__main__":` block. See https://docs.python.org/3/library/multiprocessing.html for more info. Examples: From ba10554393c8efdedb99579970a31c8b004cbeb3 Mon Sep 17 00:00:00 2001 From: Gert-Jan Both <32122273+GJBoth@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:21:17 -0400 Subject: [PATCH 4/7] Update torchrl/collectors/collectors.py Co-authored-by: Vincent Moens --- torchrl/collectors/collectors.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index bc682b86e56..df251722dc1 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1887,9 +1887,13 @@ class MultiSyncDataCollector(_MultiDataCollector): trajectory and the start of the next collection. This class can be safely used with online RL sota-implementations. - note:: Python requires multiprocessed code to be instantiated within a - `if __name__ == "__main__":` block. See https://docs.python.org/3/library/multiprocessing.html - for more info. + .. note:: Python requires multiprocessed code to be instantiated within a main guard: + + >>> if __name__ == "__main__": + ... # Create your collector here + + See https://docs.python.org/3/library/multiprocessing.html for more info. + Examples: From 1fae6d2776512a6998f79f86ec78791fc4e7e4f4 Mon Sep 17 00:00:00 2001 From: Gert-Jan Both <32122273+GJBoth@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:21:42 -0400 Subject: [PATCH 5/7] Update torchrl/collectors/collectors.py Co-authored-by: Vincent Moens --- torchrl/collectors/collectors.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index df251722dc1..43841125fda 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1902,21 +1902,21 @@ class MultiSyncDataCollector(_MultiDataCollector): >>> from torch import nn >>> from torchrl.collectors import MultiSyncDataCollector >>> if __name__ == "__main__": - >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = MultiSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - >>> for i, data in enumerate(collector): + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break From 1029f115344328b9e1b6bb914bc065af948c785a Mon Sep 17 00:00:00 2001 From: Gert-Jan Both Date: Mon, 29 Apr 2024 15:25:17 -0400 Subject: [PATCH 6/7] Update sync and indent shutdown. --- torchrl/collectors/collectors.py | 42 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 43841125fda..0c3aab73077 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1946,8 +1946,8 @@ class MultiSyncDataCollector(_MultiDataCollector): batch_size=torch.Size([200]), device=cpu, is_shared=False) - >>> collector.shutdown() - >>> del collector + ... collector.shutdown() + ... del collector """ @@ -2249,27 +2249,27 @@ class MultiaSyncDataCollector(_MultiDataCollector): `if __name__ == "__main__":` block. See https://docs.python.org/3/library/multiprocessing.html for more info. - Examples: - >>> from torchrl.envs.libs.gym import GymEnv + Examples: + >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> from torchrl.collectors import MultiaSyncDataCollector >>> if __name__ == "__main__": - >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = MultiaSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - >>> for i, data in enumerate(collector): + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiaSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break @@ -2299,8 +2299,8 @@ class MultiaSyncDataCollector(_MultiDataCollector): batch_size=torch.Size([200]), device=cpu, is_shared=False) - >>> collector.shutdown() - >>> del collector + ... collector.shutdown() + ... del collector """ From d425c25e491e252c2e9b981eda270fca0b74bb0a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Apr 2024 21:28:39 +0100 Subject: [PATCH 7/7] amend --- torchrl/collectors/collectors.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0c3aab73077..b17a0fbe736 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -535,10 +535,7 @@ def __init__( self.reset_when_done = reset_when_done self.n_env = self.env.batch_size.numel() - ( - self.policy, - self.get_weights_fn, - ) = self._get_policy_and_device( + (self.policy, self.get_weights_fn,) = self._get_policy_and_device( policy=policy, observation_spec=self.env.observation_spec, ) @@ -1889,13 +1886,12 @@ class MultiSyncDataCollector(_MultiDataCollector): .. note:: Python requires multiprocessed code to be instantiated within a main guard: + >>> from torchrl.collectors import MultiSyncDataCollector >>> if __name__ == "__main__": ... # Create your collector here See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule @@ -1920,6 +1916,8 @@ class MultiSyncDataCollector(_MultiDataCollector): ... if i == 2: ... print(data) ... break + ... collector.shutdown() + ... del collector TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), @@ -1946,8 +1944,6 @@ class MultiSyncDataCollector(_MultiDataCollector): batch_size=torch.Size([200]), device=cpu, is_shared=False) - ... collector.shutdown() - ... del collector """ @@ -2245,11 +2241,15 @@ class MultiaSyncDataCollector(_MultiDataCollector): the batch of rollouts is collected and the next call to the iterator. This class can be safely used with offline RL sota-implementations. - note:: Python requires multiprocessed code to be instantiated within a - `if __name__ == "__main__":` block. See https://docs.python.org/3/library/multiprocessing.html - for more info. + .. note:: Python requires multiprocessed code to be instantiated within a main guard: - Examples: + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn @@ -2273,6 +2273,8 @@ class MultiaSyncDataCollector(_MultiDataCollector): ... if i == 2: ... print(data) ... break + ... collector.shutdown() + ... del collector TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), @@ -2299,8 +2301,6 @@ class MultiaSyncDataCollector(_MultiDataCollector): batch_size=torch.Size([200]), device=cpu, is_shared=False) - ... collector.shutdown() - ... del collector """