Skip to content

Commit

Permalink
[Minor] Fix doc and MARL tests
Browse files Browse the repository at this point in the history
ghstack-source-id: 9308be3ebc7fac30b5bde321792eb97069d55996
Pull Request resolved: #2759
  • Loading branch information
vmoens committed Feb 5, 2025
1 parent cb37521 commit ad7d2a1
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 39 deletions.
49 changes: 31 additions & 18 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,57 +230,70 @@ PPO
Using PPO with multi-head action policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. note:: The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`,
:class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`.
When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the
beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not
be aggregated but rather written as leaves in the tensordict.

In some cases, we have a single advantage value but more than one action undertaken. Each action has its own
log-probability, and shape. For instance, it can be that the action space is structured as follows:

>>> action_td = TensorDict(
... action0=Tensor(batch, n_agents, f0),
... action1=Tensor(batch, n_agents, f1, f2),
... agents=TensorDict(
... action0=Tensor(batch, n_agents, f0),
... action1=Tensor(batch, n_agents, f1, f2),
... batch_size=torch.Size((batch, n_agents))
... ),
... batch_size=torch.Size((batch,))
... )

where `f0`, `f1` and `f2` are some arbitrary integers.

Note that, in TorchRL, the tensordict has the shape of the environment (if the environment is batch-locked, otherwise it
Note that, in TorchRL, the root tensordict has the shape of the environment (if the environment is batch-locked, otherwise it
has the shape of the number of batched environments being run). If the tensordict is sampled from the buffer, it will
also have the shape of the replay buffer `batch_size`. The `n_agent` dimension, although common to each action, does not
in general appear in the tensordict's batch-size.
in general appear in the root tensordict's batch-size (although it appears in the sub-tensordict containing the
agent-specific data according to the :ref:`MARL API <MARL-environment-API>`).

There is a legitimate reason why this is the case: the number of agent may condition some but not all the specs of the
environment. For example, some environments have a shared done state among all agents. A more complete tensordict
would in this case look like

>>> action_td = TensorDict(
... action0=Tensor(batch, n_agents, f0),
... action1=Tensor(batch, n_agents, f1, f2),
... agents=TensorDict(
... action0=Tensor(batch, n_agents, f0),
... action1=Tensor(batch, n_agents, f1, f2),
... observation=Tensor(batch, n_agents, f3),
... batch_size=torch.Size((batch, n_agents))
... ),
... done=Tensor(batch, 1),
... observation=Tensor(batch, n_agents, f3),
... [...] # etc
... batch_size=torch.Size((batch,))
... )

Notice that `done` states and `reward` are usually flanked by a rightmost singleton dimension. See this :ref:`part of the doc <reward_done_singleton>`
to learn more about this restriction.

The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`,
:class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`.
When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the
beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not
be aggregated but rather written as leaves in the tensordict.

The log-probability of our actions given their respective distributions may look like anything like

>>> action_td = TensorDict(
... action0_log_prob=Tensor(batch, n_agents),
... action1_log_prob=Tensor(batch, n_agents, f1),
... agents=TensorDict(
... action0_log_prob=Tensor(batch, n_agents),
... action1_log_prob=Tensor(batch, n_agents, f1),
... batch_size=torch.Size((batch, n_agents))
... ),
... batch_size=torch.Size((batch,))
... )

or

>>> action_td = TensorDict(
... action0_log_prob=Tensor(batch, n_agents),
... action1_log_prob=Tensor(batch, n_agents),
... agents=TensorDict(
... action0_log_prob=Tensor(batch, n_agents),
... action1_log_prob=Tensor(batch, n_agents),
... batch_size=torch.Size((batch, n_agents))
... ),
... batch_size=torch.Size((batch,))
... )

Expand Down Expand Up @@ -336,7 +349,7 @@ Dreamer
DreamerValueLoss

Multi-agent objectives
-----------------------
----------------------

.. currentmodule:: torchrl.objectives.multiagent

Expand Down
56 changes: 36 additions & 20 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,45 +209,54 @@ def __init__(self):
self.obs_feat = obs_feat = (5,)

