Skip to content

Commit

Permalink
[Feature] Add EnvBase.all_actions
Browse files Browse the repository at this point in the history
ghstack-source-id: 7abf9d469f740be5f14daffa2330811f7572dad9
Pull Request resolved: #2780
  • Loading branch information
kurtamohler authored and vmoens committed Feb 14, 2025
1 parent 1ed5d29 commit 67c3e9a
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 18 deletions.
57 changes: 57 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4087,6 +4087,63 @@ def test_env_reset_with_hash(self, stateful, include_san):
td_check = env.reset(td.select("fen_hash"))
assert (td_check == td).all()

@pytest.mark.parametrize("include_fen", [False, True])
@pytest.mark.parametrize("include_pgn", [False, True])
@pytest.mark.parametrize("stateful", [False, True])
@pytest.mark.parametrize("mask_actions", [False, True])
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
if not stateful and not include_fen and not include_pgn:
pytest.skip("fen or pgn must be included if not stateful")

env = ChessEnv(
include_fen=include_fen,
include_pgn=include_pgn,
stateful=stateful,
mask_actions=mask_actions,
)
td = env.reset()

if not mask_actions:
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
env.all_actions()
return

# Choose random actions from the output of `all_actions`
for _ in range(100):
if stateful:
all_actions = env.all_actions()
else:
# Reset the the initial state first, just to make sure
# `all_actions` knows how to get the board state from the input.
env.reset()
all_actions = env.all_actions(td.clone())

# Choose some random actions and make sure they match exactly one of
# the actions from `all_actions`. This part is not tested when
# `mask_actions == False`, because `rand_action` can pick illegal
# actions in that case.
if mask_actions:
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
# it fail to work properly for stateless mode. It doesn't know
# how to correctly reset the board state to what is given in the
# tensordict before picking an action. When this is fixed, we
# can get rid of the two `reset`s below
if not stateful:
env.reset(td.clone())
td_act = td.clone()
for _ in range(10):
rand_action = env.rand_action(td_act)
assert (rand_action["action"] == all_actions["action"]).sum() == 1
if not stateful:
env.reset()

action_idx = torch.randint(0, all_actions.shape[0], ()).item()
chosen_action = all_actions[action_idx]
td = env.step(td.update(chosen_action))["next"]

if td["done"]:
td = env.reset()


class TestCustomEnvs:
def test_tictactoe_env(self):
Expand Down
48 changes: 31 additions & 17 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,12 +869,16 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
return self.is_in(item)

@abc.abstractmethod
def enumerate(self) -> Any:
def enumerate(self, use_mask: bool = False) -> Any:
"""Returns all the samples that can be obtained from the TensorSpec.
The samples will be stacked along the first dimension.
This method is only implemented for discrete specs.
Args:
use_mask (bool, optional): If ``True`` and the spec has a mask,
samples that are masked are excluded. Default is ``False``.
"""
...

Expand Down Expand Up @@ -1315,9 +1319,9 @@ def __eq__(self, other):
return False
return True

def enumerate(self) -> torch.Tensor | TensorDictBase:
def enumerate(self, use_mask: bool = False) -> torch.Tensor | TensorDictBase:
return torch.stack(
[spec.enumerate() for spec in self._specs], dim=self.stack_dim + 1
[spec.enumerate(use_mask) for spec in self._specs], dim=self.stack_dim + 1
)

def __len__(self):
Expand Down Expand Up @@ -1810,7 +1814,9 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
return np.array(vals).reshape(tuple(val.shape))
return val

def enumerate(self) -> torch.Tensor:
def enumerate(self, use_mask: bool = False) -> torch.Tensor:
if use_mask:
raise NotImplementedError
return (
torch.eye(self.n, dtype=self.dtype, device=self.device)
.expand(*self.shape, self.n)
Expand Down Expand Up @@ -2142,7 +2148,7 @@ def __init__(
domain=domain,
)

def enumerate(self) -> Any:
def enumerate(self, use_mask: bool = False) -> Any:
raise NotImplementedError(
f"enumerate is not implemented for spec of class {type(self).__name__}."
)
Expand Down Expand Up @@ -2481,7 +2487,7 @@ def __eq__(self, other):
def cardinality(self) -> Any:
raise RuntimeError("Cannot enumerate a NonTensorSpec.")

def enumerate(self) -> Any:
def enumerate(self, use_mask: bool = False) -> Any:
raise RuntimeError("Cannot enumerate a NonTensorSpec.")

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
Expand Down Expand Up @@ -2779,7 +2785,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
val.shape[: -self.ndim] + self.shape
)

def enumerate(self) -> Any:
def enumerate(self, use_mask: bool = False) -> Any:
raise NotImplementedError("enumerate cannot be called with continuous specs.")

