Skip to content

Commit

Permalink
adding tensor classes annotation for loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
SandishKumarHN committed Feb 13, 2024
1 parent 899af07 commit cbfb412
Show file tree
Hide file tree
Showing 13 changed files with 240 additions and 33 deletions.
26 changes: 23 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,10 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est):
action_spec_type=action_spec_type, device=device
)
loss_fn = DQNLoss(
actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn
actor,
loss_function="l2"
delay_value=delay_value,
double_dqn=double_dqn,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -1490,6 +1493,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est):
loss_function="l2",
delay_actor=delay_actor,
delay_value=delay_value,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -2118,6 +2122,7 @@ def test_td3(
noise_clip=noise_clip,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -2808,6 +2813,7 @@ def test_sac(
num_qvalue_nets=num_qvalue,
loss_function="l2",
**kwargs,
return_tensorclass,
)

if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
Expand Down Expand Up @@ -4216,6 +4222,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est):
num_qvalue_nets=num_qvalue,
loss_function="l2",
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -5013,6 +5020,7 @@ def test_cql(
with_lagrange=with_lagrange,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)

if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
Expand Down Expand Up @@ -6648,7 +6656,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
functional=functional,
return_tensorclass=False,
)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand Down Expand Up @@ -7113,6 +7127,7 @@ def test_reinforce_value_net(
critic_network=value_net,
delay_value=delay_value,
functional=functional,
return_tensorclass=False,
)

td = TensorDict(
Expand Down Expand Up @@ -7705,6 +7720,7 @@ def test_dreamer_world_model(
reco_loss=reco_loss,
delayed_clamp=delayed_clamp,
free_nats=free_nats,
return_tensorclass=False,
)
loss_td, _ = loss_module(tensordict)
for loss_str, lmbda in zip(
Expand Down Expand Up @@ -7962,7 +7978,10 @@ def test_odt(self, device):

actor = self._create_mock_actor(device=device)

loss_fn = OnlineDTLoss(actor)
loss_fn = OnlineDTLoss(
actor,
return_tensorclass=False
)
loss = loss_fn(td)
loss_transformer = sum(
loss[key]
Expand Down Expand Up @@ -8525,6 +8544,7 @@ def test_iql(
temperature=temperature,
expectile=expectile,
loss_function="l2",
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down
20 changes: 18 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d
Expand All @@ -30,6 +32,16 @@
VTrace,
)

@tensorclass
class A2CLosses:
loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy

class A2CLoss(LossModule):
"""TorchRL implementation of the A2C loss.
Expand Down Expand Up @@ -234,6 +246,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
return_tensorclass: bool = False,
):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -290,6 +303,7 @@ def __init__(
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type
self.return_tensorclass = return_tensorclass

@property
def functional(self):
Expand Down Expand Up @@ -445,7 +459,7 @@ def _cached_detach_critic_network_params(self):
return self.critic_network_params.detach()

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> A2CLosses:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
Expand All @@ -466,6 +480,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.critic_coef:
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
if self.return_tensorclass:
return A2CLosses._from_tensordict(td_out)
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down
20 changes: 18 additions & 2 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
import warnings
from copy import deepcopy
Expand All @@ -12,7 +14,7 @@
import numpy as np
import torch
import torch.nn as nn
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey, unravel_key
from torch import Tensor
Expand All @@ -35,6 +37,16 @@

from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

@tensorclass
class CQLLosses:
loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy

class CQLLoss(LossModule):
"""TorchRL implementation of the continuous CQL loss.
Expand Down Expand Up @@ -269,6 +281,7 @@ def __init__(
num_random: int = 10,
with_lagrange: bool = False,
lagrange_thresh: float = 0.0,
return_tensorclass: bool = False,
) -> None:
self._out_keys = None
super().__init__()
Expand Down Expand Up @@ -354,6 +367,7 @@ def __init__(
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self.return_tensorclass = return_tensorclass

@property
def target_entropy(self):
Expand Down Expand Up @@ -1171,7 +1185,7 @@ def value_loss(
return loss, metadata

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
def forward(self, tensordict: TensorDictBase) -> CQLLosses:
"""Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer.
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
Expand All @@ -1196,6 +1210,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
source=source,
batch_size=[],
)
if self.return_tensorclass:
return CQLLosses._from_tensordict(td_out)

return td_out

Expand Down
21 changes: 18 additions & 3 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, TensorDictModule

from tensordict.utils import NestedKey, unravel_key
Expand All @@ -25,6 +25,16 @@
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator

@tensorclass
class DDPGLosses:
loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy

class DDPGLoss(LossModule):
"""The DDPG Loss class.
Expand Down Expand Up @@ -189,6 +199,7 @@ def __init__(
delay_value: bool = True,
gamma: float = None,
separate_losses: bool = False,
return_tensorclass: bool = False,
) -> None:
self._in_keys = None
super().__init__()
Expand Down Expand Up @@ -229,6 +240,7 @@ def __init__(
)

self.loss_function = loss_function
self.return_tensorclass = return_tensorclass

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
Expand Down Expand Up @@ -266,7 +278,7 @@ def in_keys(self, values):
self._in_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
def forward(self, tensordict: TensorDictBase) -> DDPGLosses:
"""Computes the DDPG losses given a tensordict sampled from the replay buffer.
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
Expand All @@ -283,10 +295,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
loss_value, metadata = self.loss_value(tensordict)
loss_actor, metadata_actor = self.loss_actor(tensordict)
metadata.update(metadata_actor)
return TensorDict(
td_out = TensorDict(
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
batch_size=[],
)
if self.return_tensorclass:
return DDPGLosses._from_tensordict(td_out)
return td_out

def loss_actor(
self,
Expand Down
22 changes: 19 additions & 3 deletions torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Union

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey

Expand All @@ -18,6 +19,16 @@
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import distance_loss

@tensorclass
class OnlineDTLosses:
loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy

class OnlineDTLoss(LossModule):
r"""TorchRL implementation of the Online Decision Transformer loss.
Expand Down Expand Up @@ -78,6 +89,7 @@ def __init__(
fixed_alpha: bool = False,
target_entropy: Union[str, float] = "auto",
samples_mc_entropy: int = 1,
return_tensorclass: bool = False,
) -> None:
self._in_keys = None
self._out_keys = None
Expand Down Expand Up @@ -146,6 +158,7 @@ def __init__(
)

self.samples_mc_entropy = samples_mc_entropy
self.return_tensorclass = return_tensorclass
self._set_in_keys()

def _set_in_keys(self):
Expand Down Expand Up @@ -310,7 +323,7 @@ def out_keys(self, values):
self._out_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> OnlineDTLosses:
"""Compute the loss for the Online Decision Transformer."""
# extract action targets
tensordict = tensordict.clone(False)
Expand All @@ -328,4 +341,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
out = {
"loss": loss,
}
return TensorDict(out, [])
td_out = TensorDict(out, [])
if self.return_tensorclass:
return DDPGLosses._from_tensordict(td_out)
return td_out
Loading

0 comments on commit cbfb412

Please sign in to comment.