self.full_observation_spec = Composite(
observation=Unbounded(batch + n_agents + obs_feat),
batch_size=batch,
agents=Composite(
observation=Unbounded(batch + n_agents + obs_feat),
shape=batch + n_agents,
),
shape=batch,
)
self.full_done_spec = Composite(
done=Unbounded(batch + (1,), dtype=torch.bool),
terminated=Unbounded(batch + (1,), dtype=torch.bool),
truncated=Unbounded(batch + (1,), dtype=torch.bool),
batch_size=batch,
shape=batch,
)

self.act_feat_dirich = act_feat_dirich = (
10,
2,
)
self.act_feat_dirich = act_feat_dirich = (10, 2)
self.act_feat_categ = act_feat_categ = (7,)
self.full_action_spec = Composite(
dirich=Unbounded(batch + n_agents + act_feat_dirich),
categ=Unbounded(batch + n_agents + act_feat_categ),
batch_size=batch,
agents=Composite(
dirich=Unbounded(batch + n_agents + act_feat_dirich),
categ=Unbounded(batch + n_agents + act_feat_categ),
shape=batch + n_agents,
),
shape=batch,
)

self.full_reward_spec = Composite(
reward=Unbounded(batch + n_agents + (1,)), batch_size=batch
agents=Composite(
reward=Unbounded(batch + n_agents + (1,)), shape=batch + n_agents
),
shape=batch,
)

@classmethod
def make_composite_dist(cls):
dist_cstr = functools.partial(
CompositeDistribution,
distribution_map={
"dirich": lambda concentration: torch.distributions.Independent(
(
"agents",
"dirich",
): lambda concentration: torch.distributions.Independent(
torch.distributions.Dirichlet(concentration), 1
),
"categ": torch.distributions.Categorical,
("agents", "categ"): torch.distributions.Categorical,
},
)
return ProbabilisticTensorDictModule(
in_keys=["params"],
out_keys=["dirich", "categ"],
out_keys=[("agents", "dirich"), ("agents", "categ")],
distribution_class=dist_cstr,
return_log_prob=True,
)
Expand Down Expand Up @@ -9309,8 +9318,13 @@ def test_ppo_marl_aggregate(self):

def primer(td):
params = TensorDict(
dirich=TensorDict(concentration=env.action_spec["dirich"].one()),
categ=TensorDict(logits=env.action_spec["categ"].one()),
agents=TensorDict(
dirich=TensorDict(
concentration=env.action_spec["agents", "dirich"].one()
),
categ=TensorDict(logits=env.action_spec["agents", "categ"].one()),
batch_size=env.action_spec["agents"].shape,
),
batch_size=td.batch_size,
)
td.set("params", params)
Expand All @@ -9323,11 +9337,13 @@ def primer(td):
)
output = policy(env.fake_tensordict())
assert output.shape == env.batch_size
assert output["dirich_log_prob"].shape == env.batch_size + env.n_agents
assert output["categ_log_prob"].shape == env.batch_size + env.n_agents
assert (
output["agents", "dirich_log_prob"].shape == env.batch_size + env.n_agents
)
assert output["agents", "categ_log_prob"].shape == env.batch_size + env.n_agents

output["advantage"] = output["next", "reward"].clone()
output["value_target"] = output["next", "reward"].clone()
output["advantage"] = output["next", "agents", "reward"].clone()
output["value_target"] = output["next", "agents", "reward"].clone()
critic = TensorDictModule(
lambda obs: obs.new_zeros((*obs.shape[:-1], 1)),
in_keys=list(env.full_observation_spec.keys(True, True)),
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3574,7 +3574,7 @@ def fake_tensordict(self) -> TensorDictBase:
observation_spec = self.observation_spec
action_spec = self.input_spec["full_action_spec"]
# instantiates reward_spec if needed
_ = self.reward_spec
_ = self.full_reward_spec
reward_spec = self.output_spec["full_reward_spec"]
full_done_spec = self.output_spec["full_done_spec"]

Expand Down

1 comment on commit ad7d2a1

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: ad7d2a1 Previous: cb37521 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 34.99344609498667 iter/sec (stddev: 0.18460706847019923) 416.806102201657 iter/sec (stddev: 0.0009607491415029777) 11.91

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.