Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ProbabilisticTensorDictModule._dist_sample hasattr error #1152

Open
3 tasks done
olliepro opened this issue Dec 19, 2024 · 0 comments
Open
3 tasks done

[BUG] ProbabilisticTensorDictModule._dist_sample hasattr error #1152

olliepro opened this issue Dec 19, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@olliepro
Copy link

olliepro commented Dec 19, 2024

Describe the bug

The following code section is logically flawed. Bug was introduced here

        elif interaction_type is InteractionType.MEAN:
            if hasattr(dist, "mean"):
                try:
                    return dist.mean
                except NotImplementedError:
                    pass
            if dist.has_rsample:
                return dist.rsample((self.n_empirical_estimate,)).mean(0)
            else:
                return dist.sample((self.n_empirical_estimate,)).mean(0)

The hasattr attempts to access the dist.mean attribute to assess its existence. If the dist.mean raises a NotImplementedError, then hasattr(dist,'mean') will raise a NotImplemented error. The subsequent try block is placed in the wrong spot to catch such an error.

To Reproduce

from tensordict.nn import ProbabilisticTensorDictModule, InteractionType
from torchrl.modules import TanhNormal

prob_module = ProbabilisticTensorDictModule(
     in_keys=["loc", "scale"],
     out_keys=["action"],
     distribution_class=TanhNormal,
     distribution_kwargs={
         "low": -1,
         "high": 1,
     },
     return_log_prob=True,
)

prob_module._dist_sample(
    dist=TanhNormal(low=-1, high=1, loc=0, scale=1),
    interaction_type=InteractionType.MEAN
)

Expected behavior

The code should correctly catch the NotImplementedError exception and then estimate the mean

Suggestion

        elif interaction_type is InteractionType.MEAN:
            try:
                return dist.mean
            except AttributeError, NotImplementedError:
                if dist.has_rsample:
                    return dist.rsample((self.n_empirical_estimate,)).mean(0)
                else:
                    return dist.sample((self.n_empirical_estimate,)).mean(0)

System info

Python 3.10
torch==2.5.1+cu121
torchrl==0.6.0
tensordict==0.6.2

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@olliepro olliepro added the bug Something isn't working label Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants