-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support PettingZoo Parallel API and action mask #305
Changes from 13 commits
d776e31
f66b514
12cd0f6
a3479be
b289a3f
33652c4
2b60e51
4a0de3e
9d561e2
eaca450
f6899b6
8ba3044
1a66a74
b46ee0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Action Masking | ||
|
||
Action masking is a technique used to restrict the set of actions available to an agent in certain states. This can be particularly useful in environments where some actions are invalid or undesirable in specific situations. See [paper](https://arxiv.org/abs/2006.14171) for more details. | ||
|
||
## Implementing Action Masking | ||
|
||
To implement action masking in your environment, you need to add an `action_mask` field to the observation dictionary returned by your environment. Here's how to do it: | ||
|
||
1. Define the action mask space in your environment's observation space | ||
2. Generate and return the action mask in both `reset()` and `step()` methods | ||
|
||
Here's an example of a custom environment implementing action masking: | ||
|
||
```python | ||
import gymnasium as gym | ||
import numpy as np | ||
|
||
class CustomEnv(gym.Env): | ||
def __init__(self, full_env_name, cfg, render_mode=None): | ||
... | ||
self.observation_space = gym.spaces.Dict({ | ||
"obs": gym.spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small nit: I wonder if low=0 high=1 here is intentional, would this mean binary observations? I understand 0/1 in action_mask since this is a binary mask There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is intentional since it's retrieved from tic tac toe of PettingZoo but I think we can change if it's confusing. |
||
"action_mask": gym.spaces.Box(low=0, high=1, shape=(9,), dtype=np.int8), | ||
}) | ||
self.action_space = gym.spaces.Discrete(9) | ||
|
||
def reset(self, **kwargs): | ||
... | ||
# Initial action mask that allows all actions | ||
action_mask = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1]) | ||
return {"obs": obs, "action_mask": action_mask}, info | ||
|
||
def step(self, action): | ||
... | ||
# Generate new action mask based on the current state | ||
action_mask = np.array([1, 0, 0, 1, 1, 1, 0, 1, 1]) | ||
return {"obs": obs, "action_mask": action_mask}, reward, terminated, truncated, info | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# PettingZoo | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love this. Thank you! |
||
|
||
[PettingZoo](https://pettingzoo.farama.org/) is a Python library for conducting research in multi-agent reinforcement learning. This guide explains how to use PettingZoo environments with Sample Factory. | ||
|
||
## Installation | ||
|
||
Install Sample Factory with PettingZoo dependencies with PyPI: | ||
|
||
```bash | ||
pip install -e sample-factory[pettingzoo] | ||
``` | ||
|
||
## Running Experiments | ||
|
||
Run PettingZoo experiments with the scripts in `sf_examples`. | ||
The default parameters are not tuned for throughput. | ||
|
||
To train a model in the `tictactoe_v3` environment: | ||
|
||
``` | ||
python -m sf_examples.train_pettingzoo_env --algo=APPO --env=tictactoe_v3 --experiment="Experiment Name" | ||
``` | ||
|
||
To visualize the training results, use the `enjoy_pettingzoo_env` script: | ||
|
||
``` | ||
python -m sf_examples.enjoy_pettingzoo_env --env=tictactoe_v3 --experiment="Experiment Name" | ||
``` | ||
|
||
Currently, the scripts in `sf_examples` are set up for the `tictactoe_v3` environment. To use other PettingZoo environments, you'll need to modify the scripts or add your own as explained below. | ||
|
||
### Adding a new PettingZoo environment | ||
|
||
To add a new PettingZoo environment, follow the instructions from [Custom environments](../03-customization/custom-environments.md), with the additional step of wrapping your PettingZoo environment with `sample_factory.envs.pettingzoo_envs.PettingZooParallelEnv`. | ||
|
||
Here's an example of how to create a factory function for a PettingZoo environment: | ||
|
||
```python | ||
from sample_factory.envs.pettingzoo_envs import PettingZooParallelEnv | ||
import some_pettingzoo_env # Import your desired PettingZoo environment | ||
|
||
def make_pettingzoo_env(full_env_name, cfg=None, env_config=None, render_mode=None): | ||
return PettingZooParallelEnv(some_pettingzoo_env.parallel_env(render_mode=render_mode)) | ||
``` | ||
|
||
Note: Sample Factory supports only the [Parallel API](https://pettingzoo.farama.org/api/parallel/) of PettingZoo. If your environment uses the AEC API, you can convert it to Parallel API using `pettingzoo.utils.conversions.aec_to_parallel` or `pettingzoo.utils.conversions.turn_based_aec_to_parallel`. Be aware that these conversions have some limitations. For more details, refer to the [PettingZoo documentation](https://pettingzoo.farama.org/api/wrappers/pz_wrappers/). |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ def is_continuous_action_space(action_space: ActionSpace) -> bool: | |
return isinstance(action_space, gym.spaces.Box) | ||
|
||
|
||
def get_action_distribution(action_space, raw_logits): | ||
def get_action_distribution(action_space, raw_logits, action_mask=None): | ||
""" | ||
Create the distribution object based on provided action space and unprocessed logits. | ||
:param action_space: Gym action space object | ||
|
@@ -52,9 +52,9 @@ def get_action_distribution(action_space, raw_logits): | |
assert calc_num_action_parameters(action_space) == raw_logits.shape[-1] | ||
|
||
if isinstance(action_space, gym.spaces.Discrete): | ||
return CategoricalActionDistribution(raw_logits) | ||
return CategoricalActionDistribution(raw_logits, action_mask) | ||
elif isinstance(action_space, gym.spaces.Tuple): | ||
return TupleActionDistribution(action_space, logits_flat=raw_logits) | ||
return TupleActionDistribution(action_space, logits_flat=raw_logits, action_mask=action_mask) | ||
elif isinstance(action_space, gym.spaces.Box): | ||
return ContinuousActionDistribution(params=raw_logits) | ||
else: | ||
|
@@ -81,35 +81,71 @@ def argmax_actions(distribution): | |
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!") | ||
|
||
|
||
# Retrieved from AllenNLP: | ||
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L243 | ||
def masked_softmax(logits, mask): | ||
# To limit numerical errors from large vector elements outside the mask, we zero these out. | ||
result = functional.softmax(logits * mask, dim=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you help me understand this please? I think logits in general can be negative, or positive but close to 0, in which case multiplying them by zero does not achieve the desired effect. I'd say we should probably use something like this instead?
the choice of 1e-9 is arbitrary here, but it could be something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's got from AllenNLP including the comment so don't fully understand but as far as I investigated seems your version is safer in some cases even tho usually the results are identical in both versions 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
result = result * mask | ||
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) | ||
return result | ||
|
||
|
||
# Retrieved from AllenNLP: | ||
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L286 | ||
def masked_log_softmax(logits, mask): | ||
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it | ||
# results in nans when the whole vector is masked. We need a very small value instead of a | ||
# zero in the mask for these cases. | ||
logits = logits + (mask + 1e-13).log() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes more sense to me, this is essentially adding log(1e-13) to non-valid elements which is about -30. I'm not sure if this is universally correct, but most likely should work. Why can't we just explicitly add a large negative constant though, like -1e9 or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems you're correct. This version causes a problem in extreme cases as far as I tested. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
return functional.log_softmax(logits, dim=-1) | ||
|
||
|
||
# noinspection PyAbstractClass | ||
class CategoricalActionDistribution: | ||
def __init__(self, raw_logits): | ||
def __init__(self, raw_logits, action_mask=None): | ||
""" | ||
Ctor. | ||
:param raw_logits: unprocessed logits, typically an output of a fully-connected layer | ||
""" | ||
|
||
self.raw_logits = raw_logits | ||
self.action_mask = action_mask | ||
self.log_p = self.p = None | ||
|
||
@property | ||
def probs(self): | ||
if self.p is None: | ||
self.p = functional.softmax(self.raw_logits, dim=-1) | ||
if self.action_mask is not None: | ||
self.p = masked_softmax(self.raw_logits, self.action_mask) | ||
else: | ||
self.p = functional.softmax(self.raw_logits, dim=-1) | ||
return self.p | ||
|
||
@property | ||
def log_probs(self): | ||
if self.log_p is None: | ||
self.log_p = functional.log_softmax(self.raw_logits, dim=-1) | ||
if self.action_mask is not None: | ||
self.log_p = masked_log_softmax(self.raw_logits, self.action_mask) | ||
else: | ||
self.log_p = functional.log_softmax(self.raw_logits, dim=-1) | ||
return self.log_p | ||
|
||
def sample_gumbel(self): | ||
sample = torch.argmax(self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_(), -1) | ||
probs = self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_() | ||
if self.action_mask is not None: | ||
probs = probs * self.action_mask | ||
sample = torch.argmax(probs, -1) | ||
return sample | ||
|
||
def sample(self): | ||
samples = torch.multinomial(self.probs, 1, True) | ||
probs = self.probs | ||
if self.action_mask is not None: | ||
all_zero = (probs.sum(dim=-1) == 0).unsqueeze(-1) | ||
epsilons = torch.full_like(probs, 1e-6) | ||
probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero | ||
|
||
samples = torch.multinomial(probs, 1, True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just checking if we don't have to re-normalize the probabilities here so they add up to 1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically, it seems there is no need to add up to 1, according to the doc:
https://pytorch.org/docs/stable/generated/torch.multinomial.html But I'm not so sure honestly (I'm a newbie on RL). So please feel free to fix if you see something wrong with the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, seems it requires to normalize the value with softmax as far as I understand, so implemented that. |
||
return samples | ||
|
||
def log_prob(self, value): | ||
|
@@ -181,16 +217,18 @@ class TupleActionDistribution: | |
|
||
""" | ||
|
||
def __init__(self, action_space, logits_flat): | ||
def __init__(self, action_space, logits_flat, action_mask=None): | ||
self.logit_lengths = [calc_num_action_parameters(s) for s in action_space.spaces] | ||
self.split_logits = torch.split(logits_flat, self.logit_lengths, dim=1) | ||
self.action_lengths = [calc_num_actions(s) for s in action_space.spaces] | ||
self.action_mask = action_mask | ||
|
||
assert len(self.split_logits) == len(action_space.spaces) | ||
|
||
self.distributions = [] | ||
for i, space in enumerate(action_space.spaces): | ||
self.distributions.append(get_action_distribution(space, self.split_logits[i])) | ||
action_mask = self.action_mask[i] if self.action_mask is not None else None | ||
self.distributions.append(get_action_distribution(space, self.split_logits[i], action_mask)) | ||
|
||
@staticmethod | ||
def _flatten_actions(list_of_action_batches): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Gym env wrappers for PettingZoo -> Gymnasium transition. | ||
""" | ||
|
||
import gymnasium as gym | ||
|
||
|
||
class PettingZooParallelEnv(gym.Env): | ||
def __init__(self, env): | ||
if not all_equal([env.observation_space(a) for a in env.possible_agents]): | ||
raise ValueError("All observation spaces must be equal") | ||
|
||
if not all_equal([env.action_space(a) for a in env.possible_agents]): | ||
raise ValueError("All action spaces must be equal") | ||
|
||
self.env = env | ||
self.metadata = env.metadata | ||
self.render_mode = env.render_mode if hasattr(env, "render_mode") else env.unwrapped.render_mode | ||
self.observation_space = normalize_observation_space(env.observation_space(env.possible_agents[0])) | ||
self.action_space = env.action_space(env.possible_agents[0]) | ||
self.num_agents = env.max_num_agents | ||
self.is_multiagent = True | ||
|
||
def reset(self, **kwargs): | ||
obs, infos = self.env.reset(**kwargs) | ||
obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents] | ||
infos = [infos[a] if a in infos else dict(is_active=False) for a in self.env.possible_agents] | ||
return obs, infos | ||
|
||
def step(self, actions): | ||
actions = dict(zip(self.env.possible_agents, actions)) | ||
obs, rewards, terminations, truncations, infos = self.env.step(actions) | ||
|
||
if not self.env.agents: | ||
obs, infos = self.env.reset() | ||
|
||
obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents] | ||
rewards = [rewards.get(a) for a in self.env.possible_agents] | ||
terminations = [terminations.get(a) for a in self.env.possible_agents] | ||
truncations = [truncations.get(a) for a in self.env.possible_agents] | ||
infos = [normalize_info(infos[a], a) if a in infos else dict(is_active=False) for a in self.env.possible_agents] | ||
return obs, rewards, terminations, truncations, infos | ||
|
||
def render(self): | ||
return self.env.render() | ||
|
||
def close(self): | ||
self.env.close() | ||
|
||
|
||
def all_equal(l_) -> bool: | ||
return all(v == l_[0] for v in l_) | ||
|
||
|
||
def normalize_observation_space(obs_space): | ||
"""Normalize observation space with the key "obs" that's specially handled as the main value.""" | ||
if isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces: | ||
spaces = dict(obs_space.spaces) | ||
spaces["obs"] = spaces["observation"] | ||
del spaces["observation"] | ||
obs_space = gym.spaces.Dict(spaces) | ||
|
||
return obs_space | ||
|
||
|
||
def normalize_observation(obs): | ||
if isinstance(obs, dict) and "observation" in obs: | ||
obs["obs"] = obs["observation"] | ||
del obs["observation"] | ||
|
||
return obs | ||
|
||
|
||
def normalize_info(info, agent): | ||
"""active_agent is available when using `turn_based_aec_to_parallel` of PettingZoo.""" | ||
if isinstance(info, dict) and "active_agent" in info: | ||
info["is_active"] = info["active_agent"] == agent | ||
|
||
return info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This documentation is wonderful, thank you!