def expand(self, *shape):
Expand Down Expand Up @@ -2951,9 +2957,9 @@ def __init__(
def cardinality(self) -> int:
return torch.as_tensor(self.nvec).prod()

def enumerate(self) -> torch.Tensor:
def enumerate(self, use_mask: bool = False) -> torch.Tensor:
nvec = self.nvec
enum_disc = self.to_categorical_spec().enumerate()
enum_disc = self.to_categorical_spec().enumerate(use_mask)
enums = torch.cat(
[
torch.nn.functional.one_hot(enum_unb, nv).to(self.dtype)
Expand Down Expand Up @@ -3417,14 +3423,18 @@ def __init__(
def _undefined_n(self):
return self.space.n < 0

def enumerate(self) -> torch.Tensor:
def enumerate(self, use_mask: bool = False) -> torch.Tensor:
dtype = self.dtype
if dtype is torch.bool:
dtype = torch.uint8
arange = torch.arange(self.n, dtype=dtype, device=self.device)
n = self.n
arange = torch.arange(n, dtype=dtype, device=self.device)
if use_mask and self.mask is not None:
arange = arange[self.mask]
n = arange.shape[0]
if self.ndim:
arange = arange.view(-1, *(1,) * self.ndim)
return arange.expand(self.n, *self.shape)
return arange.expand(n, *self.shape)

@property
def n(self):
Expand Down Expand Up @@ -4088,7 +4098,9 @@ def __init__(
self.update_mask(mask)
self.remove_singleton = remove_singleton

def enumerate(self) -> torch.Tensor:
def enumerate(self, use_mask: bool = False) -> torch.Tensor:
if use_mask:
raise NotImplementedError()
if self.mask is not None:
raise RuntimeError(
"Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
Expand Down Expand Up @@ -5136,13 +5148,15 @@ def cardinality(self) -> int:
n = 0
return n

def enumerate(self) -> TensorDictBase:
def enumerate(self, use_mask: bool = False) -> TensorDictBase:
# We are going to use meshgrid to create samples of all the subspecs in here
# but first let's get rid of the batch size, we'll put it back later
self_without_batch = self
while self_without_batch.ndim:
self_without_batch = self_without_batch[0]
samples = {key: spec.enumerate() for key, spec in self_without_batch.items()}
samples = {
key: spec.enumerate(use_mask) for key, spec in self_without_batch.items()
}
if self.data_cls is not None:
cls = self.data_cls
else:
Expand Down Expand Up @@ -5566,10 +5580,10 @@ def update(self, dict) -> None:
self[key] = item
return self

def enumerate(self) -> TensorDictBase:
def enumerate(self, use_mask: bool = False) -> TensorDictBase:
dim = self.stack_dim
return LazyStackedTensorDict.maybe_dense_stack(
[spec.enumerate() for spec in self._specs], dim + 1
[spec.enumerate(use_mask) for spec in self._specs], dim + 1
)

def __eq__(self, other):
Expand Down
21 changes: 21 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2831,6 +2831,27 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
f"got {tensordict.batch_size} and {self.batch_size}"
)

def all_actions(
self, tensordict: Optional[TensorDictBase] = None
) -> TensorDictBase:
"""Generates all possible actions from the action spec.
This only works in environments with fully discrete actions.
Args:
tensordict (TensorDictBase, optional): If given, :meth:`~.reset`
is called with this tensordict.
Returns:
a tensordict object with the "action" entry updated with a batch of
all possible actions. The actions are stacked together in the
leading dimension.
"""
if tensordict is not None:
self.reset(tensordict)

return self.full_action_spec.enumerate(use_mask=True)

def rand_action(self, tensordict: Optional[TensorDictBase] = None):
"""Performs a random action given the action_spec attribute.
Expand Down
15 changes: 14 additions & 1 deletion torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import importlib.util
import io
import pathlib
from typing import Dict
from typing import Dict, Optional

import torch
from tensordict import TensorDict, TensorDictBase
Expand Down Expand Up @@ -357,6 +357,19 @@ def __init__(
def _is_done(self, board):
return board.is_game_over() | board.is_fifty_moves()

def all_actions(
self, tensordict: Optional[TensorDictBase] = None
) -> TensorDictBase:
if not self.mask_actions:
raise RuntimeError(
(
"Cannot generate legal actions since 'mask_actions=False' was "
"set. If you really want to generate all actions, not just "
"legal ones, call 'env.full_action_spec.enumerate()'."
)
)
return super().all_actions(tensordict)

def _reset(self, tensordict=None):
fen = None
pgn = None
Expand Down

0 comments on commit 67c3e9a

Please sign in to comment.