Skip to content

Commit

Permalink
[Feature] Make ProbabilisticTensorDictSequential account for more tha…
Browse files Browse the repository at this point in the history
…n one distribution

ghstack-source-id: b62b81b5cfd49168b5875f7ba9b4f35b51cd2423
Pull Request resolved: #1114
  • Loading branch information
vmoens committed Dec 2, 2024
1 parent e871b7d commit c7bd20c
Show file tree
Hide file tree
Showing 4 changed files with 710 additions and 87 deletions.
26 changes: 7 additions & 19 deletions tensordict/nn/distributions/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,22 +221,6 @@ def from_distributions(
self.inplace = inplace
return self

@property
def aggregate_probabilities(self):
aggregate_probabilities = self._aggregate_probabilities
if aggregate_probabilities is None:
warnings.warn(
"The default value of `aggregate_probabilities` will change from `False` to `True` in v0.7. "
"Please pass this value explicitly to avoid this warning.",
FutureWarning,
)
aggregate_probabilities = self._aggregate_probabilities = False
return aggregate_probabilities

@aggregate_probabilities.setter
def aggregate_probabilities(self, value):
self._aggregate_probabilities = value

def sample(self, shape=None) -> TensorDictBase:
if shape is None:
shape = torch.Size([])
Expand Down Expand Up @@ -337,7 +321,7 @@ def log_prob(
aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities``
from the class.
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default).
Has no effect if ``aggregate_probabilities`` is set to ``True``.
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
Expand All @@ -356,6 +340,8 @@ def log_prob(
"""
if aggregate_probabilities is None:
aggregate_probabilities = self.aggregate_probabilities
if aggregate_probabilities is None:
aggregate_probabilities = False
if not aggregate_probabilities:
return self.log_prob_composite(
sample, include_sum=include_sum, inplace=inplace
Expand All @@ -382,7 +368,7 @@ def log_prob_composite(
Keyword Args:
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default).
.. warning:: The default value of ``include_sum`` will switch to ``False`` in v0.9 in the constructor.
Expand Down Expand Up @@ -451,7 +437,7 @@ def entropy(
setting from the class. Determines whether to return a single summed entropy tensor or a TensorDict
with individual entropies. Defaults to ``False`` if not set in the class.
include_sum (bool, optional): Whether to include the summed entropy in the output TensorDict.
Defaults to `self.inplace`, which is set through the class constructor. Has no effect if
Defaults to `self.include_sum`, which is set through the class constructor. Has no effect if
`aggregate_probabilities` is set to `True`.
.. warning:: The default value of `include_sum` will switch to `False` in v0.9 in the constructor.
Expand All @@ -466,6 +452,8 @@ def entropy(
"""
if aggregate_probabilities is None:
aggregate_probabilities = self.aggregate_probabilities
if aggregate_probabilities is None:
aggregate_probabilities = False
if not aggregate_probabilities:
return self.entropy_composite(samples_mc, include_sum=include_sum)
se = 0.0
Expand Down
Loading

1 comment on commit c7bd20c

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: c7bd20c Previous: 978d96c Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 126696.62011603784 iter/sec (stddev: 5.773436748800963e-7) 277296.5756871786 iter/sec (stddev: 4.191384604963448e-7) 2.19
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 127007.47714629647 iter/sec (stddev: 5.46058555435402e-7) 274031.46884222946 iter/sec (stddev: 3.616653353246629e-7) 2.16

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.