Skip to content

Commit

Permalink
[Doc] Add docstring for MCTSForest.extend
Browse files Browse the repository at this point in the history
ghstack-source-id: dbef5e48ea55db6ba7867e1b24eb4711ad08af61
Pull Request resolved: #2795
  • Loading branch information
kurtamohler committed Feb 19, 2025
1 parent 3e614ff commit 10e2f69
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,33 @@ def _make_node_map(self, source, dest):
self.max_size = self.data_map.max_size

def extend(self, rollout, *, return_node: bool = False):
"""Add a rollout to the forest.
Nodes are only added to a tree at points where rollouts diverge from
each other and at the endpoints of rollouts.
If there is no existing tree that matches the first steps of the
rollout, a new tree is added. Only one node is created, for the final
step.
If there is an existing tree that matches, the rollout is added to that
tree. If the rollout diverges from all other rollouts in the tree at
some step, a new node is created before the step where the rollouts
diverge, and a leaf node is created for the final step of the rollout.
If all of the rollout's steps match with a previously added rollout,
nothing changes. If the rollout matches up to a leaf node of a tree but
continues beyond it, that node is extended to the end of the rollout,
and no new nodes are created.
Args:
rollout (TensorDict): The rollout to add to the forest.
return_node (bool, optional): If True, the method returns the added
node. Default is ``False``.
Returns:
Tree: The node that was added to the forest. This is only
returned if ``return_node`` is True.
"""
source, dest = (
rollout.exclude("next").copy(),
rollout.select("next", *self.action_keys).copy(),
Expand Down

0 comments on commit 10e2f69

Please sign in to comment.