Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PettingZoo Parallel API and action mask #305

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco, envpool]'
pip install -e '.[atari, mujoco, envpool, pettingzoo]'
conda list
- name: Install test dependencies
run: |
Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco]'
pip install -e '.[atari, mujoco, pettingzoo]'
conda list
- name: Install test dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ check-codestyle:
.PHONY: test

test:
pytest -s --maxfail=2
pytest -s --maxfail=2 -rA
# ; echo "Tests finished. You might need to type 'reset' and press Enter to fix the terminal window"


Expand Down
38 changes: 38 additions & 0 deletions docs/07-advanced-topics/action-masking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Action Masking
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation is wonderful, thank you!


Action masking is a technique used to restrict the set of actions available to an agent in certain states. This can be particularly useful in environments where some actions are invalid or undesirable in specific situations. See [paper](https://arxiv.org/abs/2006.14171) for more details.

## Implementing Action Masking

To implement action masking in your environment, you need to add an `action_mask` field to the observation dictionary returned by your environment. Here's how to do it:

1. Define the action mask space in your environment's observation space
2. Generate and return the action mask in both `reset()` and `step()` methods

Here's an example of a custom environment implementing action masking:

```python
import gymnasium as gym
import numpy as np

class CustomEnv(gym.Env):
def __init__(self, full_env_name, cfg, render_mode=None):
...
self.observation_space = gym.spaces.Dict({
"obs": gym.spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: I wonder if low=0 high=1 here is intentional, would this mean binary observations?

I understand 0/1 in action_mask since this is a binary mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional since it's retrieved from tic tac toe of PettingZoo but I think we can change if it's confusing.

"action_mask": gym.spaces.Box(low=0, high=1, shape=(9,), dtype=np.int8),
})
self.action_space = gym.spaces.Discrete(9)

def reset(self, **kwargs):
...
# Initial action mask that allows all actions
action_mask = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1])
return {"obs": obs, "action_mask": action_mask}, info

def step(self, action):
...
# Generate new action mask based on the current state
action_mask = np.array([1, 0, 0, 1, 1, 1, 0, 1, 1])
return {"obs": obs, "action_mask": action_mask}, reward, terminated, truncated, info
```
46 changes: 46 additions & 0 deletions docs/09-environment-integrations/pettingzoo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# PettingZoo
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this. Thank you!


[PettingZoo](https://pettingzoo.farama.org/) is a Python library for conducting research in multi-agent reinforcement learning. This guide explains how to use PettingZoo environments with Sample Factory.

## Installation

Install Sample Factory with PettingZoo dependencies with PyPI:

```bash
pip install -e sample-factory[pettingzoo]
```

## Running Experiments

Run PettingZoo experiments with the scripts in `sf_examples`.
The default parameters are not tuned for throughput.

To train a model in the `tictactoe_v3` environment:

```
python -m sf_examples.train_pettingzoo_env --algo=APPO --env=tictactoe_v3 --experiment="Experiment Name"
```

To visualize the training results, use the `enjoy_pettingzoo_env` script:

```
python -m sf_examples.enjoy_pettingzoo_env --env=tictactoe_v3 --experiment="Experiment Name"
```

Currently, the scripts in `sf_examples` are set up for the `tictactoe_v3` environment. To use other PettingZoo environments, you'll need to modify the scripts or add your own as explained below.

### Adding a new PettingZoo environment

To add a new PettingZoo environment, follow the instructions from [Custom environments](../03-customization/custom-environments.md), with the additional step of wrapping your PettingZoo environment with `sample_factory.envs.pettingzoo_envs.PettingZooParallelEnv`.

Here's an example of how to create a factory function for a PettingZoo environment:

```python
from sample_factory.envs.pettingzoo_envs import PettingZooParallelEnv
import some_pettingzoo_env # Import your desired PettingZoo environment

def make_pettingzoo_env(full_env_name, cfg=None, env_config=None, render_mode=None):
return PettingZooParallelEnv(some_pettingzoo_env.parallel_env(render_mode=render_mode))
```

Note: Sample Factory supports only the [Parallel API](https://pettingzoo.farama.org/api/parallel/) of PettingZoo. If your environment uses the AEC API, you can convert it to Parallel API using `pettingzoo.utils.conversions.aec_to_parallel` or `pettingzoo.utils.conversions.turn_based_aec_to_parallel`. Be aware that these conversions have some limitations. For more details, refer to the [PettingZoo documentation](https://pettingzoo.farama.org/api/wrappers/pz_wrappers/).
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ nav:
- 07-advanced-topics/passing-info.md
- 07-advanced-topics/observer.md
- 07-advanced-topics/profiling.md
- 07-advanced-topics/action-masking.md
- Miscellaneous:
- 08-miscellaneous/tests.md
- 08-miscellaneous/v1-to-v2.md
Expand All @@ -170,6 +171,7 @@ nav:
- 09-environment-integrations/nethack.md
- 09-environment-integrations/brax.md
- 09-environment-integrations/swarm-rl.md
- 09-environment-integrations/pettingzoo.md
- Huggingface Integration:
- 10-huggingface/huggingface.md
- Release Notes:
Expand Down
5 changes: 4 additions & 1 deletion sample_factory/algo/sampling/inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,14 @@ def _handle_policy_steps(self, timing):
if actor_critic.training:
actor_critic.eval() # need to call this because we can be in serial mode

action_mask = (
ensure_torch_tensor(obs.pop("action_mask")).to(self.device) if "action_mask" in obs else None
)
normalized_obs = prepare_and_normalize_obs(actor_critic, obs)
rnn_states = ensure_torch_tensor(rnn_states).to(self.device).float()

with timing.add_time("forward"):
policy_outputs = actor_critic(normalized_obs, rnn_states)
policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask)
policy_outputs["policy_version"] = torch.empty([num_samples]).fill_(self.param_client.policy_version)

with timing.add_time("prepare_outputs"):
Expand Down
2 changes: 1 addition & 1 deletion sample_factory/algo/sampling/non_batched_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _reset(self):

log.info("Decorrelating experience for %d frames...", decorrelate_steps)
for decorrelate_step in range(decorrelate_steps):
actions = [e.action_space.sample() for _ in range(self.num_agents)]
actions = [e.action_space.sample(obs.get("action_mask")) for obs in observations]
observations, rew, terminated, truncated, info = e.step(actions)

for agent_i, obs in enumerate(observations):
Expand Down
58 changes: 48 additions & 10 deletions sample_factory/algo/utils/action_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def is_continuous_action_space(action_space: ActionSpace) -> bool:
return isinstance(action_space, gym.spaces.Box)


def get_action_distribution(action_space, raw_logits):
def get_action_distribution(action_space, raw_logits, action_mask=None):
"""
Create the distribution object based on provided action space and unprocessed logits.
:param action_space: Gym action space object
Expand All @@ -52,9 +52,9 @@ def get_action_distribution(action_space, raw_logits):
assert calc_num_action_parameters(action_space) == raw_logits.shape[-1]

if isinstance(action_space, gym.spaces.Discrete):
return CategoricalActionDistribution(raw_logits)
return CategoricalActionDistribution(raw_logits, action_mask)
elif isinstance(action_space, gym.spaces.Tuple):
return TupleActionDistribution(action_space, logits_flat=raw_logits)
return TupleActionDistribution(action_space, logits_flat=raw_logits, action_mask=action_mask)
elif isinstance(action_space, gym.spaces.Box):
return ContinuousActionDistribution(params=raw_logits)
else:
Expand All @@ -81,35 +81,71 @@ def argmax_actions(distribution):
raise NotImplementedError(f"Action distribution type {type(distribution)} does not support argmax!")


# Retrieved from AllenNLP:
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L243
def masked_softmax(logits, mask):
# To limit numerical errors from large vector elements outside the mask, we zero these out.
result = functional.softmax(logits * mask, dim=-1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand this please?

I think logits in general can be negative, or positive but close to 0, in which case multiplying them by zero does not achieve the desired effect.

I'd say we should probably use something like this instead?

def masked_softmax(logits, mask):
    # Mask out the invalid logits by adding a large negative number (-1e9)
    logits = logits + (mask == 0) * -1e9
    result = functional.softmax(logits, dim=-1)
    result = result * mask
    result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
    return result

the choice of 1e-9 is arbitrary here, but it could be something like -max(abs(logits)) * 1e6 to make this universal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's got from AllenNLP including the comment so don't fully understand but as far as I investigated seems your version is safer in some cases even tho usually the results are identical in both versions 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

result = result * mask
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result


# Retrieved from AllenNLP:
# https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py#L286
def masked_log_softmax(logits, mask):
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
logits = logits + (mask + 1e-13).log()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes more sense to me, this is essentially adding log(1e-13) to non-valid elements which is about -30. I'm not sure if this is universally correct, but most likely should work. Why can't we just explicitly add a large negative constant though, like -1e9 or -max(abs(logits)) * 1e6 like in the previous example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you're correct. This version causes a problem in extreme cases as far as I tested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return functional.log_softmax(logits, dim=-1)


# noinspection PyAbstractClass
class CategoricalActionDistribution:
def __init__(self, raw_logits):
def __init__(self, raw_logits, action_mask=None):
"""
Ctor.
:param raw_logits: unprocessed logits, typically an output of a fully-connected layer
"""

self.raw_logits = raw_logits
self.action_mask = action_mask
self.log_p = self.p = None

@property
def probs(self):
if self.p is None:
self.p = functional.softmax(self.raw_logits, dim=-1)
if self.action_mask is not None:
self.p = masked_softmax(self.raw_logits, self.action_mask)
else:
self.p = functional.softmax(self.raw_logits, dim=-1)
return self.p

@property
def log_probs(self):
if self.log_p is None:
self.log_p = functional.log_softmax(self.raw_logits, dim=-1)
if self.action_mask is not None:
self.log_p = masked_log_softmax(self.raw_logits, self.action_mask)
else:
self.log_p = functional.log_softmax(self.raw_logits, dim=-1)
return self.log_p

def sample_gumbel(self):
sample = torch.argmax(self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_(), -1)
probs = self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_()
if self.action_mask is not None:
probs = probs * self.action_mask
sample = torch.argmax(probs, -1)
return sample

def sample(self):
samples = torch.multinomial(self.probs, 1, True)
probs = self.probs
if self.action_mask is not None:
all_zero = (probs.sum(dim=-1) == 0).unsqueeze(-1)
epsilons = torch.full_like(probs, 1e-6)
probs = torch.where(all_zero, epsilons, probs) # ensure sum of probabilities is non-zero

samples = torch.multinomial(probs, 1, True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just checking if we don't have to re-normalize the probabilities here so they add up to 1.
Does torch.multinomial do this internally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, it seems there is no need to add up to 1, according to the doc:

The rows of input do not need to sum to one (in which case we use > the values as weights), ...

https://pytorch.org/docs/stable/generated/torch.multinomial.html

But I'm not so sure honestly (I'm a newbie on RL). So please feel free to fix if you see something wrong with the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, seems it requires to normalize the value with softmax as far as I understand, so implemented that.

return samples

def log_prob(self, value):
Expand Down Expand Up @@ -181,16 +217,18 @@ class TupleActionDistribution:

"""

def __init__(self, action_space, logits_flat):
def __init__(self, action_space, logits_flat, action_mask=None):
self.logit_lengths = [calc_num_action_parameters(s) for s in action_space.spaces]
self.split_logits = torch.split(logits_flat, self.logit_lengths, dim=1)
self.action_lengths = [calc_num_actions(s) for s in action_space.spaces]
self.action_mask = action_mask

assert len(self.split_logits) == len(action_space.spaces)

self.distributions = []
for i, space in enumerate(action_space.spaces):
self.distributions.append(get_action_distribution(space, self.split_logits[i]))
action_mask = self.action_mask[i] if self.action_mask is not None else None
self.distributions.append(get_action_distribution(space, self.split_logits[i], action_mask))

@staticmethod
def _flatten_actions(list_of_action_batches):
Expand Down
4 changes: 4 additions & 0 deletions sample_factory/algo/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def set_global_context(ctx: SampleFactoryContext):


def reset_global_context():
"""
Most useful in tests, call this after any part of the global context has been modified
by a test in any way.
"""
global GLOBAL_CONTEXT
GLOBAL_CONTEXT = SampleFactoryContext()

Expand Down
4 changes: 3 additions & 1 deletion sample_factory/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def max_frames_reached(frames):
reward_list = []

obs, infos = env.reset()
action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None
rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device)
episode_reward = None
finished_episode = [False for _ in range(env.num_agents)]
Expand All @@ -149,7 +150,7 @@ def max_frames_reached(frames):

if not cfg.no_render:
visualize_policy_inputs(normalized_obs)
policy_outputs = actor_critic(normalized_obs, rnn_states)
policy_outputs = actor_critic(normalized_obs, rnn_states, action_mask=action_mask)

# sample actions from the distribution by default
actions = policy_outputs["actions"]
Expand All @@ -169,6 +170,7 @@ def max_frames_reached(frames):
last_render_start = render_frame(cfg, env, video_frames, num_episodes, last_render_start)

obs, rew, terminated, truncated, infos = env.step(actions)
action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None
dones = make_dones(terminated, truncated)
infos = [{} for _ in range(env_info.num_agents)] if infos is None else infos

Expand Down
79 changes: 79 additions & 0 deletions sample_factory/envs/pettingzoo_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Gym env wrappers for PettingZoo -> Gymnasium transition.
"""

import gymnasium as gym


class PettingZooParallelEnv(gym.Env):
def __init__(self, env):
if not all_equal([env.observation_space(a) for a in env.possible_agents]):
raise ValueError("All observation spaces must be equal")

if not all_equal([env.action_space(a) for a in env.possible_agents]):
raise ValueError("All action spaces must be equal")

self.env = env
self.metadata = env.metadata
self.render_mode = env.render_mode if hasattr(env, "render_mode") else env.unwrapped.render_mode
self.observation_space = normalize_observation_space(env.observation_space(env.possible_agents[0]))
self.action_space = env.action_space(env.possible_agents[0])
self.num_agents = env.max_num_agents
self.is_multiagent = True

def reset(self, **kwargs):
obs, infos = self.env.reset(**kwargs)
obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents]
infos = [infos[a] if a in infos else dict(is_active=False) for a in self.env.possible_agents]
return obs, infos

def step(self, actions):
actions = dict(zip(self.env.possible_agents, actions))
obs, rewards, terminations, truncations, infos = self.env.step(actions)

if not self.env.agents:
obs, infos = self.env.reset()

obs = [normalize_observation(obs.get(a)) for a in self.env.possible_agents]
rewards = [rewards.get(a) for a in self.env.possible_agents]
terminations = [terminations.get(a) for a in self.env.possible_agents]
truncations = [truncations.get(a) for a in self.env.possible_agents]
infos = [normalize_info(infos[a], a) if a in infos else dict(is_active=False) for a in self.env.possible_agents]
return obs, rewards, terminations, truncations, infos

def render(self):
return self.env.render()

def close(self):
self.env.close()


def all_equal(l_) -> bool:
return all(v == l_[0] for v in l_)


def normalize_observation_space(obs_space):
"""Normalize observation space with the key "obs" that's specially handled as the main value."""
if isinstance(obs_space, gym.spaces.Dict) and "observation" in obs_space.spaces:
spaces = dict(obs_space.spaces)
spaces["obs"] = spaces["observation"]
del spaces["observation"]
obs_space = gym.spaces.Dict(spaces)

return obs_space


def normalize_observation(obs):
if isinstance(obs, dict) and "observation" in obs:
obs["obs"] = obs["observation"]
del obs["observation"]

return obs


def normalize_info(info, agent):
"""active_agent is available when using `turn_based_aec_to_parallel` of PettingZoo."""
if isinstance(info, dict) and "active_agent" in info:
info["is_active"] = info["active_agent"] == agent

return info
Loading
Loading