diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 7e7567ff974..283bd99bd52 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -1003,6 +1003,33 @@ def _make_node_map(self, source, dest): self.max_size = self.data_map.max_size def extend(self, rollout, *, return_node: bool = False): + """Add a rollout to the forest. + + Nodes are only added to a tree at points where rollouts diverge from + each other and at the endpoints of rollouts. + + If there is no existing tree that matches the first steps of the + rollout, a new tree is added. Only one node is created, for the final + step. + + If there is an existing tree that matches, the rollout is added to that + tree. If the rollout diverges from all other rollouts in the tree at + some step, a new node is created before the step where the rollouts + diverge, and a leaf node is created for the final step of the rollout. + If all of the rollout's steps match with a previously added rollout, + nothing changes. If the rollout matches up to a leaf node of a tree but + continues beyond it, that node is extended to the end of the rollout, + and no new nodes are created. + + Args: + rollout (TensorDict): The rollout to add to the forest. + return_node (bool, optional): If True, the method returns the added + node. Default is ``False``. + + Returns: + Tree: The node that was added to the forest. This is only + returned if ``return_node`` is True. + """ source, dest = ( rollout.exclude("next").copy(), rollout.select("next", *self.action_keys).copy(),