Skip to content

Commit

Permalink
[Feature] COMPOSITE_LP_AGGREGATE env variable
Browse files Browse the repository at this point in the history
ghstack-source-id: 16b07d0eac582cfd419612f87e38e1a7acffcfc0
Pull Request resolved: #1190
  • Loading branch information
vmoens committed Jan 21, 2025
1 parent 790bef6 commit 9733d6e
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,13 @@ def _generate_next_value_(name, start, count, last_values):
return name.lower()


_composite_lp_aggregate = _ContextManager()
_composite_lp_aggregate = _ContextManager(
default=(
strtobool(os.getenv("COMPOSITE_LP_AGGREGATE"))
if os.getenv("COMPOSITE_LP_AGGREGATE") is not None
else None
)
)


def composite_lp_aggregate(nowarn: bool = False) -> bool | None:
Expand All @@ -467,9 +473,9 @@ def composite_lp_aggregate(nowarn: bool = False) -> bool | None:
if not nowarn:
warnings.warn(
"Composite log-prob aggregation wasn't defined explicitly and ``composite_lp_aggregate()`` will "
"currently return ``True``. However, from v0.9, this behaviour will change and ``composite_lp_aggregate`` will "
"currently return ``True``. However, from v0.9, this behavior will change and ``composite_lp_aggregate`` will "
"return ``False``. Please change your code accordingly by specifying the aggregation strategy via "
"`tensordict.nn.set_composite_lp_aggregate`.",
"`tensordict.nn.set_composite_lp_aggregate` or via the `COMPOSITE_LP_AGGREGATE` environment variable.",
category=DeprecationWarning,
)
return True
Expand All @@ -483,6 +489,8 @@ class set_composite_lp_aggregate(_DecoratorContextManager):
will be summed into a single tensor with the shape of the root tensordict. This behaviour is being deprecated in favor of
non-aggregated log-probs, which offer more flexibility and a somewhat more natural API (tensordict samples, tensordict log-probs, tensordict entropies).
The value of composite_lp_aggregate can also be controlled through the `COMPOSITE_LP_AGGREGATE` environment variable.
Example:
>>> _ = torch.manual_seed(0)
>>> from tensordict import TensorDict
Expand Down

1 comment on commit 9733d6e

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 9733d6e Previous: 790bef6 Ratio
benchmarks/tensorclass/test_torch_functions.py::test_zeros_like 114.52104034927143 iter/sec (stddev: 0.0027453456740762546) 229.27875914727062 iter/sec (stddev: 0.00019811542657909156) 2.00
benchmarks/tensorclass/test_torch_functions.py::test_ones_like 115.0436268022277 iter/sec (stddev: 0.002742486417007542) 230.21639335588992 iter/sec (stddev: 0.00006196716475215159) 2.00

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.