diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index cac1ec21..9671f935 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -17,6 +17,10 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y xvfb libglu1-mesa-dev python3-opengl - name: Upgrade pip run: | python -m pip install --upgrade pip setuptools wheel @@ -27,7 +31,7 @@ jobs: - name: do_unittest timeout-minutes: 40 run: | - python3 -m pytest tests --cov=openrl --cov-report=xml -m unittest --cov-report=term-missing --durations=0 -v --color=yes + xvfb-run -s "-screen 0 1400x900x24" python3 -m pytest tests --cov=openrl --cov-report=xml -m unittest --cov-report=term-missing --durations=0 -v --color=yes -s - name: Upload coverage reports to Codecov with GitHub Action uses: codecov/codecov-action@v3 with: diff --git a/.gitignore b/.gitignore index c92a6657..469ab373 100644 --- a/.gitignore +++ b/.gitignore @@ -153,10 +153,11 @@ run_results/ api_docs .vscode *.pkl -api_docs *.json opponent_pool !/examples/selfplay/opponent_templates/tictactoe_opponent/info.json +!/examples/nlp/ds_config.json +!/examples/nlp/eval_ds_config.json wandb_run examples/dmc/new.gif /examples/snake/submissions/rl/actor_2000.pth diff --git a/Project.md b/Project.md index d7c455f5..38a8d1c7 100644 --- a/Project.md +++ b/Project.md @@ -18,7 +18,7 @@ However, in many practical applications, it is important to develop reasonable a In this paper, we propose an on-policy framework for discovering multiple strategies for the same task. Experimental results show that our method efficiently finds diverse strategies in a wide variety of reinforcement learning tasks. -- Paper: [DGPO: Discovering Multiple Strategies with Diversity-Guided Policy Optimization](https://arxiv.org/abs/2207.05631)(AAMAS Extended Abstract 2023) -- Authors: Wenze Chen, Shiyu Huang, Yuan Chiang, Ting Chen, Jun Zhu +- Paper: [DGPO: Discovering Multiple Strategies with Diversity-Guided Policy Optimization](https://arxiv.org/abs/2207.05631)(AAAAI 2024) +- Authors: Wenze Chen, Shiyu Huang, Yuan Chiang, Tim Pearce, Wei-Wei Tu, Ting Chen, Jun Zhu diff --git a/README.md b/README.md index 741fdef2..c76e3691 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
- +
--- @@ -25,10 +25,10 @@ [![Contributors](https://img.shields.io/github/contributors/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/graphs/contributors) [![GitHub license](https://img.shields.io/github/license/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/blob/master/LICENSE) -[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/guvAS2up) +[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/qMbVT2qBhr) [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) -OpenRL-v0.1.7 is updated on Sep 21, 2023 +OpenRL-v0.2.0 is updated on Dec 20, 2023 The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with OpenRL, you can switch to the stable branch. @@ -58,6 +58,8 @@ Currently, the features supported by OpenRL include: - Reinforcement learning training support for natural language tasks (such as dialogue) +- Support [DeepSpeed](https://github.com/microsoft/DeepSpeed) + - Support [Arena](https://openrl-docs.readthedocs.io/en/latest/arena/index.html) , which allows convenient evaluation of various agents (even submissions for [JiDi](https://openrl-docs.readthedocs.io/en/latest/arena/index.html#performing-local-evaluation-of-agents-submitted-to-the-jidi-platform-using-openrl)) in a competitive environment. @@ -160,19 +162,19 @@ Here we provide a table for the comparison of OpenRL and existing popular RL lib OpenRL employs a modular design and high-level abstraction, allowing users to accomplish training for various tasks through a unified and user-friendly interface. -| Library | NLP/RLHF | Multi-agent | Self-Play Training | Offline RL | Bilingual Document | -|:------------------------------------------------------------------:|:------------------:|:--------------------:|:--------------------:|:------------------:|:------------------:| -| **[OpenRL](https://github.com/OpenRL-Lab/openrl)** | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | -| [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) | :x: | :x: | :x: | :x: | :x: | -| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | -| [DI-engine](https://github.com/opendilab/DI-engine/) | :x: | :heavy_check_mark: | not fullly supported | :heavy_check_mark: | :heavy_check_mark: | -| [Tianshou](https://github.com/thu-ml/tianshou) | :x: | not fullly supported | not fullly supported | :heavy_check_mark: | :heavy_check_mark: | -| [MARLlib](https://github.com/Replicable-MARL/MARLlib) | :x: | :heavy_check_mark: | not fullly supported | :x: | :x: | -| [MAPPO Benchmark](https://github.com/marlbenchmark/on-policy) | :x: | :heavy_check_mark: | :x: | :x: | :x: | -| [RL4LMs](https://github.com/allenai/RL4LMs) | :heavy_check_mark: | :x: | :x: | :x: | :x: | -| [trlx](https://github.com/CarperAI/trlx) | :heavy_check_mark: | :x: | :x: | :x: | :x: | -| [trl](https://github.com/huggingface/trl) | :heavy_check_mark: | :x: | :x: | :x: | :x: | -| [TimeChamber](https://github.com/inspirai/TimeChamber) | :x: | :x: | :heavy_check_mark: | :x: | :x: | +| Library | NLP/RLHF | Multi-agent | Self-Play Training | Offline RL | [DeepSpeed](https://github.com/microsoft/DeepSpeed) | +|:------------------------------------------------------------------:|:------------------:|:--------------------:|:--------------------:|:------------------:|:--------------------:| +| **[OpenRL](https://github.com/OpenRL-Lab/openrl)** | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) | :x: | :x: | :x: | :x: | :x: | +| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | +| [DI-engine](https://github.com/opendilab/DI-engine/) | :x: | :heavy_check_mark: | not fullly supported | :heavy_check_mark: | :x: | +| [Tianshou](https://github.com/thu-ml/tianshou) | :x: | not fullly supported | not fullly supported | :heavy_check_mark: | :x: | +| [MARLlib](https://github.com/Replicable-MARL/MARLlib) | :x: | :heavy_check_mark: | not fullly supported | :x: | :x: | +| [MAPPO Benchmark](https://github.com/marlbenchmark/on-policy) | :x: | :heavy_check_mark: | :x: | :x: | :x: | +| [RL4LMs](https://github.com/allenai/RL4LMs) | :heavy_check_mark: | :x: | :x: | :x: | :x: | +| [trlx](https://github.com/CarperAI/trlx) | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | +| [trl](https://github.com/huggingface/trl) | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | +| [TimeChamber](https://github.com/inspirai/TimeChamber) | :x: | :x: | :heavy_check_mark: | :x: | :x: | ## Installation @@ -333,7 +335,7 @@ If you are using OpenRL in your research project, you are also welcome to join t - Join the [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) group to discuss OpenRL usage and development with us. -- Join the [Discord](https://discord.gg/guvAS2up) group to discuss OpenRL usage and development with us. +- Join the [Discord](https://discord.gg/qMbVT2qBhr) group to discuss OpenRL usage and development with us. - Send an E-mail to: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com) - Join the [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions). diff --git a/README_zh.md b/README_zh.md index c8fb4619..e75c76af 100644 --- a/README_zh.md +++ b/README_zh.md @@ -1,5 +1,5 @@
- +
@@ -26,10 +26,10 @@ [![Contributors](https://img.shields.io/github/contributors/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/graphs/contributors) [![GitHub license](https://img.shields.io/github/license/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/blob/master/LICENSE) -[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/guvAS2up) +[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/qMbVT2qBhr) [![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) -OpenRL-v0.1.7 is updated on Sep 21, 2023 +OpenRL-v0.1.10 is updated on Oct 27, 2023 The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with OpenRL, you can switch to the stable branch. @@ -51,6 +51,7 @@ OpenRL基于PyTorch进行开发,目标是为强化学习研究社区提供一 - 支持通过专家数据进行离线强化学习训练 - 支持自博弈训练 - 支持自然语言任务(如对话任务)的强化学习训练 +- 支持[DeepSpeed](https://github.com/microsoft/DeepSpeed) - 支持[竞技场](https://openrl-docs.readthedocs.io/zh/latest/arena/index.html)功能,可以在多智能体对抗性环境中方便地对各种智能体(甚至是[及第平台](https://openrl-docs.readthedocs.io/zh/latest/arena/index.html#openrl)上提交的智能体)进行评测。 - 支持从[Hugging Face](https://huggingface.co/)上导入模型和数据。支持加载Hugging Face上[Stable-baselines3的模型](https://openrl-docs.readthedocs.io/zh/latest/sb3/index.html)来进行测试和训练。 - 提供用户自有环境接入OpenRL的[详细教程](https://openrl-docs.readthedocs.io/zh/latest/custom_env/index.html). @@ -128,18 +129,18 @@ OpenRL-Lab将持续维护和更新OpenRL,欢迎大家加入我们的[开源社 这里我们提供了一个表格,比较了OpenRL和其他常用的强化学习库。 OpenRL采用模块化设计和高层次的抽象,使得用户可以通过统一的简单易用的接口完成各种任务的训练。 -| 强化学习库 | 自然语言任务/RLHF | 多智能体训练 | 自博弈训练 | 离线强化学习 | 双语文档 | +| 强化学习库 | 自然语言任务/RLHF | 多智能体训练 | 自博弈训练 | 离线强化学习 | [DeepSpeed](https://github.com/microsoft/DeepSpeed) | |:------------------------------------------------------------------:|:------------------:|:--------------------:|:--------------------:|:------------------:|:------------------:| | **[OpenRL](https://github.com/OpenRL-Lab/openrl)** | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) | :x: | :x: | :x: | :x: | :x: | | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | -| [DI-engine](https://github.com/opendilab/DI-engine/) | :x: | :heavy_check_mark: | not fullly supported | :heavy_check_mark: | :heavy_check_mark: | -| [Tianshou](https://github.com/thu-ml/tianshou) | :x: | not fullly supported | not fullly supported | :heavy_check_mark: | :heavy_check_mark: | +| [DI-engine](https://github.com/opendilab/DI-engine/) | :x: | :heavy_check_mark: | not fullly supported | :heavy_check_mark: | :x: | +| [Tianshou](https://github.com/thu-ml/tianshou) | :x: | not fullly supported | not fullly supported | :heavy_check_mark: | :x: | | [MARLlib](https://github.com/Replicable-MARL/MARLlib) | :x: | :heavy_check_mark: | not fullly supported | :x: | :x: | | [MAPPO Benchmark](https://github.com/marlbenchmark/on-policy) | :x: | :heavy_check_mark: | :x: | :x: | :x: | | [RL4LMs](https://github.com/allenai/RL4LMs) | :heavy_check_mark: | :x: | :x: | :x: | :x: | -| [trlx](https://github.com/CarperAI/trlx) | :heavy_check_mark: | :x: | :x: | :x: | :x: | -| [trl](https://github.com/huggingface/trl) | :heavy_check_mark: | :x: | :x: | :x: | :x: | +| [trlx](https://github.com/CarperAI/trlx) | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | +| [trl](https://github.com/huggingface/trl) | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | [TimeChamber](https://github.com/inspirai/TimeChamber) | :x: | :x: | :heavy_check_mark: | :x: | :x: | ## 安装 @@ -293,7 +294,7 @@ openrl --mode train --env CartPole-v1 - 加入 [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) 群组,与我们一起讨论OpenRL的使用和开发。 -- 加入 [Discord](https://discord.gg/guvAS2up) 群组,与我们一起讨论OpenRL的使用和开发。 +- 加入 [Discord](https://discord.gg/qMbVT2qBhr) 群组,与我们一起讨论OpenRL的使用和开发。 - 发送邮件到: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com) - 加入 [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions) diff --git a/examples/arena/README.md b/examples/arena/README.md index e9d59b91..940bea33 100644 --- a/examples/arena/README.md +++ b/examples/arena/README.md @@ -3,6 +3,7 @@ ```bash pip install "openrl[selfplay]" +pip install "pettingzoo[mpe]","pettingzoo[butterfly]" ``` ### Usage @@ -15,3 +16,11 @@ python run_arena.py ### Evaluate Google Research Football submissions for JiDi locally If you want to evaluate your Google Research Football submissions for JiDi locally, please try to use tizero as illustrated [here](foothttps://github.com/OpenRL-Lab/TiZero#evaluate-jidi-submissions-locally). + +### Evaluate more environments + +We also provide a script to evaluate more environments, including MPE, Go, Texas Holdem, Butterfly. You can run the script as follows: + +```shell +python evaluate_more_envs.py +``` \ No newline at end of file diff --git a/examples/arena/evaluate_more_envs.py b/examples/arena/evaluate_more_envs.py new file mode 100644 index 00000000..3b7bfe07 --- /dev/null +++ b/examples/arena/evaluate_more_envs.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from pettingzoo.butterfly import cooperative_pong_v5 +from pettingzoo.classic import connect_four_v3, go_v5, rps_v2, texas_holdem_no_limit_v6 +from pettingzoo.mpe import simple_push_v3 + +from openrl.arena import make_arena +from openrl.arena.agents.local_agent import LocalAgent +from openrl.arena.agents.random_agent import RandomAgent +from openrl.envs.PettingZoo.registration import register +from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner + + +def ConnectFourEnv(render_mode, **kwargs): + return connect_four_v3.env(render_mode) + + +def RockPaperScissorsEnv(render_mode, **kwargs): + return rps_v2.env(num_actions=3, max_cycles=15) + + +def GoEnv(render_mode, **kwargs): + return go_v5.env(render_mode=render_mode, board_size=5, komi=7.5) + + +def TexasHoldemEnv(render_mode, **kwargs): + return texas_holdem_no_limit_v6.env(render_mode=render_mode) + + +# MPE +def SimplePushEnv(render_mode, **kwargs): + return simple_push_v3.env(render_mode=render_mode) + + +def CooperativePongEnv(render_mode, **kwargs): + return cooperative_pong_v5.env(render_mode=render_mode) + + +def register_new_envs(): + new_env_dict = { + "connect_four_v3": ConnectFourEnv, + "RockPaperScissors": RockPaperScissorsEnv, + "go_v5": GoEnv, + "texas_holdem_no_limit_v6": TexasHoldemEnv, + "simple_push_v3": SimplePushEnv, + "cooperative_pong_v5": CooperativePongEnv, + } + + for env_id, env in new_env_dict.items(): + register(env_id, env) + return new_env_dict.keys() + + +def run_arena( + env_id: str, + parallel: bool = True, + seed=0, + total_games: int = 10, + max_game_onetime: int = 5, +): + env_wrappers = [RecordWinner] + + arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False) + + agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent") + agent2 = RandomAgent() + + arena.reset( + agents={"agent1": agent1, "agent2": agent2}, + total_games=total_games, + max_game_onetime=max_game_onetime, + seed=seed, + ) + result = arena.run(parallel=parallel) + arena.close() + print(result) + return result + + +def test_new_envs(): + env_ids = register_new_envs() + seed = 0 + for env_id in env_ids: + run_arena(env_id=env_id, seed=seed, parallel=False, total_games=1) + + +if __name__ == "__main__": + test_new_envs() diff --git a/examples/arena/run_arena.py b/examples/arena/run_arena.py index e880884c..fdc0776a 100644 --- a/examples/arena/run_arena.py +++ b/examples/arena/run_arena.py @@ -17,6 +17,7 @@ """""" from openrl.arena import make_arena from openrl.arena.agents.local_agent import LocalAgent +from openrl.arena.agents.random_agent import RandomAgent from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner @@ -37,7 +38,7 @@ def run_arena( arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=use_tqdm) agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent") - agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent") + agent2 = RandomAgent() arena.reset( agents={"agent1": agent1, "agent2": agent2}, @@ -52,5 +53,12 @@ def run_arena( if __name__ == "__main__": - run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10) - # run_arena(render=False, parallel=False, seed=1, total_games=1, max_game_onetime=1,use_tqdm=False) + # run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10) + run_arena( + render=False, + parallel=False, + seed=1, + total_games=300, + max_game_onetime=1, + use_tqdm=False, + ) diff --git a/examples/atari/train_ppo.py b/examples/atari/train_ppo.py index 4f122c40..5920e819 100644 --- a/examples/atari/train_ppo.py +++ b/examples/atari/train_ppo.py @@ -59,7 +59,6 @@ def train(): agent = Agent(net, use_wandb=True) # start training, set total number of training steps to 20000 - # agent.train(total_time_steps=1000) agent.train(total_time_steps=5000000) env.close() agent.save("./ppo_agent/") diff --git a/examples/behavior_cloning/test_env.py b/examples/behavior_cloning/test_env.py index 60b272c6..fe0fa1b7 100644 --- a/examples/behavior_cloning/test_env.py +++ b/examples/behavior_cloning/test_env.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/behavior_cloning/train_bc.py b/examples/behavior_cloning/train_bc.py index 0d562dee..16d2ef2f 100644 --- a/examples/behavior_cloning/train_bc.py +++ b/examples/behavior_cloning/train_bc.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/cartpole/train_a2c.py b/examples/cartpole/train_a2c.py index 415f0bba..35ca95a9 100644 --- a/examples/cartpole/train_a2c.py +++ b/examples/cartpole/train_a2c.py @@ -1,4 +1,5 @@ """""" + import numpy as np import torch diff --git a/examples/cartpole/train_dqn_beta.py b/examples/cartpole/train_dqn_beta.py index 2dffaa81..3e32ec28 100644 --- a/examples/cartpole/train_dqn_beta.py +++ b/examples/cartpole/train_dqn_beta.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/cartpole/train_ppo.py b/examples/cartpole/train_ppo.py index ee11f871..77e41008 100644 --- a/examples/cartpole/train_ppo.py +++ b/examples/cartpole/train_ppo.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/custom_env/pettingzoo_env.py b/examples/custom_env/pettingzoo_env.py index 8b173449..d5644b7b 100644 --- a/examples/custom_env/pettingzoo_env.py +++ b/examples/custom_env/pettingzoo_env.py @@ -25,6 +25,7 @@ from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper register("RockPaperScissors", RockPaperScissors) + env = make( "RockPaperScissors", env_num=10, diff --git a/examples/custom_env/rock_paper_scissors.py b/examples/custom_env/rock_paper_scissors.py index 2811a1ff..f18e1841 100644 --- a/examples/custom_env/rock_paper_scissors.py +++ b/examples/custom_env/rock_paper_scissors.py @@ -18,6 +18,7 @@ import functools +import time import gymnasium import numpy as np @@ -54,7 +55,7 @@ class RockPaperScissors(AECEnv): metadata = {"render_modes": ["human"], "name": "rps_v2"} - def __init__(self, render_mode=None): + def __init__(self, id, render_mode=None): """ The init method takes in environment arguments and should define the following attributes: @@ -122,8 +123,8 @@ def observe(self, agent): """ # observation of one agent is the previous state of the other # return np.array(self.observations[agent]) - obs = np.zeros(4, dtype=np.int64) - obs[self.observations[agent]] = 1 + obs = np.zeros([1, 4], dtype=np.int64) + obs[0, self.observations[agent]] = 1 return obs def close(self): @@ -182,6 +183,7 @@ def step(self, action): # handles stepping an agent which is already dead # accepts a None action for the one agent, and moves the agent_selection to # the next dead agent, or if there are no more dead agents, to the next live agent + action = None self._was_dead_step(action) return diff --git a/examples/ddpg/train_ddpg_beta.py b/examples/ddpg/train_ddpg_beta.py index 2a19f557..7ba61ee0 100644 --- a/examples/ddpg/train_ddpg_beta.py +++ b/examples/ddpg/train_ddpg_beta.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/envpool/README.md b/examples/envpool/README.md new file mode 100644 index 00000000..e9a16389 --- /dev/null +++ b/examples/envpool/README.md @@ -0,0 +1,20 @@ +## Installation + + +Install envpool with: + +``` shell +pip install envpool +``` + +Note 1: envpool only supports Linux operating system. + +## Usage + +You can use `OpenRL` to train Cartpole (envpool) via: + +``` shell +PYTHON_PATH train_ppo.py +``` + +You can also add custom wrappers in `envpool_wrapper.py`. Currently we have `VecAdapter` and `VecMonitor` wrappers. \ No newline at end of file diff --git a/examples/envpool/envpool_wrappers.py b/examples/envpool/envpool_wrappers.py new file mode 100644 index 00000000..bf975166 --- /dev/null +++ b/examples/envpool/envpool_wrappers.py @@ -0,0 +1,181 @@ +import time +import warnings +from typing import Optional + +import gym +import gymnasium +import numpy as np +from envpool.python.protocol import EnvPool +from packaging import version +from stable_baselines3.common.vec_env import VecEnvWrapper as BaseWrapper +from stable_baselines3.common.vec_env import VecMonitor +from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn + +is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0") + + +class VecEnvWrapper(BaseWrapper): + @property + def agent_num(self): + if self.is_original_envpool_env(): + return 1 + else: + return self.env.agent_num + + def is_original_envpool_env(self): + return not hasattr(self.venv, "agent_num`") + + +class VecAdapter(VecEnvWrapper): + """ + Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv. + + :param venv: The envpool object. + """ + + def __init__(self, venv: EnvPool): + venv.num_envs = venv.spec.config.num_envs + observation_space = venv.observation_space + new_observation_space = gymnasium.spaces.Box( + low=observation_space.low, + high=observation_space.high, + dtype=observation_space.dtype, + ) + action_space = venv.action_space + if isinstance(action_space, gym.spaces.Discrete): + new_action_space = gymnasium.spaces.Discrete(action_space.n) + elif isinstance(action_space, gym.spaces.MultiDiscrete): + new_action_space = gymnasium.spaces.MultiDiscrete(action_space.nvec) + elif isinstance(action_space, gym.spaces.MultiBinary): + new_action_space = gymnasium.spaces.MultiBinary(action_space.n) + elif isinstance(action_space, gym.spaces.Box): + new_action_space = gymnasium.spaces.Box( + low=action_space.low, + high=action_space.high, + dtype=action_space.dtype, + ) + else: + raise NotImplementedError(f"Action space {action_space} is not supported") + super().__init__( + venv=venv, + observation_space=new_observation_space, + action_space=new_action_space, + ) + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + def reset(self) -> VecEnvObs: + if is_legacy_gym: + return self.venv.reset(), {} + else: + return self.venv.reset() + + def step_wait(self) -> VecEnvStepReturn: + if is_legacy_gym: + obs, rewards, dones, info_dict = self.venv.step(self.actions) + else: + obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions) + dones = terms + truncs + rewards = rewards + infos = [] + for i in range(self.num_envs): + infos.append( + { + key: info_dict[key][i] + for key in info_dict.keys() + if isinstance(info_dict[key], np.ndarray) + } + ) + if dones[i]: + infos[i]["terminal_observation"] = obs[i] + if is_legacy_gym: + obs[i] = self.venv.reset(np.array([i])) + else: + obs[i] = self.venv.reset(np.array([i]))[0] + return obs, rewards, dones, infos + + +class VecMonitor(VecEnvWrapper): + def __init__( + self, + venv, + filename: Optional[str] = None, + info_keywords=(), + ): + # Avoid circular import + from stable_baselines3.common.monitor import Monitor, ResultsWriter + + try: + is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] + except AttributeError: + is_wrapped_with_monitor = False + + if is_wrapped_with_monitor: + warnings.warn( + "The environment is already wrapped with a `Monitor` wrapperbut you are" + " wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics" + " will beoverwritten by the `VecMonitor` ones.", + UserWarning, + ) + + VecEnvWrapper.__init__(self, venv) + self.episode_count = 0 + self.t_start = time.time() + + env_id = None + if hasattr(venv, "spec") and venv.spec is not None: + env_id = venv.spec.id + + self.results_writer: Optional[ResultsWriter] = None + if filename: + self.results_writer = ResultsWriter( + filename, + header={"t_start": self.t_start, "env_id": str(env_id)}, + extra_keys=info_keywords, + ) + + self.info_keywords = info_keywords + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + + def reset(self, **kwargs) -> VecEnvObs: + obs, info = self.venv.reset() + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return obs, info + + def step_wait(self) -> VecEnvStepReturn: + obs, rewards, dones, infos = self.venv.step_wait() + self.episode_returns += rewards + self.episode_lengths += 1 + new_infos = list(infos[:]) + for i in range(len(dones)): + if dones[i]: + info = infos[i].copy() + episode_return = self.episode_returns[i] + episode_length = self.episode_lengths[i] + episode_info = { + "r": episode_return, + "l": episode_length, + "t": round(time.time() - self.t_start, 6), + } + for key in self.info_keywords: + episode_info[key] = info[key] + info["episode"] = episode_info + self.episode_count += 1 + self.episode_returns[i] = 0 + self.episode_lengths[i] = 0 + if self.results_writer: + self.results_writer.write_row(episode_info) + new_infos[i] = info + rewards = np.expand_dims(rewards, 1) + return obs, rewards, dones, new_infos + + def close(self) -> None: + if self.results_writer: + self.results_writer.close() + return self.venv.close() + + +__all__ = ["VecAdapter", "VecMonitor"] diff --git a/examples/envpool/make_env.py b/examples/envpool/make_env.py new file mode 100644 index 00000000..669ca67a --- /dev/null +++ b/examples/envpool/make_env.py @@ -0,0 +1,131 @@ +import copy +import inspect +from typing import Callable, Iterable, List, Optional, Union + +import envpool +from gymnasium import Env + +from openrl.envs.vec_env import ( + AsyncVectorEnv, + RewardWrapper, + SyncVectorEnv, + VecMonitorWrapper, +) +from openrl.envs.vec_env.vec_info import VecInfoFactory +from openrl.envs.wrappers.base_wrapper import BaseWrapper +from openrl.rewards import RewardFactory + + +def build_envs( + make, + id: str, + env_num: int = 1, + wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None, + need_env_id: bool = False, + **kwargs, +) -> List[Callable[[], Env]]: + cfg = kwargs.get("cfg", None) + + def create_env(env_id: int, env_num: int, need_env_id: bool) -> Callable[[], Env]: + def _make_env() -> Env: + new_kwargs = copy.deepcopy(kwargs) + if need_env_id: + new_kwargs["env_id"] = env_id + new_kwargs["env_num"] = env_num + if "envpool" in new_kwargs: + # for now envpool doesnt support any render mode + # envpool also doesnt stores the id anywhere + new_kwargs.pop("envpool") + env = make( + id, + **new_kwargs, + ) + env.unwrapped.spec.id = id + + if wrappers is not None: + if callable(wrappers): + if issubclass(wrappers, BaseWrapper): + env = wrappers(env, cfg=cfg) + else: + env = wrappers(env) + elif isinstance(wrappers, Iterable) and all( + [callable(w) for w in wrappers] + ): + for wrapper in wrappers: + if ( + issubclass(wrapper, BaseWrapper) + and "cfg" in inspect.signature(wrapper.__init__).parameters + ): + env = wrapper(env, cfg=cfg) + else: + env = wrapper(env) + else: + raise NotImplementedError + + return env + + return _make_env + + env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)] + return env_fns + + +def make_envpool_envs( + id: str, + env_num: int = 1, + **kwargs, +): + assert "env_type" in kwargs + assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"] + kwargs["envpool"] = True + + if "env_wrappers" in kwargs: + env_wrappers = kwargs.pop("env_wrappers") + else: + env_wrappers = [] + env_fns = build_envs( + make=envpool.make, + id=id, + env_num=env_num, + wrappers=env_wrappers, + **kwargs, + ) + return env_fns + + +def make( + id: str, + env_num: int = 1, + asynchronous: bool = False, + add_monitor: bool = True, + render_mode: Optional[str] = None, + auto_reset: bool = True, + **kwargs, +): + cfg = kwargs.get("cfg", None) + if id in envpool.registration.list_all_envs(): + env_fns = make_envpool_envs( + id=id.split(":")[-1], + env_num=env_num, + **kwargs, + ) + if asynchronous: + env = AsyncVectorEnv( + env_fns, render_mode=render_mode, auto_reset=auto_reset + ) + else: + env = SyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset) + + reward_class = cfg.reward_class if cfg else None + reward_class = RewardFactory.get_reward_class(reward_class, env) + + env = RewardWrapper(env, reward_class) + + if add_monitor: + vec_info_class = cfg.vec_info_class if cfg else None + vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env) + env = VecMonitorWrapper(vec_info_class, env) + + return env + else: + raise NotImplementedError(f"env {id} is not supported") diff --git a/examples/envpool/train_ppo.py b/examples/envpool/train_ppo.py new file mode 100644 index 00000000..b6550b96 --- /dev/null +++ b/examples/envpool/train_ppo.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import numpy as np +from make_env import make + +from examples.envpool.envpool_wrappers import VecAdapter, VecMonitor +from openrl.configs.config import create_config_parser +from openrl.modules.common import PPONet as Net +from openrl.modules.common.ppo_net import PPONet as Net +from openrl.runners.common import PPOAgent as Agent + + +def train(): + # create the neural network + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args() + + # create environment, set environment parallelism to 9 + env = make( + "CartPole-v1", + render_mode=None, + env_num=9, + asynchronous=False, + env_wrappers=[VecAdapter, VecMonitor], + env_type="gym", + ) + + net = Net( + env, + cfg=cfg, + ) + # initialize the trainer + agent = Agent(net, use_wandb=False, project_name="CartPole-v1") + # start training, set total number of training steps to 20000 + agent.train(total_time_steps=20000) + + env.close() + return agent + + +def evaluation(agent): + # begin to test + # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. + render_mode = "group_human" + render_mode = None + env = make( + "CartPole-v1", + env_wrappers=[VecAdapter, VecMonitor], + render_mode=render_mode, + env_num=9, + asynchronous=True, + env_type="gym", + ) + # The trained agent sets up the interactive environment it needs. + agent.set_env(env) + # Initialize the environment and get initial observations and environmental information. + obs, info = env.reset() + done = False + step = 0 + total_step, total_reward = 0, 0 + while not np.any(done): + # Based on environmental observation input, predict next action. + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + step += 1 + total_step += 1 + total_reward += np.mean(r) + if step % 50 == 0: + print(f"{step}: reward:{np.mean(r)}") + env.close() + print("total step:", total_step) + print("total reward:", total_reward) + + +if __name__ == "__main__": + agent = train() + evaluation(agent) diff --git a/examples/gail/train_gail.py b/examples/gail/train_gail.py index abe73039..4e227be9 100644 --- a/examples/gail/train_gail.py +++ b/examples/gail/train_gail.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/gridworld/train_dqn.py b/examples/gridworld/train_dqn.py index 2b859784..900a1287 100644 --- a/examples/gridworld/train_dqn.py +++ b/examples/gridworld/train_dqn.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/gridworld/train_ppo.py b/examples/gridworld/train_ppo.py index 683e9579..71f59bcb 100644 --- a/examples/gridworld/train_ppo.py +++ b/examples/gridworld/train_ppo.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/nlp/README.md b/examples/nlp/README.md index 6bcbb7c0..2fb61de9 100644 --- a/examples/nlp/README.md +++ b/examples/nlp/README.md @@ -6,6 +6,14 @@ Users can train the dialog task via: python train_ppo.py --config nlp_ppo.yaml ``` +Users can train the dialog task with deepspeed via: + +```shell +deepspeed train_ppo.py --config nlp_ppo_ds.yaml + + +``` + After the training, users can chat with the agent via: ```shell diff --git a/examples/nlp/ds_config.json b/examples/nlp/ds_config.json new file mode 100644 index 00000000..3de0eb2d --- /dev/null +++ b/examples/nlp/ds_config.json @@ -0,0 +1,9 @@ +{ + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 16, + "steps_per_print": 10, + "zero_optimization": { + "stage": 2 + }, + "fp16": {"enabled": false, "loss_scale_window": 100} +} \ No newline at end of file diff --git a/examples/nlp/eval_ds_config.json b/examples/nlp/eval_ds_config.json new file mode 100644 index 00000000..58c08252 --- /dev/null +++ b/examples/nlp/eval_ds_config.json @@ -0,0 +1,10 @@ +{ + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 16, + "steps_per_print": 10, + "zero_optimization": { + "stage": 0, + "offload_param": {"device": "cpu"} +}, + "fp16": {"enabled": false} +} \ No newline at end of file diff --git a/examples/nlp/nlp_ppo.yaml b/examples/nlp/nlp_ppo.yaml index 47da5280..918a75b8 100644 --- a/examples/nlp/nlp_ppo.yaml +++ b/examples/nlp/nlp_ppo.yaml @@ -1,19 +1,16 @@ seed: 0 -lr: 1e-6 -critic_lr: 1e-6 +lr: 1e-7 +critic_lr: 1e-7 run_dir: ./run_results/ log_interval: 1 -use_recurrent_policy: true use_valuenorm: true use_adv_normalize: true wandb_entity: "openrl-lab" ppo_epoch: 5 episode_length: 128 num_mini_batch: 20 -use_share_model: true -use_amp: true + hidden_size: 1 -data_chunk_length: 1 model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog env: @@ -25,8 +22,8 @@ vec_info_class: id: "NLPVecInfo" reward_class: id: "NLPReward" - args: { - "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", + args: { "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", + "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", } \ No newline at end of file diff --git a/examples/nlp/nlp_ppo_ds.yaml b/examples/nlp/nlp_ppo_ds.yaml new file mode 100644 index 00000000..88dac18c --- /dev/null +++ b/examples/nlp/nlp_ppo_ds.yaml @@ -0,0 +1,37 @@ +seed: 0 +lr: 1e-7 +critic_lr: 1e-7 +run_dir: ./run_results/ +log_interval: 1 +use_valuenorm: true +use_adv_normalize: true +wandb_entity: "openrl-lab" +ppo_epoch: 5 +episode_length: 128 +num_mini_batch: 20 + +hidden_size: 1 + +use_deepspeed: true +use_fp16: false +use_offload: false +deepspeed_config: ds_config.json + +model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog +env: + args: { + 'tokenizer_path': 'gpt2', + 'data_path': 'daily_dialog', + } +vec_info_class: + id: "NLPVecInfo" +reward_class: + id: "NLPReward" + args: { + "use_deepspeed": true, + "ref_ds_config": "eval_ds_config.json", + "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", + "intent_ds_config": "eval_ds_config.json", + "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", + } + \ No newline at end of file diff --git a/examples/nlp/train_ppo.py b/examples/nlp/train_ppo.py index f549c122..4fefcf52 100644 --- a/examples/nlp/train_ppo.py +++ b/examples/nlp/train_ppo.py @@ -1,19 +1,25 @@ """""" + from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.modules.common import PPONet as Net -from openrl.modules.networks.policy_value_network_gpt import ( - PolicyValueNetworkGPT as PolicyValueNetwork, -) +from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork +from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork from openrl.runners.common import PPOAgent as Agent def train(): # create environment cfg_parser = create_config_parser() + try: + import deepspeed + + cfg_parser = deepspeed.add_config_arguments(cfg_parser) + except: + print("choose not to use deepspeed in the nlp task") cfg = cfg_parser.parse_args() - env_num = 10 + env_num = 5 env = make( "daily_dialog", env_num=env_num, @@ -22,7 +28,7 @@ def train(): ) # create the neural network - model_dict = {"model": PolicyValueNetwork} + model_dict = {"policy": PolicyNetwork, "critic": ValueNetwork} net = Net(env, device="cuda", cfg=cfg, model_dict=model_dict) # initialize the trainer diff --git a/examples/retro/train_retro.py b/examples/retro/train_retro.py index ad13749a..0668b620 100644 --- a/examples/retro/train_retro.py +++ b/examples/retro/train_retro.py @@ -1,4 +1,5 @@ """""" + import numpy as np from custom_registration import make diff --git a/examples/sac/train_ddpg.py b/examples/sac/train_ddpg.py index 484a1f6d..5bc2bab8 100644 --- a/examples/sac/train_ddpg.py +++ b/examples/sac/train_ddpg.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/sac/train_sac_beta.py b/examples/sac/train_sac_beta.py index 9fa905a8..bc40c1dc 100644 --- a/examples/sac/train_sac_beta.py +++ b/examples/sac/train_sac_beta.py @@ -1,4 +1,5 @@ """""" + import numpy as np from openrl.configs.config import create_config_parser diff --git a/examples/selfplay/selfplay.yaml b/examples/selfplay/selfplay.yaml index 7a7c1bbe..8a05611d 100644 --- a/examples/selfplay/selfplay.yaml +++ b/examples/selfplay/selfplay.yaml @@ -1,6 +1,6 @@ globals: selfplay_api_host: 127.0.0.1 - selfplay_api_port: 10086 + selfplay_api_port: 13486 seed: 0 selfplay_api: diff --git a/examples/smac/custom_vecinfo.py b/examples/smac/custom_vecinfo.py index 52a2b5b2..ba39f6e1 100644 --- a/examples/smac/custom_vecinfo.py +++ b/examples/smac/custom_vecinfo.py @@ -41,10 +41,10 @@ def statistics(self, buffer: Any) -> Dict[str, Any]: assert ( "game_state" in singe_env_info["final_info"].keys() ), "game_state must be in info" - assert singe_env_info["final_info"]["game_state"] in [ - "win", - "lose", - ], "game_state in the final_info must be win or lose" + # assert singe_env_info["final_info"]["game_state"] in [ + # "win", + # "lose", + # ], "game_state in the final_info must be win or lose" self.win_history.append( singe_env_info["final_info"]["game_state"] == "win" ) diff --git a/examples/smac/train_ppo.py b/examples/smac/train_ppo.py index 32f8acff..4c03d295 100644 --- a/examples/smac/train_ppo.py +++ b/examples/smac/train_ppo.py @@ -25,7 +25,8 @@ def train(): # create environment env_num = 8 env = make( - "2s_vs_1sc", + "3m", + # "2s_vs_1sc", env_num=env_num, asynchronous=True, cfg=cfg, diff --git a/examples/smacv2/custom_vecinfo.py b/examples/smacv2/custom_vecinfo.py index 6dd90d00..48fc210d 100644 --- a/examples/smacv2/custom_vecinfo.py +++ b/examples/smacv2/custom_vecinfo.py @@ -33,21 +33,21 @@ def __init__(self, *args, **kwargs): def statistics(self, buffer: Any) -> Dict[str, Any]: info_dict = super().statistics(buffer) - """for step_info in self.infos: + for step_info in self.infos: for singe_env_info in step_info: assert isinstance(singe_env_info, dict), "singe_env_info must be dict" if "final_info" in singe_env_info.keys(): assert ( "game_state" in singe_env_info["final_info"].keys() - ), "game_state must be in info" - assert singe_env_info["final_info"]["game_state"] in [ - "win", - "lose", - ], "game_state in the final_info must be win or lose" + ), "win_state must be in info" + # assert singe_env_info["final_info"]["game_state"] in [ + # "win", + # "lose", + # ], "win_state in the final_info must be win or lose" self.win_history.append( singe_env_info["final_info"]["game_state"] == "win" - )""" + ) if len(self.win_history) > 0: info_dict["win_rate"] = np.mean(self.win_history) diff --git a/examples/toy_env/train_ppo.py b/examples/toy_env/train_ppo.py index 49cb0c9f..6410b52a 100644 --- a/examples/toy_env/train_ppo.py +++ b/examples/toy_env/train_ppo.py @@ -1,4 +1,5 @@ """""" + from train_and_eval import evaluation, train from openrl.modules.common import PPONet as Net diff --git a/openrl/__init__.py b/openrl/__init__.py index 00bcaacf..2ea67943 100644 --- a/openrl/__init__.py +++ b/openrl/__init__.py @@ -1,5 +1,5 @@ __TITLE__ = "openrl" -__VERSION__ = "v0.1.7" +__VERSION__ = "v0.2.0" __DESCRIPTION__ = "Distributed Deep RL Framework" __AUTHOR__ = "OpenRL Contributors" __EMAIL__ = "huangshiyu@4paradigm.com" diff --git a/openrl/algorithms/dqn.py b/openrl/algorithms/dqn.py index bbca547b..ebd8d727 100644 --- a/openrl/algorithms/dqn.py +++ b/openrl/algorithms/dqn.py @@ -167,7 +167,9 @@ def prepare_loss( ) q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch - q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数 + q_loss = torch.mean( + F.mse_loss(q_values, q_targets.detach()) + ) # 均方误差损失函数 loss_list.append(q_loss) diff --git a/openrl/algorithms/ppo.py b/openrl/algorithms/ppo.py index 1c226645..18b5f2c0 100644 --- a/openrl/algorithms/ppo.py +++ b/openrl/algorithms/ppo.py @@ -41,10 +41,12 @@ def __init__( self.use_joint_action_loss = cfg.use_joint_action_loss super(PPOAlgorithm, self).__init__(cfg, init_module, agent_num, device) self.train_list = [self.train_ppo] + self.use_deepspeed = cfg.use_deepspeed def ppo_update(self, sample, turn_on=True): for optimizer in self.algo_module.optimizers.values(): - optimizer.zero_grad() + if not self.use_deepspeed: + optimizer.zero_grad() ( critic_obs_batch, @@ -108,8 +110,18 @@ def ppo_update(self, sample, turn_on=True): active_masks_batch, turn_on, ) - for loss in loss_list: - loss.backward() + if self.use_deepspeed: + if self._use_share_model: + for loss in loss_list: + self.algo_module.models["model"].backward(loss) + else: + actor_loss = loss_list[0] + critic_loss = loss_list[1] + self.algo_module.models["policy"].backward(actor_loss) + self.algo_module.models["critic"].backward(critic_loss) + else: + for loss in loss_list: + loss.backward() # else: if self._use_share_model: @@ -141,8 +153,15 @@ def ppo_update(self, sample, turn_on=True): self.algo_module.scaler.update() else: - for optimizer in self.algo_module.optimizers.values(): - optimizer.step() + if self.use_deepspeed: + if self._use_share_model: + self.algo_module.optimizers["model"].step() + else: + self.algo_module.optimizers["policy"].step() + self.algo_module.optimizers["critic"].step() + else: + for optimizer in self.algo_module.optimizers.values(): + optimizer.step() if self.world_size > 1: torch.cuda.synchronize() @@ -168,7 +187,7 @@ def cal_value_loss( -self.clip_param, self.clip_param ) - if self._use_popart or self._use_valuenorm: + if (self._use_popart or self._use_valuenorm) and value_normalizer is not None: value_normalizer.update(return_batch) error_clipped = ( value_normalizer.normalize(return_batch) - value_pred_clipped @@ -371,9 +390,12 @@ def train_ppo(self, buffer, turn_on): ].module.value_normalizer else: value_normalizer = self.algo_module.get_critic_value_normalizer() - advantages = buffer.returns[:-1] - value_normalizer.denormalize( - buffer.value_preds[:-1] - ) + if value_normalizer is not None: + advantages = buffer.returns[:-1] - value_normalizer.denormalize( + buffer.value_preds[:-1] + ) + else: + advantages = buffer.returns[:-1] - buffer.value_preds[:-1] else: advantages = buffer.returns[:-1] - buffer.value_preds[:-1] diff --git a/openrl/algorithms/vdn.py b/openrl/algorithms/vdn.py index f1215c03..83bdb5ed 100644 --- a/openrl/algorithms/vdn.py +++ b/openrl/algorithms/vdn.py @@ -211,7 +211,9 @@ def prepare_loss( rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1) rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1) q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch - q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数 + q_loss = torch.mean( + F.mse_loss(q_values, q_targets.detach()) + ) # 均方误差损失函数 loss_list.append(q_loss) return loss_list diff --git a/openrl/arena/__init__.py b/openrl/arena/__init__.py index 4bea924d..cb154a9f 100644 --- a/openrl/arena/__init__.py +++ b/openrl/arena/__init__.py @@ -30,9 +30,11 @@ def make_arena( **kwargs, ): if custom_build_env is None: + from openrl.envs import PettingZoo + if ( env_id in pettingzoo_all_envs - or env_id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys() + or env_id in PettingZoo.registration.pettingzoo_env_dict.keys() ): from openrl.envs.PettingZoo import make_PettingZoo_env diff --git a/openrl/arena/agents/random_agent.py b/openrl/arena/agents/random_agent.py new file mode 100644 index 00000000..d09e5e15 --- /dev/null +++ b/openrl/arena/agents/random_agent.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from openrl.arena.agents.base_agent import BaseAgent +from openrl.selfplay.opponents.base_opponent import BaseOpponent +from openrl.selfplay.opponents.random_opponent import RandomOpponent +from openrl.selfplay.opponents.utils import load_opponent_from_path + + +class RandomAgent(BaseAgent): + def __init__(self): + super().__init__() + + def _new_agent(self) -> BaseOpponent: + return RandomOpponent() diff --git a/openrl/arena/games/two_player_game.py b/openrl/arena/games/two_player_game.py index 7a1b4e0e..40585393 100644 --- a/openrl/arena/games/two_player_game.py +++ b/openrl/arena/games/two_player_game.py @@ -31,9 +31,10 @@ def default_dispatch_func( players: List[str], agent_names: List[str], ) -> Dict[str, str]: - assert len(players) == len( - agent_names - ), "The number of players must be equal to the number of agents." + assert len(players) == len(agent_names), ( + f"The number of players {len(players)} must be equal to the number of" + f" agents: {len(agent_names)}." + ) assert len(players) == 2, "The number of players must be equal to 2." np_random.shuffle(agent_names) return dict(zip(players, agent_names)) @@ -49,20 +50,21 @@ def _run(self, env_fn: Callable, agents: List[BaseAgent]): for player, agent in player2agent.items(): agent.reset(env, player) result = {} + truncation_dict = {} while True: termination = False info = {} for player_name in env.agent_iter(): observation, reward, termination, truncation, info = env.last() - - if termination: + truncation_dict[player_name] = truncation + if termination or all(truncation_dict.values()): break action = player2agent[player_name].act( player_name, observation, reward, termination, truncation, info ) env.step(action) - if termination: + if termination or all(truncation_dict.values()): assert "winners" in info, "The game is terminated but no winners." assert "losers" in info, "The game is terminated but no losers." diff --git a/openrl/buffers/offpolicy_replay_data.py b/openrl/buffers/offpolicy_replay_data.py index 4d62d53f..31e52e85 100644 --- a/openrl/buffers/offpolicy_replay_data.py +++ b/openrl/buffers/offpolicy_replay_data.py @@ -97,52 +97,52 @@ def __init__( ) self.first_insert_flag = True - def dict_insert(self, data): - if self._mixed_obs: - for key in self.critic_obs.keys(): - self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() - for key in self.policy_obs.keys(): - self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() - for key in self.next_policy_obs.keys(): - self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][ - key - ].copy() - for key in self.next_critic_obs.keys(): - self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][ - key - ].copy() - else: - self.critic_obs[self.step + 1] = data["critic_obs"].copy() - self.policy_obs[self.step + 1] = data["policy_obs"].copy() - self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy() - self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy() - - if "rnn_states" in data: - self.rnn_states[self.step + 1] = data["rnn_states"].copy() - if "rnn_states_critic" in data: - self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() - if "actions" in data: - self.actions[self.step + 1] = data["actions"].copy() - if "action_log_probs" in data: - self.action_log_probs[self.step] = data["action_log_probs"].copy() - - if "value_preds" in data: - self.value_preds[self.step] = data["value_preds"].copy() - if "rewards" in data: - self.rewards[self.step + 1] = data["rewards"].copy() - if "masks" in data: - self.masks[self.step + 1] = data["masks"].copy() - - if "bad_masks" in data: - self.bad_masks[self.step + 1] = data["bad_masks"].copy() - if "active_masks" in data: - self.active_masks[self.step + 1] = data["active_masks"].copy() - if "action_masks" in data: - self.action_masks[self.step + 1] = data["action_masks"].copy() - - if (self.step + 1) % self.episode_length != 0: - self.first_insert_flag = False - self.step = (self.step + 1) % self.episode_length + # def dict_insert(self, data): + # if self._mixed_obs: + # for key in self.critic_obs.keys(): + # self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() + # for key in self.policy_obs.keys(): + # self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() + # for key in self.next_policy_obs.keys(): + # self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][ + # key + # ].copy() + # for key in self.next_critic_obs.keys(): + # self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][ + # key + # ].copy() + # else: + # self.critic_obs[self.step + 1] = data["critic_obs"].copy() + # self.policy_obs[self.step + 1] = data["policy_obs"].copy() + # self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy() + # self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy() + # + # if "rnn_states" in data: + # self.rnn_states[self.step + 1] = data["rnn_states"].copy() + # if "rnn_states_critic" in data: + # self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() + # if "actions" in data: + # self.actions[self.step + 1] = data["actions"].copy() + # if "action_log_probs" in data: + # self.action_log_probs[self.step] = data["action_log_probs"].copy() + # + # if "value_preds" in data: + # self.value_preds[self.step] = data["value_preds"].copy() + # if "rewards" in data: + # self.rewards[self.step + 1] = data["rewards"].copy() + # if "masks" in data: + # self.masks[self.step + 1] = data["masks"].copy() + # + # if "bad_masks" in data: + # self.bad_masks[self.step + 1] = data["bad_masks"].copy() + # if "active_masks" in data: + # self.active_masks[self.step + 1] = data["active_masks"].copy() + # if "action_masks" in data: + # self.action_masks[self.step + 1] = data["action_masks"].copy() + # + # if (self.step + 1) % self.episode_length != 0: + # self.first_insert_flag = False + # self.step = (self.step + 1) % self.episode_length def init_buffer(self, raw_obs, action_masks=None): critic_obs = get_critic_obs(raw_obs) diff --git a/openrl/buffers/replay_data.py b/openrl/buffers/replay_data.py index 40a4b383..a8f4c1b7 100644 --- a/openrl/buffers/replay_data.py +++ b/openrl/buffers/replay_data.py @@ -198,49 +198,49 @@ def get_batch_data( else: return np.concatenate(data[step]) - def all_batch_data(self, data_name: str, min=None, max=None): - assert hasattr(self, data_name) - data = getattr(self, data_name) - - if isinstance(data, ObsData): - return data.all_batch(min, max) - else: - return data[min:max].reshape((-1, *data.shape[3:])) - - def dict_insert(self, data): - if self._mixed_obs: - for key in self.critic_obs.keys(): - self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() - for key in self.policy_obs.keys(): - self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() - else: - self.critic_obs[self.step + 1] = data["critic_obs"].copy() - self.policy_obs[self.step + 1] = data["policy_obs"].copy() - - if "rnn_states" in data: - self.rnn_states[self.step + 1] = data["rnn_states"].copy() - if "rnn_states_critic" in data: - self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() - if "actions" in data: - self.actions[self.step] = data["actions"].copy() - if "action_log_probs" in data: - self.action_log_probs[self.step] = data["action_log_probs"].copy() - - if "value_preds" in data: - self.value_preds[self.step] = data["value_preds"].copy() - if "rewards" in data: - self.rewards[self.step] = data["rewards"].copy() - if "masks" in data: - self.masks[self.step + 1] = data["masks"].copy() - - if "bad_masks" in data: - self.bad_masks[self.step + 1] = data["bad_masks"].copy() - if "active_masks" in data: - self.active_masks[self.step + 1] = data["active_masks"].copy() - if "action_masks" in data: - self.action_masks[self.step + 1] = data["action_masks"].copy() - - self.step = (self.step + 1) % self.episode_length + # def all_batch_data(self, data_name: str, min=None, max=None): + # assert hasattr(self, data_name) + # data = getattr(self, data_name) + # + # if isinstance(data, ObsData): + # return data.all_batch(min, max) + # else: + # return data[min:max].reshape((-1, *data.shape[3:])) + + # def dict_insert(self, data): + # if self._mixed_obs: + # for key in self.critic_obs.keys(): + # self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() + # for key in self.policy_obs.keys(): + # self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() + # else: + # self.critic_obs[self.step + 1] = data["critic_obs"].copy() + # self.policy_obs[self.step + 1] = data["policy_obs"].copy() + # + # if "rnn_states" in data: + # self.rnn_states[self.step + 1] = data["rnn_states"].copy() + # if "rnn_states_critic" in data: + # self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy() + # if "actions" in data: + # self.actions[self.step] = data["actions"].copy() + # if "action_log_probs" in data: + # self.action_log_probs[self.step] = data["action_log_probs"].copy() + # + # if "value_preds" in data: + # self.value_preds[self.step] = data["value_preds"].copy() + # if "rewards" in data: + # self.rewards[self.step] = data["rewards"].copy() + # if "masks" in data: + # self.masks[self.step + 1] = data["masks"].copy() + # + # if "bad_masks" in data: + # self.bad_masks[self.step + 1] = data["bad_masks"].copy() + # if "active_masks" in data: + # self.active_masks[self.step + 1] = data["active_masks"].copy() + # if "action_masks" in data: + # self.action_masks[self.step + 1] = data["action_masks"].copy() + # + # self.step = (self.step + 1) % self.episode_length def insert( self, @@ -323,7 +323,9 @@ def compute_returns(self, next_value, value_normalizer=None): self.value_preds[-1] = next_value gae = 0 for step in reversed(range(self.rewards.shape[0])): - if self._use_popart or self._use_valuenorm: + if ( + self._use_popart or self._use_valuenorm + ) and value_normalizer is not None: # step + 1 delta = ( self.rewards[step] @@ -357,7 +359,9 @@ def compute_returns(self, next_value, value_normalizer=None): else: self.returns[-1] = next_value for step in reversed(range(self.rewards.shape[0])): - if self._use_popart or self._use_valuenorm: + if ( + self._use_popart or self._use_valuenorm + ) and value_normalizer is not None: self.returns[step] = ( self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step] @@ -947,119 +951,119 @@ def naive_recurrent_generator(self, advantages, num_mini_batch): yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch - def recurrent_generator_v2( - self, advantages, num_mini_batch=None, mini_batch_size=None - ): - """ - Yield training data for MLP policies. - :param advantages: (np.ndarray) advantage estimates. - :param num_mini_batch: (int) number of minibatches to split the batch into. - :param mini_batch_size: (int) number of samples in each minibatch. - """ - episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] - batch_size = n_rollout_threads * episode_length - - if mini_batch_size is None: - assert ( - batch_size >= num_mini_batch - ), ( - "PPO requires the number of processes ({}) " - "* number of steps ({}) = {} " - "to be greater than or equal to the number of PPO mini batches ({})." - "".format( - n_rollout_threads, - episode_length, - n_rollout_threads * episode_length, - num_mini_batch, - ) - ) - mini_batch_size = batch_size // num_mini_batch - - rand = torch.randperm(batch_size).numpy() - sampler = [ - rand[i * mini_batch_size : (i + 1) * mini_batch_size] - for i in range(num_mini_batch) - ] - - # keep (num_agent, dim) - critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[2:]) - - policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[2:]) - - rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:]) - - rnn_states_critic = self.rnn_states_critic[:-1].reshape( - -1, *self.rnn_states_critic.shape[2:] - ) - - actions = self.actions.reshape(-1, *self.actions.shape[2:]) - - if self.action_masks is not None: - action_masks = self.action_masks[:-1].reshape( - -1, *self.action_masks.shape[2:] - ) - - value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:]) - - returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:]) - - masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:]) - - active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:]) - - action_log_probs = self.action_log_probs.reshape( - -1, *self.action_log_probs.shape[2:] - ) - - advantages = advantages.reshape(-1, *advantages.shape[2:]) - - shuffle = False - if shuffle: - rows, cols = _shuffle_agent_grid(batch_size, num_agents) - - if self.action_masks is not None: - action_masks = action_masks[rows, cols] - critic_obs = critic_obs[rows, cols] - policy_obs = policy_obs[rows, cols] - rnn_states = rnn_states[rows, cols] - rnn_states_critic = rnn_states_critic[rows, cols] - actions = actions[rows, cols] - value_preds = value_preds[rows, cols] - returns = returns[rows, cols] - masks = masks[rows, cols] - active_masks = active_masks[rows, cols] - action_log_probs = action_log_probs[rows, cols] - advantages = advantages[rows, cols] - - for indices in sampler: - # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim] - critic_obs_batch = critic_obs[indices].reshape(-1, *critic_obs.shape[2:]) - policy_obs_batch = policy_obs[indices].reshape(-1, *policy_obs.shape[2:]) - rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:]) - rnn_states_critic_batch = rnn_states_critic[indices].reshape( - -1, *rnn_states_critic.shape[2:] - ) - actions_batch = actions[indices].reshape(-1, *actions.shape[2:]) - if self.action_masks is not None: - action_masks_batch = action_masks[indices].reshape( - -1, *action_masks.shape[2:] - ) - else: - action_masks_batch = None - value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:]) - return_batch = returns[indices].reshape(-1, *returns.shape[2:]) - masks_batch = masks[indices].reshape(-1, *masks.shape[2:]) - active_masks_batch = active_masks[indices].reshape( - -1, *active_masks.shape[2:] - ) - old_action_log_probs_batch = action_log_probs[indices].reshape( - -1, *action_log_probs.shape[2:] - ) - if advantages is None: - adv_targ = None - else: - adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:]) - yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch + # def recurrent_generator_v2( + # self, advantages, num_mini_batch=None, mini_batch_size=None + # ): + # """ + # Yield training data for MLP policies. + # :param advantages: (np.ndarray) advantage estimates. + # :param num_mini_batch: (int) number of minibatches to split the batch into. + # :param mini_batch_size: (int) number of samples in each minibatch. + # """ + # episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] + # batch_size = n_rollout_threads * episode_length + # + # if mini_batch_size is None: + # assert ( + # batch_size >= num_mini_batch + # ), ( + # "PPO requires the number of processes ({}) " + # "* number of steps ({}) = {} " + # "to be greater than or equal to the number of PPO mini batches ({})." + # "".format( + # n_rollout_threads, + # episode_length, + # n_rollout_threads * episode_length, + # num_mini_batch, + # ) + # ) + # mini_batch_size = batch_size // num_mini_batch + # + # rand = torch.randperm(batch_size).numpy() + # sampler = [ + # rand[i * mini_batch_size : (i + 1) * mini_batch_size] + # for i in range(num_mini_batch) + # ] + # + # # keep (num_agent, dim) + # critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[2:]) + # + # policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[2:]) + # + # rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:]) + # + # rnn_states_critic = self.rnn_states_critic[:-1].reshape( + # -1, *self.rnn_states_critic.shape[2:] + # ) + # + # actions = self.actions.reshape(-1, *self.actions.shape[2:]) + # + # if self.action_masks is not None: + # action_masks = self.action_masks[:-1].reshape( + # -1, *self.action_masks.shape[2:] + # ) + # + # value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:]) + # + # returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:]) + # + # masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:]) + # + # active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:]) + # + # action_log_probs = self.action_log_probs.reshape( + # -1, *self.action_log_probs.shape[2:] + # ) + # + # advantages = advantages.reshape(-1, *advantages.shape[2:]) + # + # shuffle = False + # if shuffle: + # rows, cols = _shuffle_agent_grid(batch_size, num_agents) + # + # if self.action_masks is not None: + # action_masks = action_masks[rows, cols] + # critic_obs = critic_obs[rows, cols] + # policy_obs = policy_obs[rows, cols] + # rnn_states = rnn_states[rows, cols] + # rnn_states_critic = rnn_states_critic[rows, cols] + # actions = actions[rows, cols] + # value_preds = value_preds[rows, cols] + # returns = returns[rows, cols] + # masks = masks[rows, cols] + # active_masks = active_masks[rows, cols] + # action_log_probs = action_log_probs[rows, cols] + # advantages = advantages[rows, cols] + # + # for indices in sampler: + # # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim] + # critic_obs_batch = critic_obs[indices].reshape(-1, *critic_obs.shape[2:]) + # policy_obs_batch = policy_obs[indices].reshape(-1, *policy_obs.shape[2:]) + # rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:]) + # rnn_states_critic_batch = rnn_states_critic[indices].reshape( + # -1, *rnn_states_critic.shape[2:] + # ) + # actions_batch = actions[indices].reshape(-1, *actions.shape[2:]) + # if self.action_masks is not None: + # action_masks_batch = action_masks[indices].reshape( + # -1, *action_masks.shape[2:] + # ) + # else: + # action_masks_batch = None + # value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:]) + # return_batch = returns[indices].reshape(-1, *returns.shape[2:]) + # masks_batch = masks[indices].reshape(-1, *masks.shape[2:]) + # active_masks_batch = active_masks[indices].reshape( + # -1, *active_masks.shape[2:] + # ) + # old_action_log_probs_batch = action_log_probs[indices].reshape( + # -1, *action_log_probs.shape[2:] + # ) + # if advantages is None: + # adv_targ = None + # else: + # adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:]) + # yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length): episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] diff --git a/openrl/configs/config.py b/openrl/configs/config.py index 2a616fe6..49bdae79 100644 --- a/openrl/configs/config.py +++ b/openrl/configs/config.py @@ -498,13 +498,14 @@ def create_config_parser(): ) parser.add_argument( "--use_popart", - action="store_true", default=False, + type=bool, help="by default False, use PopArt to normalize rewards.", ) parser.add_argument( "--dual_clip_ppo", default=False, + type=bool, help="by default False, use dual-clip ppo.", ) parser.add_argument( @@ -618,7 +619,7 @@ def create_config_parser(): ) parser.add_argument( "--use_average_pool", - action="store_false", + type=bool, default=True, help="by default True, use average pooling for attn model.", ) @@ -730,8 +731,8 @@ def create_config_parser(): ) parser.add_argument( "--use_gae", - action="store_false", default=True, + type=bool, help="use generalized advantage estimation", ) parser.add_argument( @@ -748,8 +749,8 @@ def create_config_parser(): ) parser.add_argument( "--use_proper_time_limits", - action="store_true", default=False, + type=bool, help="compute returns taking into account time limits", ) parser.add_argument( @@ -1234,5 +1235,29 @@ def create_config_parser(): type=float, help="newest_weight", ) + parser.add_argument( + "--use_deepspeed", + default=False, + type=bool, + help="whether to use deepspeed", + ) + parser.add_argument( + "--local_rank", + default=-1, + type=int, + help="local_rank", + ) + parser.add_argument( + "--use_offload", + default=False, + type=bool, + help="whether to use offload (deepspeed)", + ) + parser.add_argument( + "--use_fp16", + default=False, + type=bool, + help="whether to use fp16 (deepspeed)", + ) return parser diff --git a/openrl/configs/utils.py b/openrl/configs/utils.py index 53e1f4d2..a8420767 100644 --- a/openrl/configs/utils.py +++ b/openrl/configs/utils.py @@ -16,7 +16,7 @@ """""" - +import os import re import tempfile @@ -83,9 +83,19 @@ def __call__(self, parser, cfg, values, option_string=None): # Load the rendered content as a dictionary data = yaml.safe_load(rendered_content) - # Write the result to a temporary file - with tempfile.NamedTemporaryFile("w", delete=True, suffix=".yaml") as temp_file: + # Write the result to a temporary file. Not work on Windows. + # with tempfile.NamedTemporaryFile("w", delete=True, suffix=".yaml") as temp_file: + # yaml.dump(data, temp_file) + # temp_file.seek(0) # Move to the beginning of the file + # # Use the default behavior of ActionConfigFile to handle the temporary file + # super().__call__(parser, cfg, temp_file.name, option_string) + + # Write the result to a temporary file. This works on all platforms. + temp_fd, temp_filename = tempfile.mkstemp(suffix=".yaml") + with os.fdopen(temp_fd, "w") as temp_file: yaml.dump(data, temp_file) - temp_file.seek(0) # Move to the beginning of the file + try: # Use the default behavior of ActionConfigFile to handle the temporary file - super().__call__(parser, cfg, temp_file.name, option_string) + super().__call__(parser, cfg, temp_filename, option_string) + finally: + os.remove(temp_filename) diff --git a/openrl/envs/__init__.py b/openrl/envs/__init__.py index a2eb835f..d12c493a 100644 --- a/openrl/envs/__init__.py +++ b/openrl/envs/__init__.py @@ -16,12 +16,9 @@ toy_all_envs = [ "BitFlippingEnv", - "FakeImageEnv", "IdentityEnv", "IdentityEnvcontinuous", "IdentityEnvBox", - "IdentityEnvMultiBinary", - "IdentityEnvMultiDiscrete", "SimpleMultiObsEnv", "SimpleMultiObsEnv", ] diff --git a/openrl/envs/common/build_envs.py b/openrl/envs/common/build_envs.py index 94c34019..76f4b35b 100644 --- a/openrl/envs/common/build_envs.py +++ b/openrl/envs/common/build_envs.py @@ -2,6 +2,7 @@ import inspect from typing import Callable, Iterable, List, Optional, Union +import gymnasium as gym from gymnasium import Env from openrl.envs.wrappers.base_wrapper import BaseWrapper @@ -33,6 +34,8 @@ def _make_env() -> Env: if need_env_id: new_kwargs["env_id"] = env_id new_kwargs["env_num"] = env_num + if id.startswith("ALE/") or id in gym.envs.registry.keys(): + new_kwargs.pop("cfg", None) env = make( id, diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index 099f2b39..5d1ed645 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -20,6 +20,7 @@ import gymnasium as gym import openrl +from openrl.envs.PettingZoo.registration import pettingzoo_env_dict from openrl.envs.vec_env import ( AsyncVectorEnv, BaseVecEnv, @@ -107,7 +108,6 @@ def make( id=id, env_num=env_num, render_mode=convert_render_mode, - cfg=cfg, **kwargs, ) elif id in openrl.envs.toy_all_envs: @@ -149,10 +149,7 @@ def make( render_mode=convert_render_mode, **kwargs, ) - elif ( - id in openrl.envs.pettingzoo_all_envs - or id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys() - ): + elif id in openrl.envs.pettingzoo_all_envs or id in pettingzoo_env_dict.keys(): from openrl.envs.PettingZoo import make_PettingZoo_envs env_fns = make_PettingZoo_envs( diff --git a/openrl/envs/mpe/rendering.py b/openrl/envs/mpe/rendering.py index ab1a47db..6dae5d66 100644 --- a/openrl/envs/mpe/rendering.py +++ b/openrl/envs/mpe/rendering.py @@ -1,6 +1,7 @@ """ 2D rendering framework """ + from __future__ import division import os @@ -26,15 +27,14 @@ try: from pyglet.gl import * + except ImportError: print( "Error occured while running `from pyglet.gl import *`", - ( - "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get" - " install python-opengl'. If you're running on a server, you may need a" - " virtual frame buffer; something like this should work: 'xvfb-run -s" - ' "-screen 0 1400x900x24" python \'' - ), + "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get" + " install python-opengl'. If you're running on a server, you may need a" + " virtual frame buffer; something like this should work: 'xvfb-run -s" + ' "-screen 0 1400x900x24" python \'', ) import math @@ -320,28 +320,6 @@ def make_polyline(v): return PolyLine(v, False) -def make_capsule(length, width): - left, r, t, b = 0, length, width / 2, -width / 2 - box = make_polygon([(left, b), (left, t), (r, t), (r, b)]) - circ0 = make_circle(width / 2) - circ1 = make_circle(width / 2) - circ1.add_attr(Transform(translation=(length, 0))) - geom = Compound([box, circ0, circ1]) - return geom - - -class Compound(Geom): - def __init__(self, gs): - Geom.__init__(self) - self.gs = gs - for g in self.gs: - g.attrs = [a for a in g.attrs if not isinstance(a, Color)] - - def render1(self): - for g in self.gs: - g.render() - - class PolyLine(Geom): def __init__(self, v, close): Geom.__init__(self) @@ -373,59 +351,3 @@ def render1(self): glVertex2f(*self.start) glVertex2f(*self.end) glEnd() - - -class Image(Geom): - def __init__(self, fname, width, height): - Geom.__init__(self) - self.width = width - self.height = height - img = pyglet.image.load(fname) - self.img = img - self.flip = False - - def render1(self): - self.img.blit( - -self.width / 2, -self.height / 2, width=self.width, height=self.height - ) - - -# ================================================================ - - -class SimpleImageViewer(object): - def __init__(self, display=None): - self.window = None - self.isopen = False - self.display = display - - def imshow(self, arr): - if self.window is None: - height, width, channels = arr.shape - self.window = pyglet.window.Window( - width=width, height=height, display=self.display - ) - self.width = width - self.height = height - self.isopen = True - assert arr.shape == ( - self.height, - self.width, - 3, - ), "You passed in an image with the wrong number shape" - image = pyglet.image.ImageData( - self.width, self.height, "RGB", arr.tobytes(), pitch=self.width * -3 - ) - self.window.clear() - self.window.switch_to() - self.window.dispatch_events() - image.blit(0, 0) - self.window.flip() - - def close(self): - if self.isopen: - self.window.close() - self.isopen = False - - def __del__(self): - self.close() diff --git a/openrl/envs/nlp/daily_dialog_env.py b/openrl/envs/nlp/daily_dialog_env.py index 0c7a6ff7..d197a232 100644 --- a/openrl/envs/nlp/daily_dialog_env.py +++ b/openrl/envs/nlp/daily_dialog_env.py @@ -36,11 +36,24 @@ def __init__( prompt_truncation_side (str): truncation side for prompt text (Defaults to "left") """ - self.debug = cfg.env.args["data_path"] is None + self.debug = ( + cfg.env.args["data_path"] is None or cfg.env.args["data_path"] == "None" + ) self.env_name = "daily_dialog" tokenizer_name = cfg.env.args["tokenizer_path"] - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) + if tokenizer_name == "builtin_BPE": + from tokenizers import Tokenizer, models + + self.tokenizer = Tokenizer(models.BPE()) + + self.tokenizer.pad_token = "" + self.tokenizer.eos_token = "" + self.tokenizer.vocab_size = 2 + self.tokenizer.name_or_path = "builtin_BPE" + + else: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "left" @@ -100,7 +113,7 @@ def __init__( self.__time_step = None self.reward_function = None - def set_reward(self, reward_fn): + def set_reward(self, reward_fn=None): self.reward_function = reward_fn def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: diff --git a/openrl/envs/nlp/rewards/intent.py b/openrl/envs/nlp/rewards/intent.py index 397ea810..bc4da36c 100644 --- a/openrl/envs/nlp/rewards/intent.py +++ b/openrl/envs/nlp/rewards/intent.py @@ -9,25 +9,93 @@ from openrl.supports.opengpu.manager import LocalGPUManager +def get_default_ds_config(offload=True, stage=0, fp16=True): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + } + return { + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 16, + "steps_per_print": 10, + "zero_optimization": zero_opt_dict, + "fp16": {"enabled": fp16}, + } + + class Intent: - def __init__(self, intent_model: str, intent_coeff: float = 1.0) -> None: + def __init__( + self, + intent_model: str, + intent_coeff: float = 1.0, + use_deepspeed: bool = True, + ds_config: str = "default", + ) -> None: super().__init__() self._intent_coeff = intent_coeff + self.use_deepspeed = use_deepspeed + self.use_half = False + self.use_data_parallel = not use_deepspeed # default to use data parallel + self.use_model_parallel = False - model_path = data_abs_path(intent_model) - self._tokenizer = AutoTokenizer.from_pretrained(intent_model) - self._model = AutoModelForSequenceClassification.from_pretrained(model_path) + if intent_model == "builtin_intent": + self._device = "cpu" + self.use_data_parallel = False + + from transformers import GPT2Config, GPT2LMHeadModel + + class TestTokenizer: + def __call__( + self, + input_texts, + return_tensors="pt", + truncation=True, + padding=True, + max_length=None, + ): + class EncodedOutput: + def __init__(self, input_ids, attention_mask): + self.input_ids = input_ids + self.attention_mask = attention_mask + + input_ids = torch.zeros((32), dtype=torch.long) + attention_masks = torch.zeros((32), dtype=torch.long) + return EncodedOutput(input_ids, attention_masks) + + self._tokenizer = TestTokenizer() + config = GPT2Config() + self._model = GPT2LMHeadModel(config) - if torch.cuda.is_available(): - manager = LocalGPUManager() - manager.log_info() - self._device = f"cuda:{manager.get_gpu()}" else: - self._device = "cpu" - print("Intent Model choose to use device:{}".format(self._device)) + self._device = "cuda" + model_path = data_abs_path(intent_model) + self._tokenizer = AutoTokenizer.from_pretrained(intent_model) + self._model = AutoModelForSequenceClassification.from_pretrained(model_path) + + if self.use_deepspeed: + import deepspeed - self._model = self._model.to(self._device) + if ds_config == "default": + ds_config = get_default_ds_config() + else: + import json + + with open(ds_config) as file: + ds_config = json.load(file) + + self._model = self._model.to(self._device) + self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config) + self.use_fp16 = ds_config["fp16"]["enabled"] + else: + if self.use_model_parallel: + self._model.parallelize() + elif self.use_data_parallel: + if self.use_half: + self._model = self._model.half() + self._model = torch.nn.DataParallel(self._model) + self._model = self._model.to(self._device) def __call__( self, @@ -58,11 +126,19 @@ def get_input_for_classifier(prompt, generated_text): input_texts, return_tensors="pt", truncation=True, padding=True ) + if self.use_half: + encoded.input_ids = encoded.input_ids.int() + encoded.attention_mask = encoded.attention_mask.int() + else: + encoded.input_ids = encoded.input_ids.long() + encoded.attention_mask = encoded.attention_mask.long() + with torch.no_grad(): outputs = self._model( input_ids=encoded.input_ids.to(self._device), attention_mask=encoded.attention_mask.to(self._device), ) + pred_labels = torch.argmax(outputs.logits, dim=1).tolist() score = (np.array(pred_labels) == np.array(target_intents)) * 1.0 diff --git a/openrl/envs/nlp/rewards/kl_penalty.py b/openrl/envs/nlp/rewards/kl_penalty.py index 039d82d5..c98c6bfb 100644 --- a/openrl/envs/nlp/rewards/kl_penalty.py +++ b/openrl/envs/nlp/rewards/kl_penalty.py @@ -10,24 +10,79 @@ from openrl.envs.nlp.utils.distribution import CategoricalDistribution +def get_default_ds_config(offload=True, stage=0, fp16=True): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + } + return { + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 16, + "steps_per_print": 10, + "zero_optimization": zero_opt_dict, + "fp16": {"enabled": fp16}, + } + + class KLPenalty(nn.Module): def __init__( self, action_space: gym.Space, ref_model: str, apply_model_parallel: bool = True, + use_deepspeed: bool = True, + ds_config: str = "default", ): super().__init__() + self.device = "cuda" + self.use_deepspeed = use_deepspeed + self.use_half = False + self.use_data_parallel = not use_deepspeed + self.use_model_parallel = False + assert not (self.use_deepspeed and self.use_data_parallel) + assert not (self.use_deepspeed and self.use_model_parallel) + assert not (self.use_data_parallel and self.use_model_parallel) + # reference model - self._apply_model_parallel = apply_model_parallel - self._ref_net = AutoModelForCausalLM.from_pretrained(ref_model) + if ref_model == "builtin_ref": + self.device = "cpu" + self.use_data_parallel = False + + from transformers import GPT2Config, GPT2LMHeadModel + + config = GPT2Config() + self._ref_net = GPT2LMHeadModel(config) + else: + self._ref_net = AutoModelForCausalLM.from_pretrained(ref_model) self._ref_net = self._ref_net.eval() - if torch.cuda.is_available(): - if self._apply_model_parallel and self._ref_net.is_parallelizable: + if self.use_deepspeed: + import deepspeed + + if ds_config == "default": + self.use_fp16 = True + ds_config = get_default_ds_config() + else: + import json + + with open(ds_config) as file: + ds_config = json.load(file) + if "fp16" in ds_config: + self.use_fp16 = ds_config["fp16"]["enabled"] + else: + self.use_fp16 = False + + self._ref_engine, *_ = deepspeed.initialize(model=self, config=ds_config) + else: + if self.use_model_parallel: self._ref_net.parallelize() - else: # else defaults to data parallel - self._ref_net = torch.nn.DataParallel(self._ref_net) + elif self.use_data_parallel: # else defaults to data parallel + if self.use_half: + self._ref_net = self._ref_net.half() + else: + self._ref_net = torch.nn.DataParallel(self._ref_net) + self._ref_net = self._ref_net.to(self.device) # alpha adjustment self._alpha = 0.2 @@ -61,20 +116,31 @@ def __call__( past_model_kwargs = { "attention_mask": attention_mask, } - model_inputs = self._prepare_inputs_for_model( self._ref_net, input_ids, past_model_kwargs ) + if self.use_half: + for key in ["input_ids", "position_ids", "attention_mask"]: + if key in model_inputs: + model_inputs[key] = model_inputs[key].int() + else: + for key in ["input_ids", "position_ids", "attention_mask"]: + if key in model_inputs: + model_inputs[key] = model_inputs[key].long() + with torch.no_grad(): output = self._ref_net(output_hidden_states=True, **model_inputs) output["past_key_values"] = None next_token_logits = output.logits[:, -1, :] + if self.use_deepspeed and self.use_fp16: + next_token_logits = next_token_logits.double() dist = self._action_dist.proba_distribution(action_logits=next_token_logits) action_input = actions.to(next_token_logits.device) ref_log_prob = dist.log_prob(action_input) ref_log_prob = ref_log_prob.reshape(action_log_probs.shape) + kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy() rew = -self._alpha * kl_div infos = [] @@ -97,7 +163,7 @@ def _prepare_inputs_for_model( input_ids, **model_kwargs ) - if self._apply_model_parallel and unwrap_model(model).is_parallelizable: + if self.use_model_parallel: # if model is in parallel mode, move the tensors to the first device model_inputs = { key: ( @@ -108,4 +174,15 @@ def _prepare_inputs_for_model( ) for key, value in model_inputs.items() } + elif self.use_data_parallel: + model_inputs = { + key: value.to(self.device) if isinstance(value, torch.Tensor) else value + for key, value in model_inputs.items() + } + elif self.use_deepspeed: + model_inputs = { + key: value.to("cuda") if isinstance(value, torch.Tensor) else value + for key, value in model_inputs.items() + } + return model_inputs diff --git a/openrl/envs/nlp/rewards/meteor.py b/openrl/envs/nlp/rewards/meteor.py index c9acd16f..5bd169ad 100644 --- a/openrl/envs/nlp/rewards/meteor.py +++ b/openrl/envs/nlp/rewards/meteor.py @@ -6,13 +6,21 @@ import openrl.envs.nlp as nlp +class VirtualMetric: + def compute(self, predictions: Any, references: Any) -> Dict[str, float]: + return {"meteor": 0.0} + + class Meteor: - def __init__(self, meteor_coeff: int) -> None: + def __init__(self, meteor_coeff: int, test: bool = False) -> None: super().__init__() self._meteor_coeff = meteor_coeff - self._metric = evaluate.load( - str(Path(nlp.__file__).parent / "utils/metrics/meteor.py") - ) + if test: + self._metric = VirtualMetric() + else: + self._metric = evaluate.load( + str(Path(nlp.__file__).parent / "utils/metrics/meteor.py") + ) def __call__( self, diff --git a/openrl/envs/snake/common.py b/openrl/envs/snake/common.py deleted file mode 100644 index 6a67a0a3..00000000 --- a/openrl/envs/snake/common.py +++ /dev/null @@ -1,227 +0,0 @@ -import os -import sys - -import numpy as np - - -class HiddenPrints: - def __enter__(self): - self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, "w") - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.stdout.close() - sys.stdout = self._original_stdout - - -class Board: - def __init__(self, board_height, board_width, snakes, beans_positions, teams): - # print('create board, beans_position: ', beans_positions) - self.height = board_height - self.width = board_width - self.snakes = snakes - self.snakes_count = len(snakes) - self.beans_positions = beans_positions - self.blank_sign = -self.snakes_count - self.bean_sign = -self.snakes_count + 1 - self.board = np.zeros((board_height, board_width), dtype=int) + self.blank_sign - self.open = dict() - for key, snake in self.snakes.items(): - self.open[key] = [snake.head] # state 0 open list, heads, ready to spread - # see [A* Pathfinding (E01: algorithm explanation)](https://www.youtube.com/watch?v=-L-WgKMFuhE) - for x, y in snake.pos: - self.board[x][y] = key # obstacles, e.g. 0, 1, 2, 3, 4, 5 - # for x, y in beans_positions: - # self.board[x][y] = self.bean_sign # beans - - self.state = 0 - self.controversy = dict() - self.teams = teams - - # print('initial board') - # print(self.board) - - def step(self): # delay: prevent rear-end collision - new_open = {key: [] for key in self.snakes.keys()} - self.state += 1 # update state - # if self.state > delay: - # for key, snake in self.snakes.items(): # drop tail - # if snake.len >= self.state: - # self.board[snake.pos[-(self.state - delay)][0]][snake.pos[-(self.state - delay)][1]] \ - # = self.blank_sign - for key, snake in self.snakes.items(): - if snake.len >= self.state: - self.board[snake.pos[-self.state][0]][ - snake.pos[-self.state][1] - ] = self.blank_sign # drop tail - for key, value in self.open.items(): # value: e.g. [[8, 3], [6, 3], [7, 4]] - others_tail_pos = [ - ( - self.snakes[_].pos[-self.state] - if self.snakes[_].len >= self.state - else [] - ) - for _ in set(range(self.snakes_count)) - {key} - ] - for x, y in value: - # print('start to spread snake {} on grid ({}, {})'.format(key, x, y)) - for x_, y_ in [ - ((x + 1) % self.height, y), # down - ((x - 1) % self.height, y), # up - (x, (y + 1) % self.width), # right - (x, (y - 1) % self.width), - ]: # left - sign = self.board[x_][y_] - idx = ( - sign % self.snakes_count - ) # which snake, e.g. 0, 1, 2, 3, 4, 5 / number of claims - state = ( - sign // self.snakes_count - ) # manhattan distance to snake who claim the point or its negative - if sign == self.blank_sign: # grid in initial state - if [x_, y_] in others_tail_pos: - # print('do not spread other snakes tail, in case of rear-end collision') - continue # do not spread other snakes' tail, in case of rear-end collision - self.board[x_][y_] = self.state * self.snakes_count + key - self.snakes[key].claimed_count += 1 - new_open[key].append([x_, y_]) - - elif key != idx and self.state == state: - # second claim, init controversy, change grid value from + to - - # print( - # '\tgird ({}, {}) in the same state claimed by different snakes ' - # 'with sign {}, idx {} and state {}'.format( - # x_, y_, sign, idx, state)) - if ( - self.snakes[idx].len > self.snakes[key].len - ): # shorter snake claim the controversial grid - # print('\t\tsnake {} is shorter than snake {}'.format(key, idx)) - self.snakes[idx].claimed_count -= 1 - new_open[idx].remove([x_, y_]) - self.board[x_][y_] = self.state * self.snakes_count + key - self.snakes[key].claimed_count += 1 - new_open[key].append([x_, y_]) - elif ( - self.snakes[idx].len == self.snakes[key].len - ): # controversial claim - # print( - # '\t\tcontroversy! first claimed by snake {}, then claimed by snake {}'.format(idx, key)) - self.controversy[(x_, y_)] = { - "state": self.state, - "length": self.snakes[idx].len, - "indexes": [idx, key], - } - # first claim by snake idx, then claim by snake key - self.board[x_][y_] = -self.state * self.snakes_count + 1 - # if + 2, not enough for all snakes claim one grid!! - self.snakes[ - idx - ].claimed_count -= ( - 1 # controversy, no snake claim this grid!! - ) - new_open[key].append([x_, y_]) - else: # (self.snakes[idx].len < self.snakes[key].len) - pass # longer snake do not claim the controversial grid - - elif ( - (x_, y_) in self.controversy - and key not in self.controversy[(x_, y_)]["indexes"] - and self.state + state == 0 - ): # third claim or more - # print('snake {} meets third or more claim in grid ({}, {})'.format(key, x_, y_)) - controversy = self.controversy[(x_, y_)] - # pprint.pprint(controversy) - if ( - controversy["length"] > self.snakes[key].len - ): # shortest snake claim grid, do 4 things - # print('\t\tsnake {} is shortest'.format(key)) - indexes_count = len(controversy["indexes"]) - for i in controversy["indexes"]: - self.snakes[i].claimed_count -= ( - 1 / indexes_count - ) # update claimed_count ! - new_open[i].remove([x_, y_]) - del self.controversy[(x_, y_)] - self.board[x_][y_] = self.state * self.snakes_count + key - self.snakes[key].claimed_count += 1 - new_open[key].append([x_, y_]) - elif ( - controversy["length"] == self.snakes[key].len - ): # controversial claim - # print('\t\tcontroversy! multi claimed by snake {}'.format(key)) - self.controversy[(x_, y_)]["indexes"].append(key) - self.board[x_][y_] += 1 - new_open[key].append([x_, y_]) - else: # (controversy['length'] < self.snakes[key].len) - pass # longer snake do not claim the controversial grid - else: - pass # do nothing with lower state grids - - self.open = new_open # update open - # update controversial snakes' claimed_count (in fraction) in the end - for _, d in self.controversy.items(): - controversial_snake_count = len( - d["indexes"] - ) # number of controversial snakes - for idx in d["indexes"]: - self.snakes[idx].claimed_count += 1 / controversial_snake_count - - -class SnakePos: - def __init__(self, snake_positions, board_height, board_width, beans_positions): - self.pos = snake_positions # [[2, 9], [2, 8], [2, 7]] - self.len = len(snake_positions) # >= 3 - self.head = snake_positions[0] - self.beans_positions = beans_positions - self.claimed_count = 0 - - displace = [ - (self.head[0] - snake_positions[1][0]) % board_height, - (self.head[1] - snake_positions[1][1]) % board_width, - ] - # print('creat snake, pos: ', self.pos, 'displace:', displace) - if displace == [ - board_height - 1, - 0, - ]: # all action are ordered by left, up, right, relative to the body - self.dir = 0 # up - self.legal_action = [2, 0, 3] - elif displace == [1, 0]: - self.dir = 1 # down - self.legal_action = [3, 1, 2] - elif displace == [0, board_width - 1]: - self.dir = 2 # left - self.legal_action = [1, 2, 0] - elif displace == [0, 1]: - self.dir = 3 # right - self.legal_action = [0, 3, 1] - else: - assert False, "snake positions error" - positions = [ - [(self.head[0] - 1) % board_height, self.head[1]], - [(self.head[0] + 1) % board_height, self.head[1]], - [self.head[0], (self.head[1] - 1) % board_width], - [self.head[0], (self.head[1] + 1) % board_width], - ] - self.legal_position = [positions[_] for _ in self.legal_action] - - def get_action(self, position): - if position not in self.legal_position: - assert False, "the start and end points do not match" - idx = self.legal_position.index(position) - return self.legal_action[idx] # 0, 1, 2, 3: up, down, left, right - - def step(self, legal_input): - if legal_input in self.legal_position: - position = legal_input - elif legal_input in self.legal_action: - idx = self.legal_action.index(legal_input) - position = self.legal_position[idx] - else: - assert False, "illegal snake move" - self.head = position - self.pos.insert(0, position) - if position in self.beans_positions: # eat a bean - self.len += 1 - else: # do not eat a bean - self.pos.pop() diff --git a/openrl/envs/snake/snake.py b/openrl/envs/snake/snake.py index 73e81229..4a5be6a5 100644 --- a/openrl/envs/snake/snake.py +++ b/openrl/envs/snake/snake.py @@ -674,7 +674,9 @@ class Snake: def __init__(self, player_id, board_width, board_height, init_len): self.actions = [-2, 2, -1, 1] self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"} - self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右] + self.direction = random.choice( + self.actions + ) # 方向[-2,2,-1,1]分别表示[上,下,左,右] self.board_width = board_width self.board_height = board_height x = random.randrange(0, board_height) diff --git a/openrl/envs/snake/snake_3v3.py b/openrl/envs/snake/snake_3v3.py deleted file mode 100644 index 78d787ef..00000000 --- a/openrl/envs/snake/snake_3v3.py +++ /dev/null @@ -1,854 +0,0 @@ -# -*- coding:utf-8 -*- -# 作者:zruizhi -# 创建时间: 2020/7/30 17:24 下午 -# 描述: -import copy -import itertools -import random -import time -from itertools import count - -import numpy as np -from gym import Env, spaces -from PIL import Image, ImageDraw, ImageFont - -from .common import Board, HiddenPrints, SnakePos # TODO: Snake类的重名问题 -from .discrete import Discrete -from .gridgame import GridGame -from .observation import * - - -class SnakeEatBeans(GridGame, GridObservation, DictObservation): - def __init__(self, all_args, env_id): - self.all_args = all_args - conf = { - "class_literal": "SnakeEatBeans", - "n_player": 6, - "board_width": 20, - "board_height": 10, - "channels": 15, - "cell_range": 8, - "n_beans": 5, - "max_step": 200, - "game_name": "snakes", - "is_obs_continuous": False, - "is_act_continuous": False, - "agent_nums": [3, 3], - "obs_type": ["dict", "dict"], - "save_interval": 100, - "save_path": "../../replay/snake_3v3/replay_{}.gif", - } - self.terminate_flg = False - colors = conf.get("colors", [(255, 255, 255), (255, 140, 0)]) - super(SnakeEatBeans, self).__init__(conf, colors) - # 0: 没有 1:食物 2-n_player+1:各玩家蛇身 - self.n_cell_type = self.n_player + 2 - self.step_cnt = 1 - self.n_beans = int(conf["n_beans"]) - # 方向[-2,2,-1,1]分别表示[上,下,左,右] - self.actions = [-2, 2, -1, 1] - self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"} - self.snakes_position = {} - self.players = [] - self.cur_bean_num = 0 - self.beans_position = [] - # 1<= init_len <= 3 - self.init_len = 3 - self.current_state = self.init_state() - self.all_observes = self.get_all_observes() - if self.n_player * self.init_len > self.board_height * self.board_width: - raise Exception( - "玩家数量过多:%d,超出board范围:%d,%d" - % (self.n_player, self.board_width, self.board_height) - ) - - self.input_dimension = self.board_width * self.board_height - self.action_dim = self.get_action_dim() - self.channels = conf["channels"] - - self.num_agents = conf["agent_nums"][0] - self.num_enemys = conf["agent_nums"][1] - - self.observation_space = [ - spaces.Box( - low=-np.inf, - high=-np.inf, - shape=(self.channels, self.board_width, self.board_height), - dtype=np.float32, - ) - ] - self.share_observation_space = [] - self.share_observation_space = [ - spaces.Box( - low=-np.inf, - high=+np.inf, - shape=(self.channels, self.board_width, self.board_height), - dtype=np.float32, - ) - ] - self.action_space = [Discrete(4) for _ in range(self.n_player)] - self.save_interval = conf["save_interval"] - self.save_path = conf["save_path"] - self.episode = 0 - self.render = all_args.save_replay - self.img_list = [] - self.env_id = env_id - - def seed(self, seed=None): - if seed is None: - np.random.seed(1) - else: - np.random.seed(seed) - - def check_win(self): - flg = self.won.index(max(self.won)) + 2 - return flg - - def get_grid_observation(self, current_state, player_id, info_before): - return current_state - - def get_dict_observation(self, current_state, player_id, info_before): - key_info = {1: self.beans_position} - for i in range(self.n_player): - snake = self.players[i] - key_info[snake.player_id] = snake.segments - # key_info['state_map'] = current_state - key_info["board_width"] = self.board_width - key_info["board_height"] = self.board_height - key_info["last_direction"] = ( - info_before.get("directions") if isinstance(info_before, dict) else None - ) - key_info["controlled_snake_index"] = player_id - - return key_info - - def set_action_space(self): - action_space = [[Discrete(4)] for _ in range(self.n_player)] - return action_space - - def reset(self): - self.step_cnt = 1 - self.snakes_position = ( - {} - ) # 格式类似于{1: [[3, 1], [4, 3], [1, 2], [0, 6], [3, 3]], 2: [[3, 0], [3, 7], [3, 6]], 3: [[2, 7], [1, 7], [0, 7]]} - self.players = [] - self.cur_bean_num = 0 - self.beans_position = [] - self.current_state = self.init_state() - self.all_observes = self.get_all_observes() - self.terminate_flg = False - self.img_list = [] - self.episode += 1 - - # available actions - left_avail_actions = np.ones([self.num_agents, self.action_dim]) - right_avail_actions = np.ones([self.num_enemys, self.action_dim]) - avail_actions = np.concatenate([left_avail_actions, right_avail_actions], 0) - # process obs - board = [] - for i in range(self.n_player): - board.append([self.get_board(self.all_observes[i])]) - - board_ = np.concatenate(board) - obs = [] - for raw_obs in self.all_observes: - obs.append([self.raw2vec(raw_obs)]) - obs_ = np.concatenate(obs) - obs_ = np.concatenate((obs_, board_), axis=1) - - share_obs = np.repeat(np.expand_dims(obs_[0], axis=0), 6, 0) - - return obs_, share_obs, avail_actions # obs:(n_player, 288) - - # return self.all_observes - - def step(self, joint_action): - info_before = self.step_before_info() - joint_action = np.expand_dims(joint_action, 1) - all_observes, info_after = self.get_next_state(joint_action) - done = self.is_terminal() - reward = self.get_reward(joint_action) - left_avail_actions = np.ones([self.num_agents, self.action_dim]) - right_avail_actions = np.ones([self.num_enemys, self.action_dim]) - avail_actions = np.concatenate([left_avail_actions, right_avail_actions], 0) - - board = [] - for i in range(self.n_player): - board.append([self.get_board(all_observes[i])]) - - board_ = np.concatenate(board) - - obs = [] - - for raw_obs in all_observes: - obs.append([self.raw2vec(raw_obs)]) # obs:[[(14, 20, 10)], [], ..., []] - - obs_ = np.concatenate(obs) # (n_player, channels, width, height) - obs_ = np.concatenate((obs_, board_), axis=1) - - share_obs = np.repeat(np.expand_dims(obs_[0], axis=0), 6, 0) - - if done: - reward = self.get_final_reward(reward) - - rewards = np.expand_dims(np.array(reward), axis=1) - - dones = [done] * self.n_player - infos = [info_after] * self.n_player - - if self.render and self.episode % self.save_interval == 0 and self.env_id == 0: - img = self.render_board() - img_pil = Image.fromarray(img) - self.img_list.append(img_pil) - - if done: - self.img_list[0].save( - self.save_path.format(self.episode), - save_all=True, - append_images=self.img_list[1:], - duration=400, - ) - print("save replay gif to" + self.save_path.format(self.episode)) - - return obs_, share_obs, rewards, dones, infos, avail_actions - # return all_observes, reward, done, info_before, info_after - - # obs: 0 空白 1 豆子 2 我方蛇头 3 我方蛇身 4-5 友方蛇头 6-7 友方蛇身 8-10 敌方蛇头 11-13 敌方蛇身 - def raw2vec(self, raw_obs): - control_index = raw_obs["controlled_snake_index"] - width = raw_obs["board_width"] - height = raw_obs["board_height"] - beans = raw_obs[1] - pos = raw_obs[control_index] - - obs = np.zeros(width * height, dtype=int) - head_h, head_w = pos[0] - obs[head_h * width + head_w] = 2 - - for bean in beans: - h, w = bean - obs[h * width + w] = 1 - - for p in pos[1:]: - h, w = p - obs[h * width + w] = 3 - - if control_index == 2: - h1, w1 = raw_obs[3][0] - h2, w2 = raw_obs[4][0] - obs[h1 * width + w1] = 4 - obs[h2 * width + w2] = 5 - for p in raw_obs[3][1:]: - h, w = p - obs[h * width + w] = 6 - for p in raw_obs[4][1:]: - h, w = p - obs[h * width + w] = 7 - for i in range(self.num_agents + 2, self.n_player + 2): - h, w = raw_obs[i][0] - obs[h * width + w] = i + 3 - for p in raw_obs[i][1:]: - h, w = p - obs[h * width + w] = i + 6 - elif control_index == 3: - h1, w1 = raw_obs[2][0] - h2, w2 = raw_obs[4][0] - obs[h1 * width + w1] = 4 - obs[h2 * width + w2] = 5 - for p in raw_obs[2][1:]: - h, w = p - obs[h * width + w] = 6 - for p in raw_obs[4][1:]: - h, w = p - obs[h * width + w] = 7 - for i in range(self.num_agents + 2, self.n_player + 2): - h, w = raw_obs[i][0] - obs[h * width + w] = i + 3 - for p in raw_obs[i][1:]: - h, w = p - obs[h * width + w] = i + 6 - elif control_index == 4: - h1, w1 = raw_obs[2][0] - h2, w2 = raw_obs[3][0] - obs[h1 * width + w1] = 4 - obs[h2 * width + w2] = 5 - for p in raw_obs[2][1:]: - h, w = p - obs[h * width + w] = 6 - for p in raw_obs[3][1:]: - h, w = p - obs[h * width + w] = 7 - for i in range(self.num_agents + 2, self.n_player + 2): - h, w = raw_obs[i][0] - obs[h * width + w] = i + 3 - for p in raw_obs[i][1:]: - h, w = p - obs[h * width + w] = i + 6 - elif control_index == 5: - h1, w1 = raw_obs[6][0] - h2, w2 = raw_obs[7][0] - obs[h1 * width + w1] = 4 - obs[h2 * width + w2] = 5 - for p in raw_obs[6][1:]: - h, w = p - obs[h * width + w] = 6 - for p in raw_obs[7][1:]: - h, w = p - obs[h * width + w] = 7 - for i in range(2, self.num_agents + 2): - h, w = raw_obs[i][0] - obs[h * width + w] = i + 6 - for p in raw_obs[i][1:]: - h, w = p - obs[h * width + w] = i + 9 - elif control_index == 6: - h1, w1 = raw_obs[5][0] - h2, w2 = raw_obs[7][0] - obs[h1 * width + w1] = 4 - obs[h2 * width + w2] = 5 - for p in raw_obs[5][1:]: - h, w = p - obs[h * width + w] = 6 - for p in raw_obs[7][1:]: - h, w = p - obs[h * width + w] = 7 - for i in range(2, self.num_agents + 2): - h, w = raw_obs[i][0] - obs[h * width + w] = i + 6 - for p in raw_obs[i][1:]: - h, w = p - obs[h * width + w] = i + 9 - else: - h1, w1 = raw_obs[5][0] - h2, w2 = raw_obs[6][0] - obs[h1 * width + w1] = 4 - obs[h2 * width + w2] = 5 - for p in raw_obs[5][1:]: - h, w = p - obs[h * width + w] = 6 - for p in raw_obs[6][1:]: - h, w = p - obs[h * width + w] = 7 - for i in range(2, self.num_agents + 2): - h, w = raw_obs[i][0] - obs[h * width + w] = i + 6 - for p in raw_obs[i][1:]: - h, w = p - obs[h * width + w] = i + 9 - - obs_ = np.zeros(width * height * (self.channels - 1), dtype=int) - for i in range(width * height): - obs_[i * (self.channels - 1) + obs[i]] = ( - 1 # channels的最后一维是territory matrix, 此处不生成, 要减去 - ) - obs_ = obs_.reshape( - height, width, (self.channels - 1) - ) # (height, width, channels-1 ) - obs_ = obs_.transpose((2, 1, 0)) - - return obs_ - - def get_board(self, observation_list): - observation_len = len(observation_list.keys()) - teams = None - teams = [[0, 1, 2], [3, 4, 5]] # 3v3 - teams_count = len(teams) - snakes_count = sum([len(_) for _ in teams]) - - # read observation - obs = observation_list.copy() - board_height = obs["board_height"] # 10 - board_width = obs["board_width"] # 20 - # print("obs['controlled_snake_index'] is ", obs['controlled_snake_index']) - ctrl_agent_index = obs["controlled_snake_index"] - 2 # 0, 1, 2, 3, 4, 5 - # last_directions = obs['last_direction'] # ['up', 'left', 'down', 'left', 'left', 'left'] - beans_positions = obs[1] # e.g.[[7, 15], [4, 14], [5, 12], [4, 12], [5, 7]] - snakes = { - key - 2: SnakePos(obs[key], board_height, board_width, beans_positions) - for key in obs.keys() & {_ + 2 for _ in range(snakes_count)} - } # &: intersection - team_indexes = [_ for _ in teams if ctrl_agent_index in _][0] - - init_board = Board(board_height, board_width, snakes, beans_positions, teams) - bd = copy.deepcopy(init_board) - - with HiddenPrints(): - while not all( - _ == [] for _ in bd.open.values() - ): # loop until all values in open are empty list - bd.step() - - board = np.array(bd.board).transpose() - board = np.expand_dims(board, axis=0) - return board - - def init_state(self): - for i in range(self.n_player): - s = Snake(i + 2, self.board_width, self.board_height, self.init_len) - s_len = 1 - while s_len < self.init_len: - if s_len == 1 and i > 0: - origin_hit = self.is_hit(s.headPos, self.snakes_position) - else: - origin_hit = 0 - cur_head = s.move_and_add(self.snakes_position) - cur_hit = self.is_hit(cur_head, self.snakes_position) or self.is_hit( - cur_head, {i: s.segments[1:]} - ) - if origin_hit or cur_hit: - x = random.randrange(0, self.board_height) - y = random.randrange(0, self.board_width) - s.headPos = [x, y] - s.segments = [s.headPos] - s.direction = random.choice(self.actions) - s_len = 1 - else: - s_len += 1 - self.snakes_position[s.player_id] = s.segments - self.players.append(s) - - self.generate_beans() - self.init_info = { - "snakes_position": [ - list(v) - for k, v in sorted( - self.snakes_position.items(), key=lambda item: item[0] - ) - ], - "beans_position": list(self.beans_position), - } - directs = [] - for i in range(len(self.players)): - s = self.players[i] - directs.append(self.actions_name[s.direction]) - self.init_info["directions"] = directs - - return self.update_state() - - def update_state(self): - next_state = [ - [[0] * self.cell_dim for _ in range(self.board_width)] - for _ in range(self.board_height) - ] - for i in range(self.n_player): - snake = self.players[i] - for pos in snake.segments: - next_state[pos[0]][pos[1]][0] = i + 2 - - for pos in self.beans_position: - next_state[pos[0]][pos[1]][0] = 1 - - return next_state - - def step_before_info(self, info=""): - directs = [] - for i in range(len(self.players)): - s = self.players[i] - directs.append(self.actions_name[s.direction]) - info = {"directions": directs} - - return info - - def is_hit(self, cur_head, snakes_position): - is_hit = False - for k, v in snakes_position.items(): - for pos in v: - if cur_head == pos: - is_hit = True - # print("hit:", cur_head, snakes_position) - break - if is_hit: - break - - return is_hit - - def generate_beans(self): - all_valid_positions = set( - itertools.product(range(0, self.board_height), range(0, self.board_width)) - ) - all_valid_positions = all_valid_positions - set(map(tuple, self.beans_position)) - for positions in self.snakes_position.values(): - all_valid_positions = all_valid_positions - set(map(tuple, positions)) - - left_bean_num = self.n_beans - self.cur_bean_num - all_valid_positions = np.array(list(all_valid_positions)) - left_valid_positions = len(all_valid_positions) - - new_bean_num = ( - left_bean_num - if left_valid_positions > left_bean_num - else left_valid_positions - ) - - if left_valid_positions > 0: - new_bean_positions_idx = np.random.choice( - left_valid_positions, size=new_bean_num, replace=False - ) - new_bean_positions = all_valid_positions[new_bean_positions_idx] - else: - new_bean_positions = [] - - for new_bean_pos in new_bean_positions: - self.beans_position.append(list(new_bean_pos)) - self.cur_bean_num += 1 - - def get_all_observes(self, before_info=""): - self.all_observes = [] - for i in range(self.n_player): - each_obs = self.get_dict_observation(self.current_state, i + 2, before_info) - self.all_observes.append(each_obs) - - return self.all_observes - - def get_next_state(self, all_action): - before_info = self.step_before_info() - not_valid = self.is_not_valid_action(all_action) - if not not_valid: - # 各玩家行动 - # print("current_state", self.current_state) - eat_snakes = [0] * self.n_player - ally_reward = 0 - enemy_reward = 0 - for i in range(self.n_player): # 判断是否吃到了豆子 - snake = self.players[i] - act = self.actions[np.argmax(all_action[i][0])] - # print(snake.player_id, "此轮的动作为:", self.actions_name[act]) - snake.change_direction(act) - snake.move_and_add(self.snakes_position) # 更新snake.segment - if self.be_eaten(snake.headPos): # @yanxue - snake.snake_reward = 1 - eat_snakes[i] = 1 - else: - snake.snake_reward = 0 - snake.pop() - # print(snake.player_id, snake.segments) # @yanxue - snake_position = [[-1] * self.board_width for _ in range(self.board_height)] - re_generatelist = [0] * self.n_player - for i in range(self.n_player): # 判断是否相撞 - snake = self.players[i] - segment = snake.segments - for j in range(len(segment)): - x = segment[j][0] - y = segment[j][1] - if snake_position[x][y] != -1: - if j == 0: # 撞头 - re_generatelist[i] = 1 - compare_snake = self.players[snake_position[x][y]] - if [x, y] == compare_snake.segments[0]: # 两头相撞won - re_generatelist[snake_position[x][y]] = 1 - else: - snake_position[x][y] = i - for i in range(self.n_player): - snake = self.players[i] - if re_generatelist[i] == 1: - if eat_snakes[i] == 1: - snake.snake_reward = ( - self.init_len - len(snake.segments) + 1 - ) # 身体越长,惩罚越大 - else: - snake.snake_reward = self.init_len - len(snake.segments) - snake.segments = [] - - for i in range(self.num_agents): - ally_reward += self.players[i].snake_reward - for i in range(self.num_enemys): - enemy_reward += self.players[i + self.num_agents].snake_reward - alpha = 0.8 - for i in range(self.num_agents): - self.players[i].snake_reward = ( - self.players[i].snake_reward - enemy_reward / 3 - ) * alpha + ally_reward / 3 * (1 - alpha) - for i in range(self.num_agents, self.n_player): - self.players[i].snake_reward = ( - self.players[i].snake_reward - ally_reward / 3 - ) * alpha + enemy_reward / 3 * (1 - alpha) - - for i in range(self.n_player): - snake = self.players[i] - if re_generatelist[i] == 1: - snake = self.clear_or_regenerate(snake) - self.snakes_position[snake.player_id] = snake.segments - snake.score = snake.get_score() - # yanxue add - # 更新状态 - self.generate_beans() - - next_state = self.update_state() - self.current_state = next_state - self.step_cnt += 1 - - self.won = [0] * self.n_player - - for i in range(self.n_player): - s = self.players[i] - self.won[i] = s.score - info_after = {} - info_after["snakes_position"] = [ - list(v) - for k, v in sorted( - self.snakes_position.items(), key=lambda item: item[0] - ) - ] - info_after["beans_position"] = list(self.beans_position) - info_after["hit"] = re_generatelist - info_after["score"] = self.won - self.all_observes = self.get_all_observes(before_info) - - return self.all_observes, info_after - - def clear_or_regenerate(self, snake): - direct_x = [0, 1, -1, 0] - direct_y = [1, 0, 0, -1] - snake.segments = [] - snake.score = 0 - grid = self.get_render_data(self.update_state()) - - def can_regenerate(): - for x in range(self.board_height): - for y in range(self.board_width): - if grid[x][y] == 0: - q = [] - q.append([x, y]) - seg = [] - while q: - cur = q.pop(0) - if cur not in seg: - seg.append(cur) - for i in range(4): - nx = (direct_x[i] + cur[0]) % self.board_height - ny = (direct_y[i] + cur[1]) % self.board_width - # if nx < 0 or nx >= self.board_height or ny < 0 or ny >= self.board_width: - # continue - if grid[nx][ny] == 0 and [nx, ny] not in q: - grid[nx][ny] = 1 - q.append([nx, ny]) - if len(seg) == self.init_len: - # print("regenerate") - if len(seg) < 3: - snake.direction = random.choice(self.actions) - elif len(seg) == 3: - mid = ( - [seg[1][0], seg[2][1]], - [seg[2][0], seg[1][1]], - ) - if seg[0] in mid: - seg[0], seg[1] = seg[1], seg[0] - snake.segments = seg - snake.headPos = seg[0] - if seg[0][0] == seg[1][0]: - # 右 - if seg[0][1] > seg[1][1]: - snake.direction = 1 - # 左 - else: - snake.direction = -1 - elif seg[0][1] == seg[1][1]: - # 下 - if seg[0][0] > seg[1][0]: - snake.direction = 2 - # 上 - else: - snake.direction = -2 - # print("re head", snake.headPos) # 输出重新生成的蛇 - # print("re snakes segments", snake.segments) - return True - # print("clear") - return False - - flg = can_regenerate() - if not flg: - self.terminate_flg = True - # print(self.terminate_flg) - return snake - - def is_not_valid_action(self, all_action): - not_valid = 0 - if len(all_action) != self.n_player: - raise Exception("all action 维度不正确!", len(all_action)) - - for i in range(self.n_player): - if len(all_action[i][0]) != 4: - raise Exception("玩家%d joint action维度不正确!" % i, all_action[i]) - return not_valid - - def get_reward(self, all_action): - r = [0] * self.n_player - for i in range(self.n_player): - r[i] = self.players[i].snake_reward - self.n_return[i] += r[i] - # print("score:", self.won) - return r - - def get_final_reward(self, reward): - ally_reward = reward[0] + reward[1] + reward[2] - enemy_reward = reward[3] + reward[4] + reward[5] - if ally_reward > enemy_reward: - reward[0] += 10 - reward[1] += 10 - reward[2] += 10 - reward[3] -= 10 - reward[4] -= 10 - reward[5] -= 10 - elif ally_reward < enemy_reward: - reward[3] += 10 - reward[4] += 10 - reward[5] += 10 - reward[0] -= 10 - reward[1] -= 10 - reward[2] -= 10 - return reward - - def is_terminal(self): - all_member = self.n_beans - # all_member = len(self.beans_position) - for s in self.players: - all_member += len(s.segments) - is_done = ( - self.step_cnt > self.max_step - or all_member > self.board_height * self.board_width - ) - - return is_done or self.terminate_flg - - def encode(self, actions): - joint_action = self.init_action_space() - if len(actions) != self.n_player: - raise Exception("action输入维度不正确!", len(actions)) - for i in range(self.n_player): - joint_action[i][0][int(actions[i])] = 1 - return joint_action - - def get_terminal_actions(self): - print("请输入%d个玩家的动作方向[0-3](上下左右),空格隔开:" % self.n_player) - cur = input() - actions = cur.split(" ") - return self.encode(actions) - - def be_eaten(self, snake_pos): - for bean in self.beans_position: - if snake_pos[0] == bean[0] and snake_pos[1] == bean[1]: - self.beans_position.remove(bean) - self.cur_bean_num -= 1 - return True - return False - - def get_action_dim(self): - action_dim = 1 - for i in range(len(self.joint_action_space[0])): - action_dim *= self.joint_action_space[0][i].n - - return action_dim - - def draw_board(self): - cols = [chr(i) for i in range(65, 65 + self.board_width)] - s = ", ".join(cols) - print(" ", s) - for i in range(self.board_height): - # print(i) - print(chr(i + 65), self.current_state[i]) - - @staticmethod - def _render_board(state, board, colors, unit, fix, extra_info): - im = GridGame._render_board(state, board, colors, unit, fix) - draw = ImageDraw.Draw(im) - # fnt = ImageFont.truetype("Courier.dfont", 16) - fnt = ImageFont.load_default() - for i, pos in zip(count(1), extra_info): - x, y = pos - draw.text( - ((y + 1 / 4) * unit, (x + 1 / 4) * unit), - "#{}".format(i), - font=fnt, - fill=(0, 0, 0), - ) - - return im - - def render_board(self): - extra_info = [tuple(x.headPos) for x in self.players] - im_data = np.array( - SnakeEatBeans._render_board( - self.get_render_data(self.current_state), - self.grid, - self.colors, - self.grid_unit, - self.grid_unit_fix, - extra_info, - ) - ) - return im_data - - @staticmethod - def parse_extra_info(data): - # return eval(re.search(r'({.*})', data['info_after']).group(1)).values() - # d = (eval(eval(data)['snakes_position']).values()) - if isinstance(data, str): - d = eval(data)["snakes_position"] - else: - d = data["snakes_position"] - - return [i[0] for i in d] - - -class Snake: - def __init__(self, player_id, board_width, board_height, init_len): - self.actions = [-2, 2, -1, 1] - self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"} - self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右] - self.board_width = board_width - self.board_height = board_height - x = random.randrange(0, board_height) - y = random.randrange(0, board_width) - self.segments = [[x, y]] - self.headPos = self.segments[0] - self.player_id = player_id - self.score = 0 - self.snake_reward = 0 - self.init_len = init_len - - def get_score(self): - return len(self.segments) - self.init_len - - def change_direction(self, act): - if act + self.direction != 0: - self.direction = act - else: - n_direct = random.choice(self.actions) - while n_direct + self.direction == 0: - n_direct = random.choice(self.actions) - self.direction = n_direct - # print("方向不合法,重新生成") - # print("direction", self.actions_name[self.direction]) - - # 超过边界,可以穿越 - def update_position(self, position): - position[0] %= self.board_height - position[1] %= self.board_width - return position - - def move_and_add(self, snakes_position): - cur_head = list(self.headPos) - # 根据方向移动蛇头的坐标 - # 右 - if self.direction == 1: - cur_head[1] += 1 - # 左 - if self.direction == -1: - cur_head[1] -= 1 - # 上 - if self.direction == -2: - cur_head[0] -= 1 - # 下 - if self.direction == 2: - cur_head[0] += 1 - - cur_head = self.update_position(cur_head) - # print("cur head", cur_head) - # print("cur snakes positions", snakes_position) - - self.segments.insert(0, cur_head) - self.headPos = self.segments[0] - return cur_head - - def pop(self): - self.segments.pop() # 在蛇尾减去一格 diff --git a/openrl/envs/toy_envs/__init__.py b/openrl/envs/toy_envs/__init__.py index 4e6588ef..cf785cc5 100644 --- a/openrl/envs/toy_envs/__init__.py +++ b/openrl/envs/toy_envs/__init__.py @@ -18,25 +18,12 @@ from typing import Any from openrl.envs.toy_envs.bit_flipping_env import BitFlippingEnv -from openrl.envs.toy_envs.identity_env import ( - FakeImageEnv, - IdentityEnv, - IdentityEnvBox, - IdentityEnvcontinuous, - IdentityEnvMultiBinary, - IdentityEnvMultiDiscrete, -) -from openrl.envs.toy_envs.multi_input_envs import SimpleMultiObsEnv +from openrl.envs.toy_envs.identity_env import IdentityEnv, IdentityEnvcontinuous __all__ = [ "BitFlippingEnv", - "FakeImageEnv", "IdentityEnv", "IdentityEnvcontinuous", - "IdentityEnvBox", - "IdentityEnvMultiBinary", - "IdentityEnvMultiDiscrete", - "SimpleMultiObsEnv", ] @@ -49,13 +36,8 @@ env_dict = { "BitFlippingEnv": BitFlippingEnv, - "FakeImageEnv": FakeImageEnv, "IdentityEnv": IdentityEnv, "IdentityEnvcontinuous": IdentityEnvcontinuous, - "IdentityEnvBox": IdentityEnvBox, - "IdentityEnvMultiBinary": IdentityEnvMultiBinary, - "IdentityEnvMultiDiscrete": IdentityEnvMultiDiscrete, - "SimpleMultiObsEnv": SimpleMultiObsEnv, } diff --git a/openrl/envs/toy_envs/bit_flipping_env.py b/openrl/envs/toy_envs/bit_flipping_env.py index 0d77ebd4..5534ed37 100644 --- a/openrl/envs/toy_envs/bit_flipping_env.py +++ b/openrl/envs/toy_envs/bit_flipping_env.py @@ -5,8 +5,6 @@ from gymnasium import Env, spaces from gymnasium.envs.registration import EnvSpec -from openrl.utils.type_aliases import GymStepReturn - class BitFlippingEnv(Env): """ @@ -175,7 +173,7 @@ def reset( self.state = self.obs_space.sample() return self._get_obs(), {} - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]): if self.continuous: self.state[action > 0] = 1 - self.state[action > 0] else: diff --git a/openrl/envs/toy_envs/identity_env.py b/openrl/envs/toy_envs/identity_env.py index dd653626..c3867756 100644 --- a/openrl/envs/toy_envs/identity_env.py +++ b/openrl/envs/toy_envs/identity_env.py @@ -6,8 +6,6 @@ from gymnasium.envs.registration import EnvSpec from gymnasium.utils import seeding -from openrl.utils.type_aliases import GymStepReturn - T = TypeVar("T", int, np.ndarray) @@ -30,6 +28,7 @@ def __init__( ``dim`` and ``space``. :param ep_length: the length of each episode in time_steps """ + if space is None: if dim is None: dim = 2 @@ -38,12 +37,14 @@ def __init__( assert ( dim is None ), "arguments for both 'dim' and 'space' provided: at most one allowed" + self.dim = dim self.observation_space = spaces.Discrete(1) self.action_space = space self.ep_length = ep_length self.current_step = 0 self.num_resets = -1 # Becomes 0 after __init__ exits. + self.metadata.update({"name": IdentityEnv}) def reset( self, @@ -53,6 +54,8 @@ def reset( ) -> T: if seed is not None: self.seed(seed) + if self._np_random is None: + self.seed(0) self.current_step = 0 self.num_resets += 1 self._choose_next_state() @@ -67,6 +70,7 @@ def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]: def _choose_next_state(self) -> None: # self.state = [self.action_space.sample()] + assert self.dim is not None self.state = [self._np_random.integers(0, self.dim)] def _get_reward(self, action: T) -> float: @@ -153,114 +157,3 @@ def _get_reward(self, action: T) -> float: def render(self, mode: str = "human") -> None: pass - - -# Not Work Yet -class IdentityEnvBox(IdentityEnv[np.ndarray]): - def __init__( - self, - low: float = -1.0, - high: float = 1.0, - eps: float = 0.05, - ep_length: int = 100, - ): - """ - Identity environment for testing purposes - - :param low: the lower bound of the box dim - :param high: the upper bound of the box dim - :param eps: the epsilon bound for correct value - :param ep_length: the length of each episode in timesteps - """ - space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32) - super().__init__(ep_length=ep_length, space=space) - self.eps = eps - - def step( - self, action: np.ndarray - ) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]: - reward = self._get_reward(action) - self._choose_next_state() - self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} - - def _get_reward(self, action: np.ndarray) -> float: - return ( - 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 - ) - - -# Not Work Yet -class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]): - def __init__(self, dim: int = 1, ep_length: int = 100) -> None: - """ - Identity environment for testing purposes - - :param dim: the size of the dimensions you want to learn - :param ep_length: the length of each episode in timesteps - """ - space = spaces.MultiDiscrete([dim, dim]) - super().__init__(ep_length=ep_length, space=space) - - -# Not Work Yet -class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]): - def __init__(self, dim: int = 1, ep_length: int = 100) -> None: - """ - Identity environment for testing purposes - - :param dim: the size of the dimensions you want to learn - :param ep_length: the length of each episode in timesteps - """ - space = spaces.MultiBinary(dim) - super().__init__(ep_length=ep_length, space=space) - - -# Not Work Yet -class FakeImageEnv(gym.Env): - """ - Fake image environment for testing purposes, it mimics Atari games. - - :param action_dim: Number of discrete actions - :param screen_height: Height of the image - :param screen_width: Width of the image - :param n_channels: Number of color channels - :param discrete: Create discrete action space instead of continuous - :param channel_first: Put channels on first axis instead of last - """ - - def __init__( - self, - action_dim: int = 6, - screen_height: int = 84, - screen_width: int = 84, - n_channels: int = 1, - discrete: bool = True, - channel_first: bool = False, - ) -> None: - self.observation_shape = (screen_height, screen_width, n_channels) - if channel_first: - self.observation_shape = (n_channels, screen_height, screen_width) - self.observation_space = spaces.Box( - low=0, high=255, shape=self.observation_shape, dtype=np.uint8 - ) - if discrete: - self.action_space = spaces.Discrete(action_dim) - else: - self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32) - self.ep_length = 10 - self.current_step = 0 - - def reset(self) -> np.ndarray: - self.current_step = 0 - return self.observation_space.sample() - - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: - reward = 0.0 - self.current_step += 1 - done = self.current_step >= self.ep_length - return self.observation_space.sample(), reward, done, {} - - def render(self, mode: str = "human") -> None: - pass diff --git a/openrl/envs/toy_envs/multi_input_envs.py b/openrl/envs/toy_envs/multi_input_envs.py deleted file mode 100644 index 952a5b04..00000000 --- a/openrl/envs/toy_envs/multi_input_envs.py +++ /dev/null @@ -1,187 +0,0 @@ -from typing import Dict, Union - -import gymnasium as gym -import numpy as np -from gymnasium import spaces - -from openrl.utils.type_aliases import GymStepReturn - - -# Not Work Yet -class SimpleMultiObsEnv(gym.Env): - """ - Base class for GridWorld-based MultiObs Environments 4x4 grid world. - - .. code-block:: text - - ____________ - | 0 1 2 3| - | 4|¯5¯¯6¯| 7| - | 8|_9_10_|11| - |12 13 14 15| - ¯¯¯¯¯¯¯¯¯¯¯¯¯¯ - - start is 0 - states 5, 6, 9, and 10 are blocked - goal is 15 - actions are = [left, down, right, up] - - simple linear state env of 15 states but encoded with a vector and an image observation: - each column is represented by a random vector and each row is - represented by a random image, both sampled once at creation time. - - :param num_col: Number of columns in the grid - :param num_row: Number of rows in the grid - :param random_start: If true, agent starts in random position - :param channel_last: If true, the image will be channel last, else it will be channel first - """ - - def __init__( - self, - num_col: int = 4, - num_row: int = 4, - random_start: bool = True, - discrete_actions: bool = True, - channel_last: bool = True, - ): - super().__init__() - - self.vector_size = 5 - if channel_last: - self.img_size = [64, 64, 1] - else: - self.img_size = [1, 64, 64] - - self.random_start = random_start - self.discrete_actions = discrete_actions - if discrete_actions: - self.action_space = spaces.Discrete(4) - else: - self.action_space = spaces.Box(0, 1, (4,)) - - self.observation_space = spaces.Dict( - spaces={ - "vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64), - "img": spaces.Box(0, 255, self.img_size, dtype=np.uint8), - } - ) - self.count = 0 - # Timeout - self.max_count = 100 - self.log = "" - self.state = 0 - self.action2str = ["left", "down", "right", "up"] - self.init_possible_transitions() - - self.num_col = num_col - self.state_mapping = [] - self.init_state_mapping(num_col, num_row) - - self.max_state = len(self.state_mapping) - 1 - - def init_state_mapping(self, num_col: int, num_row: int) -> None: - """ - Initializes the state_mapping array which holds the observation values for each state - - :param num_col: Number of columns. - :param num_row: Number of rows. - """ - # Each column is represented by a random vector - col_vecs = np.random.random((num_col, self.vector_size)) - # Each row is represented by a random image - row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8) - - for i in range(num_col): - for j in range(num_row): - self.state_mapping.append( - {"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)} - ) - - def get_state_mapping(self) -> Dict[str, np.ndarray]: - """ - Uses the state to get the observation mapping. - - :return: observation dict {'vec': ..., 'img': ...} - """ - return self.state_mapping[self.state] - - def init_possible_transitions(self) -> None: - """ - Initializes the transitions of the environment - The environment exploits the cardinal directions of the grid by noting that - they correspond to simple addition and subtraction from the cell id within the grid - - - up => means moving up a row => means subtracting the length of a column - - down => means moving down a row => means adding the length of a column - - left => means moving left by one => means subtracting 1 - - right => means moving right by one => means adding 1 - - Thus one only needs to specify in which states each action is possible - in order to define the transitions of the environment - """ - self.left_possible = [1, 2, 3, 13, 14, 15] - self.down_possible = [0, 4, 8, 3, 7, 11] - self.right_possible = [0, 1, 2, 12, 13, 14] - self.up_possible = [4, 8, 12, 7, 11, 15] - - def step(self, action: Union[float, np.ndarray]) -> GymStepReturn: - """ - Run one timestep of the environment's dynamics. When end of - episode is reached, you are responsible for calling `reset()` - to reset this environment's state. - Accepts an action and returns a tuple (observation, reward, done, info). - - :param action: - :return: tuple (observation, reward, done, info). - """ - if not self.discrete_actions: - action = np.argmax(action) - else: - action = int(action) - - self.count += 1 - - prev_state = self.state - - reward = -0.1 - # define state transition - if self.state in self.left_possible and action == 0: # left - self.state -= 1 - elif self.state in self.down_possible and action == 1: # down - self.state += self.num_col - elif self.state in self.right_possible and action == 2: # right - self.state += 1 - elif self.state in self.up_possible and action == 3: # up - self.state -= self.num_col - - got_to_end = self.state == self.max_state - reward = 1 if got_to_end else reward - done = self.count > self.max_count or got_to_end - - self.log = ( - f"Went {self.action2str[action]} in state {prev_state}, got to state" - f" {self.state}" - ) - - return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end} - - def render(self, mode: str = "human") -> None: - """ - Prints the log of the environment. - - :param mode: - """ - print(self.log) - - def reset(self) -> Dict[str, np.ndarray]: - """ - Resets the environment state and step count and returns reset observation. - - :return: observation dict {'vec': ..., 'img': ...} - """ - self.count = 0 - if not self.random_start: - self.state = 0 - else: - self.state = np.random.randint(0, self.max_state) - return self.state_mapping[self.state] diff --git a/openrl/envs/vec_env/async_venv.py b/openrl/envs/vec_env/async_venv.py index 02d6fec2..141532ba 100644 --- a/openrl/envs/vec_env/async_venv.py +++ b/openrl/envs/vec_env/async_venv.py @@ -1,4 +1,5 @@ """An async vector environment.""" + import multiprocessing as mp import sys import time @@ -233,10 +234,8 @@ def reset_send( if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `reset_send` while waiting for a pending call to" - f" `{self._state.value}` to complete" - ), + "Calling `reset_send` while waiting for a pending call to" + f" `{self._state.value}` to complete", self._state.value, ) @@ -328,10 +327,8 @@ def step_send(self, actions: np.ndarray): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `step_send` while waiting for a pending call to" - f" `{self._state.value}` to complete." - ), + "Calling `step_send` while waiting for a pending call to" + f" `{self._state.value}` to complete.", self._state.value, ) @@ -341,9 +338,7 @@ def step_send(self, actions: np.ndarray): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP - def step_fetch( - self, timeout: Optional[Union[int, float]] = None - ) -> Union[ + def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[ Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]], Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]], ]: @@ -575,10 +570,8 @@ def call_send(self, name: str, *args, **kwargs): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `call_send` while waiting " - f"for a pending call to `{self._state.value}` to complete." - ), + "Calling `call_send` while waiting " + f"for a pending call to `{self._state.value}` to complete.", str(self._state.value), ) @@ -635,10 +628,8 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `exec_func_send` while waiting " - f"for a pending call to `{self._state.value}` to complete." - ), + "Calling `exec_func_send` while waiting " + f"for a pending call to `{self._state.value}` to complete.", str(self._state.value), ) @@ -674,6 +665,7 @@ def exec_func_fetch(self, timeout: Union[int, float, None] = None) -> list: ) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) + self._raise_if_errors(successes) self._state = AsyncState.DEFAULT @@ -715,10 +707,8 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]): if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - ( - "Calling `set_attr` while waiting " - f"for a pending call to `{self._state.value}` to complete." - ), + "Calling `set_attr` while waiting " + f"for a pending call to `{self._state.value}` to complete.", str(self._state.value), ) @@ -732,7 +722,7 @@ def _worker( index: int, env_fn: callable, pipe: Connection, - parent_pipe: Connection, + parent_pipe: Optional[Connection], shared_memory: bool, error_queue: Queue, auto_reset: bool = True, @@ -756,7 +746,8 @@ def prepare_obs(observation): observation = None return observation - parent_pipe.close() + if parent_pipe is not None: + parent_pipe.close() try: while True: command, data = pipe.recv() @@ -837,7 +828,7 @@ def prepare_obs(observation): ) elif command == "_func_exec": function, indices, args, kwargs = data - if index in indices: + if indices is None or index in indices: if callable(function): pipe.send((function(env, *args, **kwargs), True)) else: diff --git a/openrl/envs/vec_env/base_venv.py b/openrl/envs/vec_env/base_venv.py index f2e54744..c10d0d0d 100644 --- a/openrl/envs/vec_env/base_venv.py +++ b/openrl/envs/vec_env/base_venv.py @@ -272,7 +272,7 @@ def exec_func_fetch(self, timeout: Union[int, float, None] = None) -> list: """ def exec_func( - self, func: Callable, indices: List[int], *args, **kwargs + self, func: Callable, indices: Optional[List[int]] = None, *args, **kwargs ) -> List[Any]: """Call a method, or get a property, from each parallel environment. diff --git a/openrl/envs/vec_env/sync_venv.py b/openrl/envs/vec_env/sync_venv.py index a670ec33..1e208e4c 100644 --- a/openrl/envs/vec_env/sync_venv.py +++ b/openrl/envs/vec_env/sync_venv.py @@ -15,6 +15,7 @@ # limitations under the License. """""" +import time from copy import deepcopy from typing import Any, Callable, Iterable, List, Optional, Sequence, Union @@ -202,6 +203,7 @@ def _step(self, actions: ActType): self._truncateds[i], info, ) = returns + need_reset = _need_reset and ( all(self._terminateds[i]) or all(self._truncateds[i]) ) @@ -281,7 +283,9 @@ def env_name(self): else: return self.envs[0].unwrapped.spec.id - def exec_func(self, func: Callable, indices: List[int], *args, **kwargs) -> tuple: + def exec_func( + self, func: Callable, indices: Optional[List[int]] = None, *args, **kwargs + ) -> tuple: """Calls the method with name and applies args and kwargs. Args: @@ -294,7 +298,7 @@ def exec_func(self, func: Callable, indices: List[int], *args, **kwargs) -> tupl """ results = [] for i, env in enumerate(self.envs): - if i in indices: + if indices is None or i in indices: if callable(func): results.append(func(env, *args, **kwargs)) else: diff --git a/openrl/envs/vec_env/wrappers/reward_wrapper.py b/openrl/envs/vec_env/wrappers/reward_wrapper.py index d0a4d630..25cdc424 100644 --- a/openrl/envs/vec_env/wrappers/reward_wrapper.py +++ b/openrl/envs/vec_env/wrappers/reward_wrapper.py @@ -29,8 +29,8 @@ class RewardWrapper(VecEnvWrapper): def __init__(self, env: BaseVecEnv, reward_class: BaseReward): super().__init__(env) self.reward_class = reward_class - if len(self.reward_class.inner_rew_funcs) > 0: - env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs}) + # if len(self.reward_class.inner_rew_funcs) > 0: + # env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs}) def step( self, action: ActType, extra_data: Optional[Dict[str, Any]] diff --git a/openrl/envs/wrappers/extra_wrappers.py b/openrl/envs/wrappers/extra_wrappers.py index da819a87..27359d9e 100644 --- a/openrl/envs/wrappers/extra_wrappers.py +++ b/openrl/envs/wrappers/extra_wrappers.py @@ -21,6 +21,9 @@ import gymnasium as gym import numpy as np from gymnasium import spaces +from gymnasium.utils.step_api_compatibility import ( + convert_to_terminated_truncated_step_api, +) from gymnasium.wrappers import AutoResetWrapper, StepAPICompatibility from openrl.envs.wrappers import BaseObservationWrapper, BaseRewardWrapper, BaseWrapper @@ -46,6 +49,76 @@ def step(self, action): return obs, total_reward, term, trunc, info +def convert_to_done_step_api( + step_returns, + is_vector_env: bool = False, +): + if len(step_returns) == 4: + return step_returns + else: + assert len(step_returns) == 5 + observations, rewards, terminated, truncated, infos = step_returns + + # Cases to handle - info single env / info vector env (list) / info vector env (dict) + # if truncated[0]: + # import pdb; + # pdb.set_trace() + + if is_vector_env is False: + if isinstance(terminated, list): + infos["TimeLimit.truncated"] = truncated[0] and not terminated[0] + done_return = np.logical_or(terminated, truncated) + else: + if truncated or terminated: + infos["TimeLimit.truncated"] = truncated and not terminated + done_return = terminated or truncated + return ( + observations, + rewards, + done_return, + infos, + ) + elif isinstance(infos, list): + for info, env_truncated, env_terminated in zip( + infos, truncated, terminated + ): + if env_truncated or env_terminated: + info["TimeLimit.truncated"] = env_truncated and not env_terminated + return ( + observations, + rewards, + np.logical_or(terminated, truncated), + infos, + ) + elif isinstance(infos, dict): + if np.logical_or(np.any(truncated), np.any(terminated)): + infos["TimeLimit.truncated"] = np.logical_and( + truncated, np.logical_not(terminated) + ) + return ( + observations, + rewards, + np.logical_or(terminated, truncated), + infos, + ) + else: + raise TypeError( + "Unexpected value of infos, as is_vector_envs=False, expects `info` to" + f" be a list or dict, actual type: {type(infos)}" + ) + + +def step_api_compatibility( + step_returns, + output_truncation_bool: bool = True, + is_vector_env: bool = False, +): + if output_truncation_bool: + return convert_to_terminated_truncated_step_api(step_returns, is_vector_env) + else: + return convert_to_done_step_api(step_returns, is_vector_env) + + class RemoveTruncated(StepAPICompatibility, BaseWrapper): def __init__( self, @@ -54,6 +127,12 @@ def __init__( output_truncation_bool = False super().__init__(env, output_truncation_bool=output_truncation_bool) + def step(self, action): + step_returns = self.env.step(action) + return step_api_compatibility( + step_returns, self.output_truncation_bool, self.is_vector_env + ) + class FlattenObservation(BaseObservationWrapper): def __init__(self, env: gym.Env): diff --git a/openrl/envs/wrappers/pettingzoo_wrappers.py b/openrl/envs/wrappers/pettingzoo_wrappers.py index 226fdb9f..c571ff79 100644 --- a/openrl/envs/wrappers/pettingzoo_wrappers.py +++ b/openrl/envs/wrappers/pettingzoo_wrappers.py @@ -96,8 +96,9 @@ def last(self, observe: bool = True): winners = None losers = None + for agent in self.terminations: - if self.terminations[agent]: + if self.terminations[agent] or all(self.truncations): if winners is None: winners = self.get_winners() losers = [player for player in self.agents if player not in winners] diff --git a/openrl/envs/wrappers/util.py b/openrl/envs/wrappers/util.py index a0a97576..614a5879 100644 --- a/openrl/envs/wrappers/util.py +++ b/openrl/envs/wrappers/util.py @@ -41,7 +41,9 @@ def nest_expand_dim(input: Any) -> Any: elif input is None: return [input] else: - raise NotImplementedError("Not support type: {}".format(type(input))) + raise NotImplementedError( + "Not support type: {}, value={}".format(type(input), input) + ) def unwrap_wrapper( diff --git a/openrl/modules/common/ppo_net.py b/openrl/modules/common/ppo_net.py index 93dbaa64..7c537c91 100644 --- a/openrl/modules/common/ppo_net.py +++ b/openrl/modules/common/ppo_net.py @@ -15,7 +15,7 @@ # limitations under the License. """""" - +import copy from typing import Any, Dict, Optional, Tuple, Union import gymnasium as gym @@ -30,6 +30,23 @@ from openrl.utils.util import set_seed +def reset_rnn_states( + rnn_states, episode_starts, env_num, agent_num, rnn_layers, hidden_size +): + # First we reshape the episode_starts to match the rnn_states shape + # Since episode_starts affects all agents in the environment, we repeat it agent_num times + episode_starts = np.repeat(copy.copy(episode_starts), agent_num) + # We then need to expand the dimensions of episode_starts to match rnn_states + # The new shape of episode_starts should be (env_num * agent_num, 1, 1) to broadcast correctly + episode_starts = episode_starts[:, None, None] + # Now, episode_starts should broadcast over the last two dimensions of rnn_states when multiplied + # We want to set rnn_states to zero where episode_starts is 1, so we invert the episode_starts as a mask + mask = 1 - episode_starts + # Apply the mask to rnn_states, setting the appropriate states to zero + rnn_states *= mask + return rnn_states + + class PPONet(BaseNet): def __init__( self, @@ -89,7 +106,18 @@ def act( observation: Union[np.ndarray, Dict[str, np.ndarray]], action_masks: Optional[np.ndarray] = None, deterministic: bool = False, + episode_starts: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + if episode_starts is not None: + self.rnn_states_actor = reset_rnn_states( + self.rnn_states_actor, + episode_starts, + self.env.parallel_env_num, + self.env.agent_num, + self.rnn_states_actor.shape[1], + self.rnn_states_actor.shape[2], + ) + actions, self.rnn_states_actor = self.module.act( obs=observation, rnn_states_actor=self.rnn_states_actor, diff --git a/openrl/modules/networks/policy_network.py b/openrl/modules/networks/policy_network.py index 422eaa58..e3ebb025 100644 --- a/openrl/modules/networks/policy_network.py +++ b/openrl/modules/networks/policy_network.py @@ -56,6 +56,9 @@ def __init__( self.use_half = use_half self.tpdv = dict(dtype=torch.float32, device=device) + self._use_fp16 = cfg.use_fp16 + assert not (cfg.use_fp16 and not cfg.use_deepspeed) + policy_obs_shape = get_policy_obs_space(input_space) if "Dict" in policy_obs_shape.__class__.__name__: @@ -135,8 +138,9 @@ def forward_original( policy_obs[key].half() else: policy_obs = check(policy_obs, self.use_half, self.tpdv) - # if self.use_half: - # obs = obs.half() + if self.use_half or self._use_fp16: + policy_obs = policy_obs.half() + rnn_states = check(rnn_states, self.use_half, self.tpdv) masks = check(masks, self.use_half, self.tpdv) @@ -165,6 +169,8 @@ def eval_actions( obs[key] = check(obs[key], self.use_half, self.tpdv) else: obs = check(obs, self.use_half, self.tpdv) + if self._use_fp16: + obs = obs.half() rnn_states = check(rnn_states, self.use_half, self.tpdv) action = check(action, self.use_half, self.tpdv) @@ -202,6 +208,8 @@ def get_policy_values(self, obs, rnn_states, masks): obs[key] = check(obs[key], self.use_half, self.tpdv) else: obs = check(obs).to(**self.tpdv) + if self.use_half or self._use_fp16: + obs = obs.half() rnn_states = check(rnn_states, self.use_half, self.tpdv) masks = check(masks, self.use_half, self.tpdv) diff --git a/openrl/modules/networks/policy_network_gpt.py b/openrl/modules/networks/policy_network_gpt.py new file mode 100644 index 00000000..193094a7 --- /dev/null +++ b/openrl/modules/networks/policy_network_gpt.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2021 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Any, Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +from transformers.modeling_utils import unwrap_model + +from openrl.buffers.utils.util import get_policy_obs, get_policy_obs_space +from openrl.envs.nlp.utils.distribution import CategoricalDistribution +from openrl.modules.networks.base_policy_network import BasePolicyNetwork +from openrl.modules.networks.utils.act import ACTLayer +from openrl.modules.networks.utils.cnn import CNNBase +from openrl.modules.networks.utils.mix import MIXBase +from openrl.modules.networks.utils.mlp import MLPBase, MLPLayer +from openrl.modules.networks.utils.popart import PopArt +from openrl.modules.networks.utils.rnn import RNNLayer +from openrl.modules.networks.utils.util import init +from openrl.utils.util import check_v2 as check + + +class PolicyNetworkGPT(BasePolicyNetwork): + def __init__( + self, + cfg, + input_space, + action_space, + device=torch.device("cpu"), + use_half=False, + disable_drop_out: bool = True, + extra_args=None, + ) -> None: + self.device = device + self.use_fp16 = cfg.use_fp16 + self.use_deepspeed = cfg.use_deepspeed + self.use_half = False + self.use_data_parallel = not cfg.use_deepspeed # default to use data parallel + self.use_model_parallel = False + + assert not (self.use_deepspeed and self.use_data_parallel) + assert not (self.use_deepspeed and self.use_model_parallel) + assert not (self.use_data_parallel and self.use_model_parallel) + + super(PolicyNetworkGPT, self).__init__(cfg, device) + + self.disable_drop_out = disable_drop_out + + self._action_dist = CategoricalDistribution(action_space.n) + + from transformers import AutoConfig, AutoModelForCausalLM + + config = AutoConfig.from_pretrained(cfg.model_path) + config_dict = config.to_dict() + for key in config_dict: + if "drop" in key: + config_dict[key] = 0.0 + config = config.from_dict(config_dict) + self._policy_model = AutoModelForCausalLM.from_pretrained( + cfg.model_path, config=config + ) + self._policy_model.config.use_cache = False + + if torch.cuda.is_available(): + if self.use_model_parallel: + self._policy_model.parallelize() + elif self.use_data_parallel: + if self.use_half: + self._policy_model = self._policy_model.half() + self._policy_model = torch.nn.DataParallel(self._policy_model) + self._policy_model = self._policy_model.to(self.device) + + def forward(self, forward_type, *args, **kwargs): + if forward_type == "original": + return self.forward_original(*args, **kwargs) + elif forward_type == "eval_actions": + return self.eval_actions(*args, **kwargs) + else: + raise NotImplementedError + + def _prepare_inputs_for_model( + self, + model: Any, + input_ids: torch.tensor, + model_kwargs: Optional[Dict[str, torch.tensor]] = None, + ): + model_inputs = unwrap_model(model).prepare_inputs_for_generation( + input_ids, **model_kwargs + ) + + if self.use_model_parallel: + model_inputs = { + key: ( + value.to(model.transformer.first_device) + if isinstance(value, torch.Tensor) + and hasattr(model.transformer, "first_device") + else value + ) + for key, value in model_inputs.items() + } + + return model_inputs + + def forward_original( + self, raw_obs, rnn_states, masks, action_masks=None, deterministic=False + ): + for key in raw_obs.keys(): + raw_obs[key] = ( + torch.from_numpy(raw_obs[key]) + if type(raw_obs[key]) == np.ndarray + else raw_obs[key] + ) + rnn_states = check(rnn_states) + + if self.use_half: + input_ids = raw_obs["input_encoded_pt"].int() + attention_mask = raw_obs["input_attention_mask_pt"].int() + else: + input_ids = raw_obs["input_encoded_pt"].long() + attention_mask = raw_obs["input_attention_mask_pt"].long() + + for key in raw_obs.keys(): + if self.use_data_parallel: + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + else: + input_ids = input_ids.to(self._policy_model.device) + attention_mask = attention_mask.to(self._policy_model.device) + + past_model_kwargs = None + + if past_model_kwargs is None: + past_model_kwargs = { + "attention_mask": attention_mask, + } + + model_inputs = self._prepare_inputs_for_model( + self._policy_model, input_ids, past_model_kwargs + ) + + # forward pass to transformers + output = self._policy_model(**model_inputs) + + # compute action probs - policy head + next_token_logits = output.logits[:, -1] + dist = self._action_dist.proba_distribution(action_logits=next_token_logits) + + actions = dist.mode() if deterministic else dist.sample() + action_log_probs = dist.log_prob(actions) + + return actions.unsqueeze(-1), action_log_probs.unsqueeze(-1), rnn_states + + def eval_actions( + self, obs, rnn_states, action, masks, action_masks=None, active_masks=None + ): + for key in obs.keys(): + obs[key] = ( + torch.from_numpy(obs[key]) if type(obs[key]) == np.ndarray else obs[key] + ) + if self.use_data_parallel: + obs[key] = obs[key].to(self.device) + else: + obs[key] = obs[key].to(self._policy_model.device) + if self.use_data_parallel: + action = check(action).to(self.device).squeeze() + else: + action = check(action).to(self._policy_model.device).squeeze() + rnn_states = check(rnn_states) + + if self.half: + input_ids = obs["input_encoded_pt"].int() + attention_mask = obs["input_attention_mask_pt"].int() + else: + input_ids = obs["input_encoded_pt"].long() + attention_mask = obs["input_attention_mask_pt"].long() + + past_model_kwargs = None + + if past_model_kwargs is None: + past_model_kwargs = { + "attention_mask": attention_mask, + } + + model_inputs = self._prepare_inputs_for_model( + self._policy_model, input_ids, past_model_kwargs + ) + + # forward pass to transformers + output = self._policy_model(**model_inputs) + + # compute action probs - policy head + next_token_logits = output.logits[:, -1] + dist = self._action_dist.proba_distribution(action_logits=next_token_logits) + + action_log_probs = dist.log_prob(action) + dist_entropy = dist.entropy() + values = None + + return action_log_probs.unsqueeze(-1), dist_entropy.mean(), values + + def get_policy_values(self, obs, rnn_states, masks): + raise NotImplementedError diff --git a/openrl/modules/networks/policy_value_network_gpt.py b/openrl/modules/networks/policy_value_network_gpt.py index e87e146b..85daef3a 100644 --- a/openrl/modules/networks/policy_value_network_gpt.py +++ b/openrl/modules/networks/policy_value_network_gpt.py @@ -37,6 +37,7 @@ def __init__( self.disable_drop_out = disable_drop_out self._use_valuenorm = cfg.use_valuenorm super(CausalLMActorCriticPolicy, self).__init__( + cfg, input_space, action_space, model_name=cfg.model_path, @@ -45,6 +46,9 @@ def __init__( self.use_half = use_half self.tpdv = dict(dtype=torch.float32, device=device) + self._use_fp16 = cfg.use_fp16 + assert not (cfg.use_fp16 and not cfg.use_deepspeed) + def get_actor_para(self): return self._policy_model.parameters() @@ -66,6 +70,8 @@ def get_actions( ): for key in obs.keys(): obs[key] = check(obs[key], self.use_half, self.tpdv) + if self._use_fp16: + obs[key] = obs[key].half() rnn_states = check(rnn_states, self.use_half, self.tpdv) past_model_kwargs = None @@ -83,6 +89,8 @@ def eval_actions( ): for key in obs.keys(): obs[key] = check(obs[key], self.use_half, self.tpdv) + if self._use_fp16: + obs[key] = obs[key].half() action = check(action, self.use_half, self.tpdv).squeeze() eval_output = super().evaluate_actions(obs, action) @@ -95,20 +103,11 @@ def eval_actions( def get_values(self, obs, rnn_states, masks): for key in obs.keys(): obs[key] = check(obs[key], self.use_half, self.tpdv) + if self._use_fp16: + obs[key] = obs[key].half() rnn_states = check(rnn_states, self.use_half, self.tpdv) value_output = super().forward_value(obs) values = value_output.values return values, rnn_states - - def get_log_probs_ref_model(self, obs, action): - for key in obs.keys(): - obs[key] = check(obs[key], self.use_half, self.tpdv) - action = check(action, self.use_half, self.tpdv) - action = action.squeeze(-1) - - policy_output = super().get_log_probs_ref_model(obs, action) - action_log_probs = policy_output.log_probs - - return action_log_probs.detach().cpu().numpy() diff --git a/openrl/modules/networks/utils/attention.py b/openrl/modules/networks/utils/attention.py index c00a3f24..a05a9a84 100644 --- a/openrl/modules/networks/utils/attention.py +++ b/openrl/modules/networks/utils/attention.py @@ -234,10 +234,13 @@ def forward(self, x, self_idx=-1): K = self.split_shape[i][0] L = self.split_shape[i][1] for j in range(K): - torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1) - exec("x1.append(self.fc_{}(temp))".format(i)) - x[self_idx] - exec("x1.append(self.fc_{}(temp))".format(N - 1)) + # torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1) + # exec("x1.append(self.fc_{}(temp))".format(i)) + temp = torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1) + x1.append(getattr(self, "fc_" + str(i))(temp)) + x1.append(getattr(self, "fc_" + str(N - 1))(self_x)) + # x[self_idx] + # exec("x1.append(self.fc_{}(temp))".format(N - 1)) out = torch.stack(x1, 1) @@ -278,8 +281,10 @@ def forward(self, x, self_idx=None): K = self.split_shape[i][0] L = self.split_shape[i][1] for j in range(K): - x[i][:, (L * j) : (L * j + L)] - exec("x1.append(self.fc_{}(temp))".format(i)) + # x[i][:, (L * j) : (L * j + L)] + # exec("x1.append(self.fc_{}(temp))".format(i)) + temp = x[i][:, (L * j) : (L * j + L)] + x1.append(getattr(self, "fc_" + str(i))(temp)) out = torch.stack(x1, 1) diff --git a/openrl/modules/networks/utils/distributions.py b/openrl/modules/networks/utils/distributions.py index 340015a4..fd3ef8ca 100644 --- a/openrl/modules/networks/utils/distributions.py +++ b/openrl/modules/networks/utils/distributions.py @@ -68,7 +68,7 @@ def init_(m): def forward(self, x, action_masks=None): x = self.linear(x) if action_masks is not None: - x[action_masks == 0] = -1e10 + x[action_masks == 0] = -6e4 # fp16 return FixedCategorical(logits=x) diff --git a/openrl/modules/networks/utils/nlp/base_policy.py b/openrl/modules/networks/utils/nlp/base_policy.py index 9051b886..dd0e2032 100644 --- a/openrl/modules/networks/utils/nlp/base_policy.py +++ b/openrl/modules/networks/utils/nlp/base_policy.py @@ -124,13 +124,14 @@ class GenerationOutputs: class LMActorCriticPolicy(nn.Module): def __init__( self, + cfg: Any, observation_space: DictSpace, action_space: Discrete, model_name: str, optimizer_kwargs: Dict[str, Any] = {}, weight_decay: float = 1e-6, use_sde: bool = None, - apply_model_parallel: bool = True, + # apply_model_parallel: bool = True, optimizer_class: torch.optim.Optimizer = torch.optim.AdamW, generation_kwargs: Dict[str, Any] = {}, prompt_truncation_side: str = "left", @@ -146,16 +147,16 @@ def __init__( optimizer_kwargs (Dict[str, Any], optional): optimizer kwargs. Defaults to {}. weight_decay (float, optional): weight decay. Defaults to 1e-6. use_sde (bool, optional): Use state-dependent exploration. Defaults to None. - apply_model_parallel (bool, optional): whether to apply model parallel. Defaults to True. + apply_model_parallel (bool, optional): default to use model parallel when not using deepspeed. optimizer_class (torch.optim.Optimizer, optional): Optimizer class. Defaults to torch.optim.AdamW. generation_kwargs (Dict[str, Any], optional): generation parameters for rollout. Defaults to {}. prompt_truncation_side (str, optional): truncation side for prompt text. Defaults to "left". """ super().__init__() + self._use_deepspeed = cfg.use_deepspeed self._action_space = action_space - self._apply_model_parallel = apply_model_parallel + self._apply_model_parallel = not cfg.use_deepspeed # TODO self._build_model_heads(model_name, config, device) - self._setup_optimizer(optimizer_kwargs, weight_decay, optimizer_class) self._action_dist = CategoricalDistribution(self._action_space.n) self._generation_kwargs = generation_kwargs self._prompt_truncation_side = prompt_truncation_side diff --git a/openrl/modules/networks/utils/nlp/causal_policy.py b/openrl/modules/networks/utils/nlp/causal_policy.py index dedfc4aa..f0b86d0d 100644 --- a/openrl/modules/networks/utils/nlp/causal_policy.py +++ b/openrl/modules/networks/utils/nlp/causal_policy.py @@ -15,16 +15,13 @@ PolicyType, ValueOutput, ) -from openrl.modules.networks.utils.nlp.hf_generation_utils import ( - override_generation_routines, - unwrap_generation_routines, -) from openrl.modules.utils.valuenorm import ValueNorm class CausalLMActorCriticPolicy(LMActorCriticPolicy): def __init__( self, + cfg: Any, observation_space: DictSpace, action_space: Discrete, model_name: str, @@ -40,6 +37,7 @@ def __init__( device: str = "cpu", ): super().__init__( + cfg, observation_space, action_space, model_name, @@ -65,29 +63,37 @@ def load_from_dict(self, state_dict: dict = None): @property def policy(self): policy_model = self._policy_model - policy_model.__class__ = unwrap_generation_routines(type(policy_model)) return policy_model def _build_model_heads(self, model_name: str, config: str, device: str): if self.disable_drop_out: - config = AutoConfig.from_pretrained(model_name) + if model_name == "test_gpt2": + from transformers import GPT2Config + + config = GPT2Config() + + else: + config = AutoConfig.from_pretrained(model_name) config_dict = config.to_dict() for key in config_dict: if "drop" in key: config_dict[key] = 0.0 config = config.from_dict(config_dict) - self._policy_model = AutoModelForCausalLM.from_pretrained( - model_name, config=config - ) + if model_name == "test_gpt2": + from transformers import GPT2LMHeadModel - self._policy_model.__class__ = override_generation_routines( - type(self._policy_model) - ) + self._policy_model = GPT2LMHeadModel(config) + self._value_model = GPT2LMHeadModel(config) - self._value_model = AutoModelForCausalLM.from_pretrained( - model_name, config=config - ) + else: + self._policy_model = AutoModelForCausalLM.from_pretrained( + model_name, config=config + ) + + self._value_model = AutoModelForCausalLM.from_pretrained( + model_name, config=config + ) self._value_head = nn.Linear( self._value_model.config.hidden_size, 1, bias=False @@ -99,7 +105,17 @@ def _build_model_heads(self, model_name: str, config: str, device: str): torch.multiprocessing.set_sharing_strategy("file_system") # apply model parallel if torch.cuda.is_available(): - if self._apply_model_parallel and self._policy_model.is_parallelizable: + if self._use_deepspeed: + if self.value_normalizer is not None: + import deepspeed + + para = self.value_normalizer.running_mean + deepspeed.zero.register_external_parameter(self, para) + para = self.value_normalizer.running_mean_sq + deepspeed.zero.register_external_parameter(self, para) + para = self.value_normalizer.debiasing_term + deepspeed.zero.register_external_parameter(self, para) + elif self._apply_model_parallel and self._policy_model.is_parallelizable: self._policy_model.parallelize() self._value_model.parallelize() self._value_head = self._value_head.to(self.device) @@ -126,17 +142,18 @@ def _prepare_inputs_for_model( input_ids, **model_kwargs ) - if self._apply_model_parallel and unwrap_model(model).is_parallelizable: - # if model is in parallel mode, move the tensors to the first device - model_inputs = { - key: ( - value.to(model.transformer.first_device) - if isinstance(value, torch.Tensor) - and hasattr(model.transformer, "first_device") - else value - ) - for key, value in model_inputs.items() - } + if not self._use_deepspeed: + if self._apply_model_parallel and unwrap_model(model).is_parallelizable: + # if model is in parallel mode, move the tensors to the first device + model_inputs = { + key: ( + value.to(model.transformer.first_device) + if isinstance(value, torch.Tensor) + and hasattr(model.transformer, "first_device") + else value + ) + for key, value in model_inputs.items() + } return model_inputs def forward_policy( diff --git a/openrl/modules/networks/utils/nlp/hf_generation_utils.py b/openrl/modules/networks/utils/nlp/hf_generation_utils.py deleted file mode 100644 index 37d80875..00000000 --- a/openrl/modules/networks/utils/nlp/hf_generation_utils.py +++ /dev/null @@ -1,4000 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import warnings -from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -from torch import nn -from transformers.generation_beam_constraints import ( - Constraint, - DisjunctiveConstraint, - PhrasalConstraint, -) -from transformers.generation_beam_search import ( - BeamScorer, - BeamSearchScorer, - ConstrainedBeamSearchScorer, -) -from transformers.generation_logits_process import ( - EncoderNoRepeatNGramLogitsProcessor, - ExponentialDecayLengthPenalty, - ForcedBOSTokenLogitsProcessor, - ForcedEOSTokenLogitsProcessor, - HammingDiversityLogitsProcessor, - InfNanRemoveLogitsProcessor, - LogitsProcessorList, - MinLengthLogitsProcessor, - NoBadWordsLogitsProcessor, - NoRepeatNGramLogitsProcessor, - PrefixConstrainedLogitsProcessor, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, - TypicalLogitsWarper, -) -from transformers.generation_stopping_criteria import ( - MaxLengthCriteria, - MaxTimeCriteria, - StoppingCriteria, - StoppingCriteriaList, - validate_stopping_criteria, -) -from transformers.generation_utils import GenerationMixin -from transformers.pytorch_utils import torch_int_div -from transformers.utils import ModelOutput, logging - -logger = logging.get_logger(__name__) - - -@dataclass -class GreedySearchDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using greedy search. - - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each - tensor of shape `(batch_size, config.vocab_size)`). - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class GreedySearchEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention - weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the - encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - - Args: - sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size, config.vocab_size)`). - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, - sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class SampleDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using sampling. - - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each - tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`). - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length, - sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class SampleEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of - the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states - attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size*num_return_sequences, config.vocab_size)`). - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape - `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length, - sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - scores: Optional[Tuple[torch.FloatTensor]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class BeamSearchDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using beam search. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class BeamSearchEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights - of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states - attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, - config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, - sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, - sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class BeamSampleDecoderOnlyOutput(ModelOutput): - """ - Base class for outputs of decoder-only generation models using beam sample. - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape - `(batch_size*num_beams*num_return_sequences, config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None - attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -@dataclass -class BeamSampleEncoderDecoderOutput(ModelOutput): - """ - Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention - weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the - encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) - - Args: - sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting - of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. - `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, - config.vocab_size)`). - beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): - Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped - tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, - sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size*num_beams, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. - """ - - sequences: torch.LongTensor = None - sequences_scores: Optional[torch.FloatTensor] = None - scores: Optional[Tuple[torch.FloatTensor]] = None - beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - - -GreedySearchOutput = Union[ - GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput -] -SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] -BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] -BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] - - -class GenerationMixinWithRawScores: - """ - A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. - - The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: - - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and - `do_sample=False`. - - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and - `do_sample=True`. - - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and - `do_sample=False`. - - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if - `num_beams>1` and `do_sample=True`. - - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if - `num_beams>1` and `num_beam_groups>1`. - - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], - if `constraints!=None` or `force_words_ids!=None`. - """ - - def _prepare_model_inputs( - self, - inputs: Optional[torch.Tensor] = None, - bos_token_id: Optional[int] = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: - """ - This function extracts the model-specific `inputs` for generation. - """ - # 1. retrieve all kwargs that are non-None or non-model input related. - # some encoder-decoder models have different names for model and encoder - if ( - self.config.is_encoder_decoder - and hasattr(self, "encoder") - and self.encoder.main_input_name != self.main_input_name - ): - input_name = self.encoder.main_input_name - else: - input_name = self.main_input_name - - model_kwargs = { - k: v for k, v in model_kwargs.items() if v is not None or k != input_name - } - - # 2. check whether model_input_name is passed as kwarg - # if yes and `inputs` is None use kwarg inputs - inputs_kwarg = model_kwargs.pop(input_name, None) - if inputs_kwarg is not None and inputs is not None: - raise ValueError( - f"`inputs`: {inputs}` were passed alongside " - f"{input_name} which is not allowed." - f"Make sure to either pass {inputs} or {input_name}=..." - ) - elif inputs_kwarg is not None: - inputs = inputs_kwarg - - # 3. models with `input_ids` can also make use of `inputs_embeds` - if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): - inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" - - # 4. Only encoder-decoder models can have non `input_ids` input format - if not self.config.is_encoder_decoder and input_name != "input_ids": - raise ValueError( - f"If {input_name} is passed as model-specific keyword " - "input then model has to be an encoder-decoder and not a " - f"{self.__class__.__name__}." - ) - - # 5. if `inputs` is still None, try to create `input_ids` from BOS token - if inputs is None: - inputs = self._prepare_input_ids_for_generation( - bos_token_id, model_kwargs.get("encoder_outputs") - ) - - return inputs, input_name, model_kwargs - - def _can_retrieve_inputs_from_name( - self, - inputs: Optional[torch.Tensor], - name: str, - model_kwargs: Dict[str, torch.Tensor], - ) -> torch.Tensor: - """ - If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved - from name - """ - can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( - inspect.signature(self.forward).parameters.keys() - ) - - if can_retrieve_inputs and inputs is not None: - raise ValueError( - f"Cannot only pass one of {name} and {self.main_input_name}" - ) - - return can_retrieve_inputs - - def prepare_inputs_for_generation( - self, input_ids: torch.LongTensor, **kwargs - ) -> Dict[str, Any]: - """ - Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. - """ - return {"input_ids": input_ids} - - def adjust_logits_during_generation( - self, logits: torch.FloatTensor, **kwargs - ) -> torch.FloatTensor: - """ - Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. - """ - return logits - - def _prepare_input_ids_for_generation( - self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] - ) -> torch.LongTensor: - if self.config.is_encoder_decoder and encoder_outputs is not None: - # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding - shape = encoder_outputs.last_hidden_state.size()[:-1] - return torch.ones(shape, dtype=torch.long, device=self.device) * -100 - - if bos_token_id is None: - raise ValueError( - "`bos_token_id` has to be defined when no `input_ids` are provided." - ) - return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id - - def _prepare_attention_mask_for_generation( - self, - inputs: torch.Tensor, - pad_token_id: int, - eos_token_id: int, - ) -> torch.LongTensor: - is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [ - torch.int, - torch.long, - ] - is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) - is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( - (eos_token_id is not None) and (pad_token_id != eos_token_id) - ) - # Check if input is input_ids and padded -> only then is attention_mask defined - if ( - is_input_ids - and is_pad_token_in_inputs - and is_pad_token_not_equal_to_eos_token_id - ): - return inputs.ne(pad_token_id).long() - else: - return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device) - - def _prepare_encoder_decoder_kwargs_for_generation( - self, - inputs_tensor: torch.Tensor, - model_kwargs, - model_input_name: Optional[str] = None, - ) -> Dict[str, Any]: - # 1. get encoder - encoder = self.get_encoder() - - # 2. prepare encoder args and encoder kwargs from model kwargs - irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] - encoder_kwargs = { - argument: value - for argument, value in model_kwargs.items() - if not any(argument.startswith(p) for p in irrelevant_prefix) - } - - # 3. make sure that encoder returns `ModelOutput` - model_input_name = ( - model_input_name if model_input_name is not None else self.main_input_name - ) - encoder_kwargs["return_dict"] = True - encoder_kwargs[model_input_name] = inputs_tensor - model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) - - return model_kwargs - - def _prepare_decoder_input_ids_for_generation( - self, - batch_size: int, - decoder_start_token_id: int = None, - bos_token_id: int = None, - model_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> torch.LongTensor: - if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - return model_kwargs.pop("decoder_input_ids") - else: - decoder_start_token_id = self._get_decoder_start_token_id( - decoder_start_token_id, bos_token_id - ) - return ( - torch.ones((batch_size, 1), dtype=torch.long, device=self.device) - * decoder_start_token_id - ) - - def _get_decoder_start_token_id( - self, decoder_start_token_id: int = None, bos_token_id: int = None - ) -> int: - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.config.decoder_start_token_id - ) - bos_token_id = ( - bos_token_id if bos_token_id is not None else self.config.bos_token_id - ) - - if decoder_start_token_id is not None: - return decoder_start_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "decoder_start_token_id") - and self.config.decoder.decoder_start_token_id is not None - ): - return self.config.decoder.decoder_start_token_id - elif bos_token_id is not None: - return bos_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "bos_token_id") - and self.config.decoder.bos_token_id is not None - ): - return self.config.decoder.bos_token_id - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for" - " encoder-decoder generation." - ) - - @staticmethod - def _expand_inputs_for_generation( - input_ids: torch.LongTensor, - expand_size: int = 1, - is_encoder_decoder: bool = False, - attention_mask: Optional[torch.LongTensor] = None, - encoder_outputs: Optional[ModelOutput] = None, - **model_kwargs, - ) -> Tuple[torch.LongTensor, Dict[str, Any]]: - expanded_return_idx = ( - torch.arange(input_ids.shape[0]) - .view(-1, 1) - .repeat(1, expand_size) - .view(-1) - .to(input_ids.device) - ) - input_ids = input_ids.index_select(0, expanded_return_idx) - - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = token_type_ids.index_select( - 0, expanded_return_idx - ) - - if attention_mask is not None: - model_kwargs["attention_mask"] = attention_mask.index_select( - 0, expanded_return_idx - ) - - if is_encoder_decoder: - if encoder_outputs is None: - raise ValueError( - "If `is_encoder_decoder` is True, make sure that `encoder_outputs`" - " is defined." - ) - encoder_outputs["last_hidden_state"] = ( - encoder_outputs.last_hidden_state.index_select( - 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) - ) - ) - model_kwargs["encoder_outputs"] = encoder_outputs - return input_ids, model_kwargs - - @staticmethod - def _update_model_kwargs_for_generation( - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - ) -> Dict[str, Any]: - # update past - if "past_key_values" in outputs: - model_kwargs["past"] = outputs.past_key_values - elif "mems" in outputs: - model_kwargs["past"] = outputs.mems - elif "past_buckets_states" in outputs: - model_kwargs["past"] = outputs.past_buckets_states - else: - model_kwargs["past"] = None - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 - ) - - # update attention mask - if not is_encoder_decoder: - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [ - attention_mask, - attention_mask.new_ones((attention_mask.shape[0], 1)), - ], - dim=-1, - ) - - return model_kwargs - - def _reorder_cache(self, past, beam_idx): - raise NotImplementedError( - "Make sure that a `_reorder_cache` function is correctly implemented in" - f" {self.__class__.__module__} to enable beam search for {self.__class__}" - ) - - def _get_logits_warper( - self, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - temperature: Optional[float] = None, - num_beams: Optional[int] = None, - ) -> LogitsProcessorList: - """ - This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances - used for multinomial sampling. - """ - - # init warp parameters - top_k = top_k if top_k is not None else self.config.top_k - top_p = top_p if top_p is not None else self.config.top_p - typical_p = typical_p if typical_p is not None else self.config.typical_p - temperature = ( - temperature if temperature is not None else self.config.temperature - ) - # instantiate warpers list - warpers = LogitsProcessorList() - - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - if temperature is not None and temperature != 1.0: - warpers.append(TemperatureLogitsWarper(temperature)) - if top_k is not None and top_k != 0: - warpers.append( - TopKLogitsWarper( - top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1) - ) - ) - if top_p is not None and top_p < 1.0: - warpers.append( - TopPLogitsWarper( - top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1) - ) - ) - if typical_p is not None and typical_p < 1.0: - warpers.append( - TypicalLogitsWarper( - mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1) - ) - ) - return warpers - - def _get_logits_processor( - self, - repetition_penalty: float, - no_repeat_ngram_size: int, - encoder_no_repeat_ngram_size: int, - input_ids_seq_length: int, - encoder_input_ids: torch.LongTensor, - bad_words_ids: List[List[int]], - min_length: int, - max_length: int, - eos_token_id: int, - forced_bos_token_id: int, - forced_eos_token_id: int, - prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], - num_beams: int, - num_beam_groups: int, - diversity_penalty: float, - remove_invalid_values: bool, - exponential_decay_length_penalty: Tuple, - logits_processor: Optional[LogitsProcessorList], - ) -> LogitsProcessorList: - """ - This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] - instances used to modify the scores of the language model head. - """ - processors = LogitsProcessorList() - - # init warp parameters - repetition_penalty = ( - repetition_penalty - if repetition_penalty is not None - else self.config.repetition_penalty - ) - no_repeat_ngram_size = ( - no_repeat_ngram_size - if no_repeat_ngram_size is not None - else self.config.no_repeat_ngram_size - ) - encoder_no_repeat_ngram_size = ( - encoder_no_repeat_ngram_size - if encoder_no_repeat_ngram_size is not None - else self.config.encoder_no_repeat_ngram_size - ) - bad_words_ids = ( - bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids - ) - min_length = min_length if min_length is not None else self.config.min_length - eos_token_id = ( - eos_token_id if eos_token_id is not None else self.config.eos_token_id - ) - diversity_penalty = ( - diversity_penalty - if diversity_penalty is not None - else self.config.diversity_penalty - ) - forced_bos_token_id = ( - forced_bos_token_id - if forced_bos_token_id is not None - else self.config.forced_bos_token_id - ) - forced_eos_token_id = ( - forced_eos_token_id - if forced_eos_token_id is not None - else self.config.forced_eos_token_id - ) - remove_invalid_values = ( - remove_invalid_values - if remove_invalid_values is not None - else self.config.remove_invalid_values - ) - exponential_decay_length_penalty = ( - exponential_decay_length_penalty - if exponential_decay_length_penalty is not None - else self.config.exponential_decay_length_penalty - ) - # instantiate processors list - - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - if diversity_penalty is not None and diversity_penalty > 0.0: - processors.append( - HammingDiversityLogitsProcessor( - diversity_penalty=diversity_penalty, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - ) - ) - if repetition_penalty is not None and repetition_penalty != 1.0: - processors.append( - RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) - ) - if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: - processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) - if ( - encoder_no_repeat_ngram_size is not None - and encoder_no_repeat_ngram_size > 0 - ): - if self.config.is_encoder_decoder: - processors.append( - EncoderNoRepeatNGramLogitsProcessor( - encoder_no_repeat_ngram_size, encoder_input_ids - ) - ) - else: - raise ValueError( - "It's impossible to use `encoder_no_repeat_ngram_size` with" - " decoder-only architecture" - ) - if bad_words_ids is not None: - processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) - if min_length is not None and eos_token_id is not None and min_length > 0: - processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) - if prefix_allowed_tokens_fn is not None: - processors.append( - PrefixConstrainedLogitsProcessor( - prefix_allowed_tokens_fn, num_beams // num_beam_groups - ) - ) - if forced_bos_token_id is not None: - processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) - if forced_eos_token_id is not None: - processors.append( - ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id) - ) - if remove_invalid_values is True: - processors.append(InfNanRemoveLogitsProcessor()) - if exponential_decay_length_penalty is not None: - processors.append( - ExponentialDecayLengthPenalty( - exponential_decay_length_penalty, eos_token_id, input_ids_seq_length - ) - ) - processors = self._merge_criteria_processor_list(processors, logits_processor) - return processors - - def _get_stopping_criteria( - self, - max_length: Optional[int], - max_time: Optional[float], - stopping_criteria: Optional[StoppingCriteriaList], - ) -> StoppingCriteriaList: - criteria = StoppingCriteriaList() - if max_length is not None: - criteria.append(MaxLengthCriteria(max_length=max_length)) - if max_time is not None: - criteria.append(MaxTimeCriteria(max_time=max_time)) - criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) - return criteria - - def _merge_criteria_processor_list( - self, - default_list: Union[LogitsProcessorList, StoppingCriteriaList], - custom_list: Union[LogitsProcessorList, StoppingCriteriaList], - ) -> Union[LogitsProcessorList, StoppingCriteriaList]: - if len(custom_list) == 0: - return default_list - for default in default_list: - for custom in custom_list: - if type(custom) is type(default): - object_type = ( - "stopping criteria" - if isinstance(custom, StoppingCriteria) - else "logits processor" - ) - raise ValueError( - f"A custom {object_type} of type {type(custom)} with values" - f" {custom} has been passed to `generate`, but it has already" - f" been created with the values {default}. {default} has been" - " created by passing the corresponding arguments to generate" - " or by the model's config default values. If you just want to" - f" change the default values of {object_type} consider passing" - " them as arguments to `generate` instead of using a custom" - f" {object_type}." - ) - default_list.extend(custom_list) - return default_list - - def compute_beam_search_raw_logits( - self, - sequences: torch.Tensor, - scores: Tuple[torch.Tensor], - beam_indices: torch.Tensor, - eos_token_id: int = None, - ): - """Compute raw logits for beam search""" - - if not self.config.is_encoder_decoder: - raise NotImplementedError( - "Beam Search raw logits code is implemented only for enoder-decoder" - " only models" - ) - - # since sequences can be shorter than scores (probably due to beam search finalization) - # we always have to generate raw_logits only for generated sequences - # cut off the start tokens from generated - sequences = sequences.clone() - sequences = sequences[:, 1:] - gen_steps = sequences.shape[1] - - # align scores and beam indices according to gen_steps - # scores(gen_steps x(batch_size * num_beams) x vocab_size) - scores = scores[:gen_steps] - scores = torch.stack(scores) - _, _, vocab_size = scores.shape - - beam_indices = torch.tensor(beam_indices).T.to(scores.device) - beam_indices = beam_indices[:gen_steps, :] - batch_size = beam_indices.shape[1] - - # gen_steps x batch_size x vocab_size - beam_indices = beam_indices.unsqueeze(-1).repeat(1, 1, vocab_size) - step_wise_logits = scores.gather(dim=1, index=beam_indices) - assert step_wise_logits.shape == torch.Size((gen_steps, batch_size, vocab_size)) - - # finally convert to tuples - step_wise_logits = [(step_wise_logits[t], None) for t in range(gen_steps)] - return step_wise_logits - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - max_length: Optional[int] = None, - min_length: Optional[int] = None, - do_sample: Optional[bool] = None, - early_stopping: Optional[bool] = None, - num_beams: Optional[int] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - typical_p: Optional[float] = None, - repetition_penalty: Optional[float] = None, - bad_words_ids: Optional[Iterable[int]] = None, - force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, - bos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - length_penalty: Optional[float] = None, - no_repeat_ngram_size: Optional[int] = None, - encoder_no_repeat_ngram_size: Optional[int] = None, - num_return_sequences: Optional[int] = None, - max_time: Optional[float] = None, - max_new_tokens: Optional[int] = None, - decoder_start_token_id: Optional[int] = None, - use_cache: Optional[bool] = None, - num_beam_groups: Optional[int] = None, - diversity_penalty: Optional[float] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), - stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), - constraints: Optional[List[Constraint]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - forced_bos_token_id: Optional[int] = None, - forced_eos_token_id: Optional[int] = None, - remove_invalid_values: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, - **model_kwargs, - ) -> Union[ - GreedySearchOutput, - SampleOutput, - BeamSearchOutput, - BeamSampleOutput, - torch.LongTensor, - ]: - r""" - - Generates sequences of token ids for models with a language modeling head. The method supports the following - generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: - - - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and - `do_sample=False`. - - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and - `do_sample=True`. - - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and - `do_sample=False`. - - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if - `num_beams>1` and `do_sample=True`. - - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if - `num_beams>1` and `num_beam_groups>1`. - - *constrained beam-search decoding* by calling - [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or - `force_words_ids!=None`. - - - - Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as - defined in the model's config (`config.json`) which in turn defaults to the - [`~modeling_utils.PretrainedConfig`] of the model. - - - - Most of these parameters are explained in more detail in [this blog - post](https://huggingface.co/blog/how-to-generate). - - Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. - max_length (`int`, *optional*, defaults to `model.config.max_length`): - The maximum length of the sequence to be generated. - max_new_tokens (`int`, *optional*, defaults to None): - The maximum numbers of tokens to generate, ignore the current number of tokens. Use either - `max_new_tokens` or `max_length` but not both, they serve the same purpose. - min_length (`int`, *optional*, defaults to 10): - The minimum length of the sequence to be generated. - do_sample (`bool`, *optional*, defaults to `False`): - Whether or not to use sampling ; use greedy decoding otherwise. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. - num_beams (`int`, *optional*, defaults to 1): - Number of beams for beam search. 1 means no beam search. - temperature (`float`, *optional*, defaults to 1.0): - The value used to module the next token probabilities. - top_k (`int`, *optional*, defaults to 50): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - top_p (`float`, *optional*, defaults to 1.0): - If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher - are kept for generation. - repetition_penalty (`float`, *optional*, defaults to 1.0): - The parameter for repetition penalty. 1.0 means no penalty. See [this - paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - bos_token_id (`int`, *optional*): - The id of the *beginning-of-sequence* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the - model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer - sequences. - no_repeat_ngram_size (`int`, *optional*, defaults to 0): - If set to int > 0, all ngrams of that size can only occur once. - encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): - If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the - `decoder_input_ids`. - bad_words_ids(`List[List[int]]`, *optional*): - List of token ids that are not allowed to be generated. In order to get the token ids of the words that - should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, - add_special_tokens=False).input_ids`. - force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): - List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple - list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, - this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), - where one can allow different forms of each word. - num_return_sequences(`int`, *optional*, defaults to 1): - The number of independently computed returned sequences for each element in the batch. - max_time(`float`, *optional*, defaults to None): - The maximum amount of time you allow the computation to run for in seconds. generation will still - finish the current pass after allocated time has been passed. - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens - that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape - as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) - decoder_start_token_id (`int`, *optional*): - If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. - use_cache: (`bool`, *optional*, defaults to `True`): - Whether or not the model should use the past last key/values attentions (if applicable to the model) to - speed up decoding. - num_beam_groups (`int`, *optional*, defaults to 1): - Number of groups to divide `num_beams` into in order to ensure diversity among different groups of - beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. - diversity_penalty (`float`, *optional*, defaults to 0.0): - This value is subtracted from a beam's score if it generates a token same as any beam from other group - at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is - enabled. - prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): - If provided, this function constraints the beam search to allowed tokens only at each step. If not - provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and - `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned - on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful - for constrained generation conditioned on the prefix, as described in [Autoregressive Entity - Retrieval](https://arxiv.org/abs/2010.00904). - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and a - model's config. If a logit processor is passed that is already created with the arguments or a model's - config an error is thrown. This feature is intended for advanced users. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - model's config. If a stopping criteria is passed that is already created with the arguments or a - model's config an error is thrown. This feature is intended for advanced users. - constraints (`List[Constraint]`, *optional*): - Custom constraints that can be added to the generation to ensure that the output will contain the use - of certain tokens as defined by `Constraint` objects, in the most sensible way possible. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - forced_bos_token_id (`int`, *optional*): - The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful - for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be - the target language token. - forced_eos_token_id (`int`, *optional*): - The id of the token to force as the last generated token when `max_length` is reached. - remove_invalid_values (`bool`, *optional*): - Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to - crash. Note that using `remove_invalid_values` can slow down generation. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - exponential_decay_length_penalty (`tuple(int, float)`, *optional*): - This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been - generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates - where penalty starts and `decay_factor` represents the factor of exponential decay - - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model - is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs - should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. - - If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation_utils.GreedySearchDecoderOnlyOutput`], - - [`~generation_utils.SampleDecoderOnlyOutput`], - - [`~generation_utils.BeamSearchDecoderOnlyOutput`], - - [`~generation_utils.BeamSampleDecoderOnlyOutput`] - - If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible - [`~utils.ModelOutput`] types are: - - - [`~generation_utils.GreedySearchEncoderDecoderOutput`], - - [`~generation_utils.SampleEncoderDecoderOutput`], - - [`~generation_utils.BeamSearchEncoderDecoderOutput`], - - [`~generation_utils.BeamSampleEncoderDecoderOutput`] - - Examples: - - Greedy Decoding: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> prompt = "Today I believe we can finally" - >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids - - >>> # generate up to 30 tokens - >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] - ``` - - Multinomial Sampling: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> prompt = "Today I believe we can finally" - >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids - - >>> # sample up to 30 tokens - >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT - >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] - ``` - - Beam-search decoding: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM - - >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") - - >>> sentence = "Paris is one of the densest populated areas in Europe." - >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids - - >>> outputs = model.generate(input_ids) - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] - ```""" - # 1. Set generation parameters if not already defined - bos_token_id = ( - bos_token_id if bos_token_id is not None else self.config.bos_token_id - ) - num_beams = num_beams if num_beams is not None else self.config.num_beams - length_penalty = ( - length_penalty if length_penalty is not None else self.config.length_penalty - ) - early_stopping = ( - early_stopping if early_stopping is not None else self.config.early_stopping - ) - num_beam_groups = ( - num_beam_groups - if num_beam_groups is not None - else self.config.num_beam_groups - ) - do_sample = do_sample if do_sample is not None else self.config.do_sample - num_return_sequences = ( - num_return_sequences - if num_return_sequences is not None - else self.config.num_return_sequences - ) - - pad_token_id = ( - pad_token_id if pad_token_id is not None else self.config.pad_token_id - ) - eos_token_id = ( - eos_token_id if eos_token_id is not None else self.config.eos_token_id - ) - - if eos_token_id is None and hasattr(self.config, "decoder"): - eos_token_id = self.config.decoder.eos_token_id - - if pad_token_id is None and eos_token_id is not None: - # special case if pad_token_id is not defined - # logger.warning( - # f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - - pad_token_id = eos_token_id - - output_scores = ( - output_scores if output_scores is not None else self.config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.config.return_dict_in_generate - ) - - # 2. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, bos_token_id, model_kwargs - ) - batch_size = inputs_tensor.shape[0] - - # 3. Define other model kwargs - model_kwargs["output_attentions"] = output_attentions - model_kwargs["output_hidden_states"] = output_hidden_states - model_kwargs["use_cache"] = use_cache - - accepts_attention_mask = "attention_mask" in set( - inspect.signature(self.forward).parameters.keys() - ) - requires_attention_mask = "encoder_outputs" not in model_kwargs - - if ( - model_kwargs.get("attention_mask", None) is None - and requires_attention_mask - and accepts_attention_mask - ): - model_kwargs["attention_mask"] = ( - self._prepare_attention_mask_for_generation( - inputs_tensor, pad_token_id, eos_token_id - ) - ) - - if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: - # if model is encoder decoder encoder_outputs are created - # and added to `model_kwargs` - model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, model_kwargs, model_input_name - ) - - # 4. Prepare `input_ids` which will be used for auto-regressive generation - if self.config.is_encoder_decoder: - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, - decoder_start_token_id=decoder_start_token_id, - bos_token_id=bos_token_id, - model_kwargs=model_kwargs, - ) - else: - # if decoder-only then inputs_tensor has to be `input_ids` - input_ids = inputs_tensor - - input_ids_seq_length = input_ids.shape[-1] - - # 5. Prepare `max_length` depending on other stopping criteria - # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` - if max_length is None and max_new_tokens is not None: - max_length = max_new_tokens + input_ids_seq_length - elif max_length is not None and max_new_tokens is not None: - # Both are set, this is odd, raise a warning - warnings.warn( - ( - "Both `max_length` and `max_new_tokens` have been set " - f"but they serve the same purpose. `max_length` {max_length} " - f"will take priority over `max_new_tokens` {max_new_tokens}." - ), - UserWarning, - ) - # default to config if still None - max_length = max_length if max_length is not None else self.config.max_length - - if input_ids_seq_length >= max_length: - input_ids_string = ( - "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - ) - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but" - f" ``max_length`` is set to {max_length}. This can lead to unexpected" - " behavior. You should consider increasing ``config.max_length`` or" - " ``max_length``." - ) - - # 6. determine generation mode - is_constraint_gen_mode = constraints is not None or force_words_ids is not None - is_greedy_gen_mode = ( - (num_beams == 1) - and (num_beam_groups == 1) - and do_sample is False - and not is_constraint_gen_mode - ) - is_sample_gen_mode = ( - (num_beams == 1) - and (num_beam_groups == 1) - and do_sample is True - and not is_constraint_gen_mode - ) - is_beam_gen_mode = ( - (num_beams > 1) - and (num_beam_groups == 1) - and do_sample is False - and not is_constraint_gen_mode - ) - is_beam_sample_gen_mode = ( - (num_beams > 1) - and (num_beam_groups == 1) - and do_sample is True - and not is_constraint_gen_mode - ) - is_group_beam_gen_mode = ( - (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode - ) - - if num_beam_groups > num_beams: - raise ValueError( - "`num_beam_groups` has to be smaller or equal to `num_beams`" - ) - if is_group_beam_gen_mode and do_sample is True: - raise ValueError( - "Diverse beam search cannot be used in sampling mode. Make sure that" - " `do_sample` is set to `False`." - ) - - # 7. prepare distribution pre_processing samplers - logits_processor = self._get_logits_processor( - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=inputs_tensor, - bad_words_ids=bad_words_ids, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - forced_bos_token_id=forced_bos_token_id, - forced_eos_token_id=forced_eos_token_id, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - diversity_penalty=diversity_penalty, - remove_invalid_values=remove_invalid_values, - exponential_decay_length_penalty=exponential_decay_length_penalty, - logits_processor=logits_processor, - ) - - # 8. prepare stopping criteria - stopping_criteria = self._get_stopping_criteria( - max_length=max_length, - max_time=max_time, - stopping_criteria=stopping_criteria, - ) - - # 9. go into different generation modes - if is_greedy_gen_mode: - if num_return_sequences > 1: - raise ValueError( - "num_return_sequences has to be 1, but is" - f" {num_return_sequences} when doing greedy search." - ) - - # 10. run greedy search - return self.greedy_search( - input_ids, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_sample_gen_mode: - # 10. prepare logits warper - logits_warper = self._get_logits_warper( - top_k=top_k, - top_p=top_p, - typical_p=typical_p, - temperature=temperature, - num_beams=num_beams, - ) - - # 11. expand input_ids with `num_return_sequences` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, - expand_size=num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 12. run sample - return self.sample( - input_ids, - logits_processor=logits_processor, - logits_warper=logits_warper, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_beam_gen_mode: - if num_return_sequences > num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`." - ) - - if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now." - ) - - # 10. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - ) - # 11. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, - expand_size=num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 12. run beam search - return self.beam_search( - input_ids, - beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_beam_sample_gen_mode: - # 10. prepare logits warper - logits_warper = self._get_logits_warper( - top_k=top_k, - top_p=top_p, - typical_p=typical_p, - temperature=temperature, - num_beams=num_beams, - ) - - if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now." - ) - # 11. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size * num_return_sequences, - num_beams=num_beams, - device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - ) - - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, - expand_size=num_beams * num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 13. run beam sample - return self.beam_sample( - input_ids, - beam_scorer, - logits_processor=logits_processor, - logits_warper=logits_warper, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_group_beam_gen_mode: - if num_return_sequences > num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`." - ) - - if num_beams % num_beam_groups != 0: - raise ValueError( - "`num_beams` should be divisible by `num_beam_groups` for group" - " beam search." - ) - - if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now." - ) - - # 10. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - max_length=stopping_criteria.max_length, - device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - num_beam_groups=num_beam_groups, - ) - # 11. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, - expand_size=num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 12. run beam search - return self.group_beam_search( - input_ids, - beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif is_constraint_gen_mode: - if num_return_sequences > num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`." - ) - - if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now." - ) - - if num_beams <= 1: - raise ValueError( - "`num_beams` needs to be greater than 1 for constrained" - " genertation." - ) - - if do_sample: - raise ValueError( - "`do_sample` needs to be false for constrained generation." - ) - - if num_beam_groups is not None and num_beam_groups > 1: - raise ValueError( - "`num_beam_groups` not supported yet for constrained generation." - ) - - final_constraints = [] - if constraints is not None: - final_constraints = constraints - - if force_words_ids is not None: - - def typeerror(): - raise ValueError( - "`force_words_ids` has to either be a `List[List[List[int]]]`" - " or `List[List[int]]`of positive integers, but is" - f" {force_words_ids}." - ) - - if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: - typeerror() - - for word_ids in force_words_ids: - if isinstance(word_ids[0], list): - if not isinstance(word_ids, list) or len(word_ids) == 0: - typeerror() - if any( - not isinstance(token_ids, list) for token_ids in word_ids - ): - typeerror() - if any( - any( - (not isinstance(token_id, int) or token_id < 0) - for token_id in token_ids - ) - for token_ids in word_ids - ): - typeerror() - - constraint = DisjunctiveConstraint(word_ids) - else: - if not isinstance(word_ids, list) or len(word_ids) == 0: - typeerror() - if any( - (not isinstance(token_id, int) or token_id < 0) - for token_id in word_ids - ): - typeerror() - - constraint = PhrasalConstraint(word_ids) - final_constraints.append(constraint) - - # 10. prepare beam search scorer - constrained_beam_scorer = ConstrainedBeamSearchScorer( - constraints=final_constraints, - batch_size=batch_size, - num_beams=num_beams, - device=self.device, - length_penalty=length_penalty, - do_early_stopping=early_stopping, - num_beam_hyps_to_keep=num_return_sequences, - ) - # 11. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, - expand_size=num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 12. run beam search - return self.constrained_beam_search( - input_ids, - constrained_beam_scorer=constrained_beam_scorer, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - output_scores=output_scores, - return_dict_in_generate=return_dict_in_generate, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - def greedy_search( - self, - input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ) -> Union[GreedySearchOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be - used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] - or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token - >>> model.config.pad_token_id = model.config.eos_token_id - - >>> input_prompt = "It might be possible to" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - - >>> outputs = model.greedy_search( - ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["It might be possible to get a better understanding of the nature of the problem, but it's not"] - ```""" - # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])`" - " instead." - ), - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - pad_token_id = ( - pad_token_id if pad_token_id is not None else self.config.pad_token_id - ) - eos_token_id = ( - eos_token_id if eos_token_id is not None else self.config.eos_token_id - ) - output_scores = ( - output_scores if output_scores is not None else self.config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.config.return_dict_in_generate - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - # keep track of which sequences are already finished - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - cur_len = input_ids.shape[-1] - - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits = outputs.logits[:, -1, :] - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # pre-process distribution - next_tokens_scores = logits_processor( - input_ids, next_token_logits, model_inputs - ) - - # argmax - next_tokens = torch.argmax(next_tokens_scores, dim=-1) - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is" - " defined." - ) - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - cur_len = cur_len + 1 - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul( - (next_tokens != eos_token_id).long() - ) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GreedySearchEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return GreedySearchDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return input_ids - - def sample( - self, - input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ) -> Union[SampleOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... TopKLogitsWarper, - ... TemperatureLogitsWarper, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - - >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token - >>> model.config.pad_token_id = model.config.eos_token_id - - >>> input_prompt = "Today is a beautiful day, and" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - >>> # instantiate logits processors - >>> logits_warper = LogitsProcessorList( - ... [ - ... TopKLogitsWarper(50), - ... TemperatureLogitsWarper(0.7), - ... ] - ... ) - - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) - - >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT - >>> outputs = model.sample( - ... input_ids, - ... logits_processor=logits_processor, - ... logits_warper=logits_warper, - ... stopping_criteria=stopping_criteria, - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] - ```""" - - # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead." - ), - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - logits_warper = ( - logits_warper if logits_warper is not None else LogitsProcessorList() - ) - pad_token_id = ( - pad_token_id if pad_token_id is not None else self.config.pad_token_id - ) - eos_token_id = ( - eos_token_id if eos_token_id is not None else self.config.eos_token_id - ) - output_scores = ( - output_scores if output_scores is not None else self.config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.config.return_dict_in_generate - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - # keep track of which sequences are already finished - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - cur_len = input_ids.shape[-1] - - this_peer_finished = False # used by synced_gpus only - # auto-regressive generation - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits_raw = outputs.logits[:, -1, :].clone() - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor( - input_ids, next_token_logits, model_inputs=model_inputs - ) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += ((next_token_logits_raw, next_token_scores),) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is" - " defined." - ) - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - cur_len = cur_len + 1 - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul( - (next_tokens != eos_token_id).long() - ) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return SampleEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return SampleDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return input_ids - - def beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ) -> Union[BeamSearchOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **beam search decoding** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... num_beams=num_beams, - ... device=model.device, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead." - ), - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - if len(stopping_criteria) == 0: - warnings.warn( - ( - "You don't have defined any stopping_criteria, this will likely" - " loop forever" - ), - UserWarning, - ) - pad_token_id = ( - pad_token_id if pad_token_id is not None else self.config.pad_token_id - ) - eos_token_id = ( - eos_token_id if eos_token_id is not None else self.config.eos_token_id - ) - output_scores = ( - output_scores if output_scores is not None else self.config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}," - f" but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) - if (return_dict_in_generate and output_scores) - else None - ) - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device - ) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits = outputs.logits[:, -1, :] - next_token_logits_raw = next_token_logits.clone() - - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len - ) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor( - input_ids, next_token_scores, model_inputs=model_inputs - ) - next_token_scores = next_token_scores_processed + beam_scores[ - :, None - ].expand_as(next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_logits_raw,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size - ) - - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True - ) - - next_indices = torch_int_div(next_tokens, vocab_size) - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 - ) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache( - model_kwargs["past"], beam_idx - ) - - if return_dict_in_generate and output_scores: - beam_indices = tuple( - ( - beam_indices[beam_idx[i]] + (beam_idx[i],) - for i in range(len(beam_indices)) - ) - ) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - else: - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - ( - beam_indices[ - i * num_beams : i * num_beams + num_return_sequences - ] - for i in range(batch_size) - ) - ) - beam_indices = sum(beam_indices, ()) - - step_wise_raw_logits = self.compute_beam_search_raw_logits( - sequence_outputs["sequences"].clone(), - scores, - beam_indices, - eos_token_id, - ) - - if self.config.is_encoder_decoder: - return BeamSearchEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=step_wise_raw_logits, # raw logits - beam_indices=beam_indices, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSearchDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=beam_indices, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] - - def beam_sample( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ) -> Union[BeamSampleOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **beam search multinomial - sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... TopKLogitsWarper, - ... TemperatureLogitsWarper, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... max_length=model.config.max_length, - ... num_beams=num_beams, - ... device=model.device, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] - ... ) - >>> # instantiate logits processors - >>> logits_warper = LogitsProcessorList( - ... [ - ... TopKLogitsWarper(50), - ... TemperatureLogitsWarper(0.7), - ... ] - ... ) - - >>> outputs = model.beam_sample( - ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead." - ), - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - pad_token_id = ( - pad_token_id if pad_token_id is not None else self.config.pad_token_id - ) - eos_token_id = ( - eos_token_id if eos_token_id is not None else self.config.eos_token_id - ) - output_scores = ( - output_scores if output_scores is not None else self.config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) - if (return_dict_in_generate and output_scores) - else None - ) - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device - ) - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits_raw = outputs.logits[:, -1, :] - - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits_raw, cur_len=cur_len - ) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor( - input_ids, next_token_logits, model_inputs=model_inputs - ) - next_token_scores = next_token_scores_processed + beam_scores[ - :, None - ].expand_as(next_token_scores) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - # return raw scores instead of post-processed - scores += ((next_token_logits_raw, next_token_scores),) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size - ) - - probs = nn.functional.softmax(next_token_scores, dim=-1) - - next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) - next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - - next_token_scores, _indices = torch.sort( - next_token_scores, descending=True, dim=1 - ) - next_tokens = torch.gather(next_tokens, -1, _indices) - - next_indices = torch_int_div(next_tokens, vocab_size) - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 - ) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache( - model_kwargs["past"], beam_idx - ) - - if return_dict_in_generate and output_scores: - beam_indices = tuple( - ( - beam_indices[beam_idx[i]] + (beam_idx[i],) - for i in range(len(beam_indices)) - ) - ) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - else: - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - ( - beam_indices[ - i * num_beams : i * num_beams + num_return_sequences - ] - for i in range(batch_size) - ) - ) - beam_indices = sum(beam_indices, ()) - - if self.config.is_encoder_decoder: - return BeamSampleEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=beam_indices, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSampleDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=beam_indices, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] - - def group_beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - **model_kwargs, - ): - r""" - Generates sequences of token ids for models with a language modeling head using **diverse beam search - decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - - model_kwargs: - Additional model specific kwargs that will be forwarded to the `forward` function of the model. If - model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if - `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a - [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... HammingDiversityLogitsProcessor, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - - >>> # lets run diverse beam search using 6 beams - >>> num_beams = 6 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... max_length=model.config.max_length, - ... num_beams=num_beams, - ... device=model.device, - ... num_beam_groups=3, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model.group_beam_search( - ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead." - ), - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - pad_token_id = ( - pad_token_id if pad_token_id is not None else self.config.pad_token_id - ) - eos_token_id = ( - eos_token_id if eos_token_id is not None else self.config.eos_token_id - ) - output_scores = ( - output_scores if output_scores is not None else self.config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - num_beam_groups = beam_scorer.num_beam_groups - num_sub_beams = num_beams // num_beam_groups - device = input_ids.device - - batch_beam_size, cur_len = input_ids.shape - - if return_dict_in_generate and output_scores: - beam_indices = [ - tuple(() for _ in range(num_sub_beams * batch_size)) - for _ in range(num_beam_groups) - ] - else: - beam_indices = None - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}," - f" but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - beam_scores = torch.full( - (batch_size, num_beams), -1e9, dtype=torch.float, device=device - ) - # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in - # the same group don't produce same tokens everytime. - beam_scores[:, ::num_sub_beams] = 0 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - # predicted tokens in cur_len step - current_tokens = torch.zeros( - batch_size * num_beams, dtype=input_ids.dtype, device=device - ) - - # indices which will form the beams in the next time step - reordering_indices = torch.zeros( - batch_size * num_beams, dtype=torch.long, device=device - ) - - # do one decoder step on all beams of all sentences in batch - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - if output_scores: - processed_score = torch.zeros_like(outputs.logits[:, -1, :]) - - for beam_group_idx in range(num_beam_groups): - group_start_idx = beam_group_idx * num_sub_beams - group_end_idx = min(group_start_idx + num_sub_beams, num_beams) - group_size = group_end_idx - group_start_idx - - # indices of beams of current group among all sentences in batch - batch_group_indices = [] - - for batch_idx in range(batch_size): - batch_group_indices.extend( - [ - batch_idx * num_beams + idx - for idx in range(group_start_idx, group_end_idx) - ] - ) - group_input_ids = input_ids[batch_group_indices] - - # select outputs of beams of current group only - next_token_logits_raw = outputs.logits[batch_group_indices, -1, :] - - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits_raw, cur_len=cur_len - ) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * group_size, vocab_size) - vocab_size = next_token_scores.shape[-1] - - next_token_scores_processed = logits_processor( - group_input_ids, - next_token_scores, - current_tokens=current_tokens, - beam_group_idx=beam_group_idx, - model_inputs=model_inputs, - ) - next_token_scores = next_token_scores_processed + beam_scores[ - batch_group_indices - ].unsqueeze(-1) - next_token_scores = next_token_scores.expand_as( - next_token_scores_processed - ) - - if output_scores: - processed_score[batch_group_indices] = next_token_logits_raw - - # reshape for beam search - next_token_scores = next_token_scores.view( - batch_size, group_size * vocab_size - ) - - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True - ) - - next_indices = torch_int_div(next_tokens, vocab_size) - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - group_input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - if return_dict_in_generate and output_scores: - beam_indices[beam_group_idx] = tuple( - beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) - for i in range(len(beam_indices[0])) - ) - - input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat( - [group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], - dim=-1, - ) - current_tokens[batch_group_indices] = group_input_ids[:, -1] - - # (beam_idx // group_size) -> batch_idx - # (beam_idx % group_size) -> offset of idx inside the group - reordering_indices[batch_group_indices] = ( - num_beams * torch_int_div(beam_idx, group_size) - + group_start_idx - + (beam_idx % group_size) - ) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (processed_score,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache( - model_kwargs["past"], reordering_indices - ) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - else: - beam_indices = sum(beam_indices, ()) - num_return_sequences = beam_scorer.num_beam_hyps_to_keep - # return only as many indices as sequences - beam_indices = tuple( - ( - beam_indices[ - i * num_beams : i * num_beams + num_return_sequences - ] - for i in range(batch_size) - ) - ) - beam_indices = sum(beam_indices, ()) - - if self.config.is_encoder_decoder: - return BeamSearchEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=beam_indices, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSearchDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] - - def constrained_beam_search( - self, - input_ids: torch.LongTensor, - constrained_beam_scorer: ConstrainedBeamSearchScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = None, - **model_kwargs, - ) -> Union[BeamSearchOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **constrained beam search - decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - constrained_beam_scorer (`ConstrainedBeamSearchScorer`): - A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation, while satisfying a list of positive constraints. For more information, the - documentation of [`ConstrainedBeamSearchScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`int`, *optional*): - The id of the *end-of-sequence* token. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... ConstrainedBeamSearchScorer, - ... PhrasalConstraint, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> constraint_str = "Sie" - >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token - >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] - - - >>> # instantiate beam scorer - >>> beam_scorer = ConstrainedBeamSearchScorer( - ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model.constrained_beam_search( - ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs - ... ) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt sind Sie?'] - ```""" - # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - if max_length is not None: - warnings.warn( - ( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`" - " instead." - ), - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - if len(stopping_criteria) == 0: - warnings.warn( - ( - "You don't have defined any stopping_criteria, this will likely" - " loop forever" - ), - UserWarning, - ) - pad_token_id = ( - pad_token_id if pad_token_id is not None else self.config.pad_token_id - ) - eos_token_id = ( - eos_token_id if eos_token_id is not None else self.config.eos_token_id - ) - output_scores = ( - output_scores if output_scores is not None else self.config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.config.return_dict_in_generate - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - batch_size = len(constrained_beam_scorer._beam_hyps) - num_beams = constrained_beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}," - f" but is {batch_beam_size}." - ) - - beam_scores = torch.zeros( - (batch_size, num_beams), dtype=torch.float, device=input_ids.device - ) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits_raw = outputs.logits[:, -1, :] - # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` - # cannot be generated both before and after the `nn.functional.log_softmax` operation. - next_token_logits = self.adjust_logits_during_generation( - next_token_logits_raw, cur_len=cur_len - ) - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor( - input_ids, next_token_scores, model_inputs=model_inputs - ) - - scores_for_all_vocab = next_token_scores_processed.clone() - - next_token_scores = next_token_scores_processed + beam_scores[ - :, None - ].expand_as(next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += ((next_token_logits_raw, next_token_scores),) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view( - batch_size, num_beams * vocab_size - ) - - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True - ) - - next_indices = (next_tokens / vocab_size).long() - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = constrained_beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - scores_for_all_vocab, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat( - [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 - ) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - if model_kwargs["past"] is not None: - model_kwargs["past"] = self._reorder_cache( - model_kwargs["past"], beam_idx - ) - - # increase cur_len - cur_len = cur_len + 1 - - if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = constrained_beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - ) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - if self.config.is_encoder_decoder: - return BeamSearchEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSearchDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] - - -def top_k_top_p_filtering( - logits: torch.FloatTensor, - top_k: int = 0, - top_p: float = 1.0, - filter_value: float = -float("Inf"), - min_tokens_to_keep: int = 1, -) -> torch.FloatTensor: - """ - Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - - Args: - logits: logits distribution shape (batch size, vocabulary size) - top_k (`int`, *optional*, defaults to 0): - If > 0, only keep the top k tokens with highest probability (top-k filtering) - top_p (`float`, *optional*, defaults to 1.0): - If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus - filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimumber of tokens we keep per batch example in the output. - - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - if top_k > 0: - logits = TopKLogitsWarper( - top_k=top_k, - filter_value=filter_value, - min_tokens_to_keep=min_tokens_to_keep, - )(None, logits) - - if 0 <= top_p <= 1.0: - logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)( - None, logits - ) - - return logits - - -def override_generation_routines(cls): - bases = list(cls.__bases__) - for base_ix in range(len(bases)): - if bases[base_ix] == GenerationMixin: - bases[base_ix] = GenerationMixinWithRawScores - - # recursively look up - if bases[base_ix] != object: - bases[base_ix] = override_generation_routines(bases[base_ix]) - - cls.__bases__ = tuple(bases) - return cls - - -def unwrap_generation_routines(cls): - bases = list(cls.__bases__) - for base_ix in range(len(bases)): - if bases[base_ix] == GenerationMixinWithRawScores: - bases[base_ix] = GenerationMixin - - # recursively look up - if bases[base_ix] != object: - bases[base_ix] = unwrap_generation_routines(bases[base_ix]) - - cls.__bases__ = tuple(bases) - return cls diff --git a/openrl/modules/networks/value_network.py b/openrl/modules/networks/value_network.py index 187eb465..bce574c5 100644 --- a/openrl/modules/networks/value_network.py +++ b/openrl/modules/networks/value_network.py @@ -49,6 +49,7 @@ def __init__( self._use_recurrent_policy = cfg.use_recurrent_policy self._use_influence_policy = cfg.use_influence_policy self._use_popart = cfg.use_popart + self._use_fp16 = cfg.use_fp16 and cfg.use_deepspeed self._influence_layer_N = cfg.influence_layer_N self._recurrent_N = cfg.recurrent_N self.tpdv = dict(dtype=torch.float32, device=device) @@ -118,6 +119,9 @@ def forward(self, critic_obs, rnn_states, masks): rnn_states = check(rnn_states).to(**self.tpdv) masks = check(masks).to(**self.tpdv) + if self._use_fp16: + critic_obs = critic_obs.half() + critic_features = self.base(critic_obs) if self._use_naive_recurrent_policy or self._use_recurrent_policy: diff --git a/openrl/modules/networks/value_network_gpt.py b/openrl/modules/networks/value_network_gpt.py new file mode 100644 index 00000000..0c5b1154 --- /dev/null +++ b/openrl/modules/networks/value_network_gpt.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2021 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Any, Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +from transformers.modeling_utils import unwrap_model + +from openrl.buffers.utils.util import get_critic_obs_space +from openrl.modules.networks.base_value_network import BaseValueNetwork +from openrl.modules.networks.utils.cnn import CNNBase +from openrl.modules.networks.utils.mix import MIXBase +from openrl.modules.networks.utils.mlp import MLPBase, MLPLayer +from openrl.modules.networks.utils.popart import PopArt +from openrl.modules.networks.utils.rnn import RNNLayer +from openrl.modules.networks.utils.util import init +from openrl.modules.utils.valuenorm import ValueNorm +from openrl.utils.util import check_v2 as check + + +class ValueNetworkGPT(BaseValueNetwork): + def __init__( + self, + cfg, + input_space, + action_space=None, + use_half=False, + device=torch.device("cpu"), + extra_args=None, + ): + self.device = device + + self.use_fp16 = cfg.use_fp16 + self.use_deepspeed = cfg.use_deepspeed + self.use_half = False + self.use_data_parallel = not cfg.use_deepspeed + self.use_model_parallel = False + assert not (self.use_deepspeed and self.use_data_parallel) + assert not (self.use_deepspeed and self.use_model_parallel) + assert not (self.use_data_parallel and self.use_model_parallel) + + super(ValueNetworkGPT, self).__init__(cfg, device) + + from transformers import AutoModelForCausalLM + + self._value_model = AutoModelForCausalLM.from_pretrained(cfg.model_path) + self._value_model.config.use_cache = False + self._value_head = nn.Linear( + self._value_model.config.n_embd, + 1, + bias=False, # gpt2 + # self._value_model.config.word_embed_proj_dim, 1, bias=False # opt-x + ) + self.value_normalizer = ( + ValueNorm(1, device=device) if self._use_valuenorm else None + ) + + if self.use_deepspeed: + self._value_head.to(self.device) + else: + if self.use_model_parallel: + self._value_model.parallelize() + elif self.use_data_parallel: + if self.use_half: + self._value_model = self._value_model.half() + self._value_head = self._value_head.half() + self._value_model = torch.nn.DataParallel(self._value_model) + self._value_model = self._value_model.to(self.device) + self._value_head = torch.nn.DataParallel(self._value_head) + self._value_head = self._value_head.to(self.device) + + def _prepare_inputs_for_model( + self, + model: Any, + input_ids: torch.tensor, + model_kwargs: Optional[Dict[str, torch.tensor]] = None, + ): + model_inputs = unwrap_model(model).prepare_inputs_for_generation( + input_ids, **model_kwargs + ) + + if self.use_model_parallel: + model_inputs = { + key: ( + value.to(model.transformer.first_device) + if isinstance(value, torch.Tensor) + and hasattr(model.transformer, "first_device") + else value + ) + for key, value in model_inputs.items() + } + + return model_inputs + + def forward(self, critic_obs, rnn_states, masks): + for key in critic_obs.keys(): + critic_obs[key] = ( + torch.from_numpy(critic_obs[key]) + if type(critic_obs[key]) == np.ndarray + else critic_obs[key] + ) + if self.use_data_parallel: + critic_obs[key] = critic_obs[key].to(self.device) + else: + critic_obs[key] = critic_obs[key].to(self._value_model.device) + + rnn_states = check(rnn_states) + + if self.use_half: + input_ids = critic_obs["input_encoded_pt"].int() + attention_mask = critic_obs["input_attention_mask_pt"].int() + else: + input_ids = critic_obs["input_encoded_pt"].long() + attention_mask = critic_obs["input_attention_mask_pt"].long() + + past_model_kwargs = None + if not past_model_kwargs: + past_model_kwargs = { + "attention_mask": attention_mask, + } + + model_inputs = self._prepare_inputs_for_model( + self._value_model, input_ids, past_model_kwargs + ) + output = self._value_model(output_hidden_states=True, **model_inputs) + last_tokens_hidden = output.hidden_states[-1][:, -1] + + if self.use_model_parallel: + last_tokens_hidden = last_tokens_hidden.to(self.device) + + values = self._value_head.forward(last_tokens_hidden) + + return values, rnn_states diff --git a/openrl/modules/rl_module.py b/openrl/modules/rl_module.py index 430e60d6..7b7e390e 100644 --- a/openrl/modules/rl_module.py +++ b/openrl/modules/rl_module.py @@ -55,6 +55,8 @@ def __init__( self.rank = rank self.world_size = world_size + self.use_deepspeed = cfg.use_deepspeed + use_half_actor = self.program_type == "actor" and cfg.use_half_actor if model_configs is None: @@ -70,18 +72,57 @@ def __init__( use_half=use_half_actor, extra_args=model_cg["extra_args"] if "extra_args" in model_cg else None, ) - self.models.update({model_key: model}) if self.program_type == "actor": continue - optimizer = torch.optim.Adam( - model.parameters(), - lr=model_cg["lr"], - eps=cfg.opti_eps, - weight_decay=cfg.weight_decay, - ) - self.optimizers.update({model_key: optimizer}) + if not self.use_deepspeed: + optimizer = torch.optim.Adam( + model.parameters(), + lr=model_cg["lr"], + eps=cfg.opti_eps, + weight_decay=cfg.weight_decay, + ) + self.models.update({model_key: model}) + self.optimizers.update({model_key: optimizer}) + else: + import json + + import deepspeed + from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + from transformers import get_constant_schedule + + self.use_fp16 = cfg.use_fp16 + self.use_offload = cfg.use_offload + + # Check for inconsistencies in configuration files + assert not (self.use_fp16 and not self.use_deepspeed) + assert not (self.use_offload and not self.use_deepspeed) + assert cfg.deepspeed_config is not None + with open(cfg.deepspeed_config) as file: + ds_config = json.load(file) + if "fp16" in ds_config: + assert ds_config["fp16"]["enabled"] == self.use_fp16 + + AdamOptimizer = DeepSpeedCPUAdam if self.use_offload else FusedAdam + optim_params = filter(lambda p: p.requires_grad, model.parameters()) + optim = AdamOptimizer( + optim_params, lr=model_cg["lr"], betas=(0.9, 0.95) + ) + + # LR Scheduler + lr_scheduler = get_constant_schedule( + optimizer=optim, + ) + + engine, *_ = deepspeed.initialize( + args=cfg, + model=model, + optimizer=optim, + lr_scheduler=lr_scheduler, + ) + self.models.update({model_key: engine}) + self.optimizers.update({model_key: engine}) if cfg.use_amp: self.scaler = torch.cuda.amp.GradScaler() diff --git a/openrl/modules/utils/valuenorm.py b/openrl/modules/utils/valuenorm.py index bed1d705..43aaad9c 100644 --- a/openrl/modules/utils/valuenorm.py +++ b/openrl/modules/utils/valuenorm.py @@ -24,15 +24,21 @@ def __init__( self.per_element_update = per_element_update self.tpdv = dict(dtype=torch.float32, device=device) - # self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) - # self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) - # self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv) - - self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False) + self.running_mean = nn.Parameter( + torch.zeros(input_shape), requires_grad=False + ).to(**self.tpdv) self.running_mean_sq = nn.Parameter( torch.zeros(input_shape), requires_grad=False + ).to(**self.tpdv) + self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to( + **self.tpdv ) - self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False) + + # self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False) + # self.running_mean_sq = nn.Parameter( + # torch.zeros(input_shape), requires_grad=False + # ) + # self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False) self.reset_parameters() diff --git a/openrl/modules/vdn_module.py b/openrl/modules/vdn_module.py index 32987372..10a9b541 100644 --- a/openrl/modules/vdn_module.py +++ b/openrl/modules/vdn_module.py @@ -68,6 +68,8 @@ def __init__( device=device, ) self.cfg = cfg + self.obs_space = input_space + self.act_space = act_space def lr_decay(self, episode, episodes): update_linear_schedule(self.optimizers["q_net"], episode, episodes, self.lr) diff --git a/openrl/rewards/nlp_reward.py b/openrl/rewards/nlp_reward.py index c653c7c8..38cd306a 100644 --- a/openrl/rewards/nlp_reward.py +++ b/openrl/rewards/nlp_reward.py @@ -10,20 +10,34 @@ class NLPReward(BaseReward): - def __init__(self, env: Env, ref_model: str, intent_model: str): + def __init__( + self, + env: Env, + ref_model: str, + intent_model: str, + use_deepspeed: bool = False, + ref_ds_config: str = "default", + intent_ds_config: str = "default", + ): self.rew_infos = [] self.env_infos = [] - meteor_config = { - "meteor_coeff": 0.5, - } - self.inner_rew_funcs = { - "meteor": Meteor(**meteor_config), - } + # bug unfixed + self.inner_rew_funcs = dict() + + # meteor_config = { + # "meteor_coeff": 0.5, + # "test": ref_model == "builtin_ref", + # } + # self.inner_rew_funcs = { + # "meteor": Meteor(**meteor_config), + # } kl_config = { "action_space": env.action_space, "ref_model": ref_model, + "use_deepspeed": use_deepspeed, + "ds_config": ref_ds_config, } self.step_rew_funcs = { "kl_pen": KLPenalty(**kl_config), @@ -32,6 +46,8 @@ def __init__(self, env: Env, ref_model: str, intent_model: str): intent_config = { "intent_model": intent_model, "intent_coeff": 0.5, + "use_deepspeed": use_deepspeed, + "ds_config": intent_ds_config, } self.batch_rew_funcs = { "intent_acc": Intent(**intent_config), diff --git a/openrl/runners/common/ppo_agent.py b/openrl/runners/common/ppo_agent.py index ad7d0a84..414ff409 100644 --- a/openrl/runners/common/ppo_agent.py +++ b/openrl/runners/common/ppo_agent.py @@ -136,6 +136,7 @@ def act( observation: Union[np.ndarray, Dict[str, np.ndarray]], info: Optional[List[Dict[str, Any]]] = None, deterministic: bool = True, + episode_starts: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: assert self.net is not None, "net is None" observation = ObsData.prepare_input(observation) @@ -149,6 +150,7 @@ def act( observation, action_masks=action_masks, deterministic=deterministic, + episode_starts=episode_starts, ) action = np.array(np.split(_t2n(action), self.env_num)) diff --git a/openrl/selfplay/callbacks/selfplay_api.py b/openrl/selfplay/callbacks/selfplay_api.py index 3d148749..e2214ecb 100644 --- a/openrl/selfplay/callbacks/selfplay_api.py +++ b/openrl/selfplay/callbacks/selfplay_api.py @@ -50,14 +50,17 @@ def _init_callback(self) -> None: ) self.bind = SelfplayAPIServer.bind() - serve.run(self.bind) + serve.run(self.bind, route_prefix="/selfplay") success = False try_time = 10 while not success: success = self.api_client.set_sample_strategy(self.sample_strategy) try_time -= 1 if try_time <= 0: - raise RuntimeError("Failed to set sample strategy.") + raise RuntimeError( + f"Failed to set sample strategy: {self.sample_strategy}. host:" + f" {self.host}, port: {self.port}" + ) def _on_step(self) -> bool: # print("To send request to API server.") @@ -72,5 +75,6 @@ def _on_training_end(self) -> None: print(f"deleting {application_name}") serve.delete(application_name) del self.bind + serve.shutdown() if self.verbose >= 2: print(f"delete {application_name} done!") diff --git a/openrl/selfplay/opponents/random_opponent.py b/openrl/selfplay/opponents/random_opponent.py index 1f396c34..501d571a 100644 --- a/openrl/selfplay/opponents/random_opponent.py +++ b/openrl/selfplay/opponents/random_opponent.py @@ -47,11 +47,20 @@ def _sample_random_action( action = [] for obs, space in zip(observation, action_space): - mask = obs.get("action_mask", None) - action.append(space.sample(mask)) + if termination or truncation: + action.append(None) + else: + if isinstance(obs, dict): + mask = obs.get("action_mask", None) + else: + mask = None + action.append(space.sample(mask)) else: - mask = observation.get("action_mask", None) - action = action_space.sample(mask) + if termination or truncation: + action = None + else: + mask = observation.get("action_mask", None) + action = action_space.sample(mask) return action def _load(self, opponent_path: Union[str, Path]): diff --git a/openrl/selfplay/opponents/utils.py b/openrl/selfplay/opponents/utils.py index d1d983d5..42ddbb2b 100644 --- a/openrl/selfplay/opponents/utils.py +++ b/openrl/selfplay/opponents/utils.py @@ -28,6 +28,9 @@ def check_opponent_template(opponent_template: Union[str, Path]): + assert isinstance(opponent_template, Path) or isinstance( + opponent_template, str + ), f"opponent_template {opponent_template} must be a Path or str" if isinstance(opponent_template, str): opponent_template = Path(opponent_template) assert ( diff --git a/openrl/selfplay/selfplay_api/opponent_model.py b/openrl/selfplay/selfplay_api/opponent_model.py index af9ec6b9..b836519b 100644 --- a/openrl/selfplay/selfplay_api/opponent_model.py +++ b/openrl/selfplay/selfplay_api/opponent_model.py @@ -49,7 +49,7 @@ def get_battle_info(self) -> Dict[str, Any]: result = {} result["win_rate"] = float(self.num_wins) / max(self.num_games, 1) result["draw_rate"] = float(self.num_draws) / max(self.num_games, 1) - result["loss_rate"] = float(self.num_losses) / max(self.num_games, 1) + result["lose_rate"] = float(self.num_losses) / max(self.num_games, 1) result["total_games"] = self.num_games return result diff --git a/openrl/selfplay/selfplay_api/selfplay_api.py b/openrl/selfplay/selfplay_api/selfplay_api.py index 2c346b46..307c4fcc 100644 --- a/openrl/selfplay/selfplay_api/selfplay_api.py +++ b/openrl/selfplay/selfplay_api/selfplay_api.py @@ -33,7 +33,7 @@ from openrl.selfplay.selfplay_api.opponent_model import BattleResult -@serve.deployment(route_prefix="/selfplay") +@serve.deployment() @serve.ingress(app) class SelfplayAPIServer(BaseSelfplayAPIServer): @app.post("/set_sample_strategy") diff --git a/openrl/selfplay/strategies/__init__.py b/openrl/selfplay/strategies/__init__.py deleted file mode 100644 index 2908f8b4..00000000 --- a/openrl/selfplay/strategies/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2023 The OpenRL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""""" -from openrl.selfplay.strategies.strategies import ( - NaiveSelfplayStrategy, - OnlyLatestSelfplayStrategy, - VarExistEnemySelfplayStrategy, - WeightExistEnemySelfplayStrategy, - WeightSelfplayStrategy, - WinRateSelfplayStrategy, -) - - -def make_strategy(strategy_name): - if strategy_name == "Naive": - selfplay_strategy = NaiveSelfplayStrategy - elif strategy_name == "OnlyLatest": - selfplay_strategy = OnlyLatestSelfplayStrategy - elif strategy_name == "Weight": - selfplay_strategy = WeightSelfplayStrategy - elif strategy_name == "WinRate": - selfplay_strategy = WinRateSelfplayStrategy - elif strategy_name == "VarExistEnemy": - selfplay_strategy = VarExistEnemySelfplayStrategy - elif strategy_name == "WeightExistEnemy": - selfplay_strategy = WeightExistEnemySelfplayStrategy - return selfplay_strategy diff --git a/openrl/selfplay/strategies/base_strategy.py b/openrl/selfplay/strategies/base_strategy.py deleted file mode 100644 index 4e280b13..00000000 --- a/openrl/selfplay/strategies/base_strategy.py +++ /dev/null @@ -1,39 +0,0 @@ -from abc import abstractmethod - - -class BaseSelfplayStrategy: - @abstractmethod - def __init__(self, all_args, nenvs, exist_enemy_num): - raise NotImplementedError - - @abstractmethod - def getcnt(self): - raise NotImplementedError - - @abstractmethod - def update_enemy_ids(self, new_enemy_ids): - raise NotImplementedError - - @abstractmethod - def restore(self, model_dir): - raise NotImplementedError - - @abstractmethod - def get_qlist(self): - raise NotImplementedError - - @abstractmethod - def update_weight(self, enemy_loses): - raise NotImplementedError - - @abstractmethod - def update_win_rate(self, dones, enemy_wins): - raise NotImplementedError - - @abstractmethod - def push_newone(self): - raise NotImplementedError - - @abstractmethod - def get_plist(self): - raise NotImplementedError diff --git a/openrl/selfplay/strategies/strategies.py b/openrl/selfplay/strategies/strategies.py deleted file mode 100644 index 28e492ec..00000000 --- a/openrl/selfplay/strategies/strategies.py +++ /dev/null @@ -1,413 +0,0 @@ -import json - -import numpy as np - -from openrl.selfplay.strategies.base_strategy import BaseSelfplayStrategy - - -class SelfplayStrategy(BaseSelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - # qlist和history_cnt的数据结构 - self.all_args = all_args - self.qlist = [] - self.history_cnt = 0 - self.enemy_ids = [0] * nenvs - self.length = nenvs - - def getcnt(self): - return self.history_cnt - - def update_enemy_ids(self, new_enemy_ids): - self.enemy_ids = new_enemy_ids - - def restore(self, model_dir): - with open(model_dir + "/enemy_history_info.json") as f_obj: - enemy_info = json.load(f_obj) - self.qlist = enemy_info["qlist"] - self.history_cnt = enemy_info["history_cnt"] - - def get_qlist(self): - return self.qlist - - def update_weight(self, enemy_loses): - pass - - def update_win_rate(self, dones, enemy_wins): - pass - - def push_newone(self): - pass - - -class RatioSelfplayStrategy(SelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - super(RatioSelfplayStrategy, self).__init__(all_args, nenvs) - - def push_newone(self): - self.history_cnt += 1 - - def get_plist(self): - if self.history_cnt == 1: - return [1] - temp_plist = np.logspace( - 0, self.history_cnt - 1, self.history_cnt, endpoint=True, base=1.5 - ) - temp_plist[-1] = sum(temp_plist[:-1]) * 4 - temp_plist /= sum(temp_plist) - return temp_plist - - -class NaiveSelfplayStrategy(SelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - super(NaiveSelfplayStrategy, self).__init__(all_args, nenvs, exist_enemy_num) - - def push_newone(self): - self.history_cnt += 1 - - def get_plist(self): - return [1] * (self.history_cnt - 1) + [4 * (self.history_cnt - 1)] - - def save_new_one(self): - return True - - -class OnlyLatestSelfplayStrategy(SelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - super(OnlyLatestSelfplayStrategy, self).__init__( - all_args, nenvs, exist_enemy_num - ) - self.play_list = [] - self.max_play_num = all_args.max_play_num - self.least_win_rate = all_args.least_win_rate - - def push_newone(self): - self.play_list.append([]) - self.history_cnt += 1 - - def get_plist(self): - return [0] * (self.history_cnt - 1) + [1] - - def save_new_one(self, least_win_rate): - if sum(np.array(self.play_list[-1]) == -1) >= least_win_rate * ( - len(self.play_list[-1]) + 1 - ) and len(self.play_list[-1]) >= (self.max_play_num - 10): - return True - - def update_play_list(self, win_enemy_ids, tie_enemy_ids, lose_enemy_ids): - for win_enemy_id in win_enemy_ids: - self.play_list[win_enemy_id].append(1) - for tie_enemy_id in tie_enemy_ids: - self.play_list[tie_enemy_id].append(0) - for lose_enemy_id in lose_enemy_ids: - self.play_list[lose_enemy_id].append(-1) - self.cut_overflow() - - def update_win_rate(self, enemy_wins, enemy_ties, enemy_loses): - win_enemy_ids = np.array(self.enemy_ids)[enemy_wins] - tie_enemy_ids = np.array(self.enemy_ids)[enemy_ties] - lose_enemy_ids = np.array(self.enemy_ids)[enemy_loses] - self.update_play_list(win_enemy_ids, tie_enemy_ids, lose_enemy_ids) - - def cut_overflow(self): - for index in range(len(self.play_list)): - if len(self.play_list[index]) > self.max_play_num: - self.play_list[index] = self.play_list[index][ - (-1) * self.max_play_num : - ] - - def get_info_list(self, info_list): - return_info = [] - for info in info_list: - if info == "win": - equal_num = 1 - elif info == "tie": - equal_num = 0 - elif info == "lose": - equal_num = -1 - num_list = [] - for enemy_play_list in self.play_list: - if info == "play": - num_list.append(len(enemy_play_list)) - else: - num_list.append(int(sum(np.array(enemy_play_list) == equal_num))) - return_info.append(num_list) - return tuple(return_info) - - def get_enemy_play_dict(self): - win_num_list, tie_num_list, lose_num_list, play_num_list = self.get_info_list( - ["win", "tie", "lose", "play"] - ) - return { - "win_num_list": list(win_num_list), - "tie_num_list": list(tie_num_list), - "lose_num_list": list(lose_num_list), - "play_num_list": list(play_num_list), - } - - -class WeightSelfplayStrategy(SelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - super(WeightSelfplayStrategy, self).__init__(all_args, nenvs, exist_enemy_num) - self.recent_weight = 0.8 - self.recent_num = 3 - self.gama = 1 / (nenvs) - - def push_newone(self): - self.history_cnt += 1 - if self.history_cnt <= self.recent_num: - return - elif self.history_cnt == self.recent_num + 1: - self.qlist = [1] - else: - self.qlist.append(max(self.qlist)) - - def get_plist(self): - temp_plist = np.zeros([self.history_cnt]) - temp_plist[: (-1 * self.recent_num)] = ( - np.exp(self.qlist) / sum(np.exp(self.qlist)) * (1 - self.recent_weight) - ) - temp_plist[(-1 * self.recent_num) :] = self.recent_weight / self.recent_num - return temp_plist - - def update_weight(self, enemy_loses): - if self.history_cnt < self.recent_num + 2: - return - lose_enemy_ids = np.array(self.enemy_ids)[ - enemy_loses - ] # 输了的enemy_ids,进行更新,其中可能有重复的enemy_id - for enemy_id in lose_enemy_ids: - if enemy_id <= len(self.qlist) - 1: - divide_num = ( - len(self.qlist) - * np.exp(self.qlist[enemy_id]) - / sum(np.exp(self.qlist)) - ) - next_weight = self.qlist[enemy_id] - self.gama / divide_num - self.qlist[enemy_id] = next_weight - - -class WinRateSelfplayStrategy(SelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - super(WinRateSelfplayStrategy, self).__init__(all_args, nenvs, exist_enemy_num) - self.max_play_num = all_args.max_play_num - self.play_list = ( - [] - ) # 在该list中,每个对手维护一个长度不超过max_play_num的列表,1为该对手获胜, 0为平, -1为我方获胜 - self.recent_list = [] - self.recent_list_max_len = all_args.recent_list_max_len - self.latest_weight = all_args.latest_weight - self.least_win_rate = all_args.least_win_rate - self.stage2_least_win_rate = all_args.least_win_rate - self.stage = 1 - self.newest_pos = all_args.newest_pos - self.newest_weight = all_args.newest_weight - - def push_newone(self): - self.play_list.append([]) - self.history_cnt += 1 - - def get_info_list(self, info_list): - return_info = [] - for info in info_list: - if info == "win": - equal_num = 1 - elif info == "tie": - equal_num = 0 - elif info == "lose": - equal_num = -1 - num_list = [] - for enemy_play_list in self.play_list: - if info == "play": - num_list.append(len(enemy_play_list)) - else: - num_list.append(int(sum(np.array(enemy_play_list) == equal_num))) - return_info.append(num_list) - return tuple(return_info) - - def get_plist(self): - def f_hard(win_rate_list): - p = 1 - return win_rate_list**p - - def f_var(win_rate_list): - return (1 - win_rate_list) * win_rate_list - - win_num_list, tie_num_list, play_num_list = self.get_info_list( - ["win", "tie", "play"] - ) - win_rate_list = ( - np.array(win_num_list) + 0.5 * np.array(tie_num_list) + 0.5 - ) / (np.array(play_num_list) + 1) - return f_hard(win_rate_list) - - def update_play_list(self, win_enemy_ids, tie_enemy_ids, lose_enemy_ids): - if self.stage == 2: - win_enemy_num = (np.array(win_enemy_ids) != self.newest_pos).sum() - tie_enemy_num = (np.array(tie_enemy_ids) != self.newest_pos).sum() - lose_enemy_num = (np.array(lose_enemy_ids) != self.newest_pos).sum() - self.recent_list += ( - [1] * win_enemy_num + [0] * tie_enemy_num + [-1] * lose_enemy_num - ) - for win_enemy_id in win_enemy_ids: - self.play_list[win_enemy_id].append(1) - for tie_enemy_id in tie_enemy_ids: - self.play_list[tie_enemy_id].append(0) - for lose_enemy_id in lose_enemy_ids: - self.play_list[lose_enemy_id].append(-1) - self.cut_overflow() - - def update_win_rate(self, enemy_wins, enemy_ties, enemy_loses): - win_enemy_ids = np.array(self.enemy_ids)[enemy_wins] - tie_enemy_ids = np.array(self.enemy_ids)[enemy_ties] - lose_enemy_ids = np.array(self.enemy_ids)[enemy_loses] - self.update_play_list(win_enemy_ids, tie_enemy_ids, lose_enemy_ids) - - def restore(self, model_dir): - with open(model_dir + "/enemy_history_info.json") as f_obj: - enemy_info = json.load(f_obj) - self.history_cnt = enemy_info["history_cnt"] - self.play_list = enemy_info["play_list"] - - def get_enemy_play_dict(self): - win_num_list, tie_num_list, lose_num_list, play_num_list = self.get_info_list( - ["win", "tie", "lose", "play"] - ) - return { - "win_num_list": list(win_num_list), - "tie_num_list": list(tie_num_list), - "lose_num_list": list(lose_num_list), - "play_num_list": list(play_num_list), - } - - def update_win_info(self, data): - win_enemy_ids, tie_enemy_ids, lose_enemy_ids = ( - data["win_enemy_ids"], - data["tie_enemy_ids"], - data["lose_enemy_ids"], - ) - self.update_play_list(win_enemy_ids, tie_enemy_ids, lose_enemy_ids) - - def cut_overflow(self): - for index in range(len(self.play_list)): - if len(self.play_list[index]) > self.max_play_num: - self.play_list[index] = self.play_list[index][ - (-1) * self.max_play_num : - ] - if len(self.recent_list) > self.recent_list_max_len: - self.recent_list = self.recent_list[(-1) * self.recent_list_max_len :] - - def save_new_one(self, least_win_rate): - if self.stage == 1: - if sum(np.array(self.play_list[-1]) == -1) >= least_win_rate * ( - len(self.play_list[-1]) + 1 - ) and len(self.play_list[-1]) >= (self.max_play_num - 10): - if self.getcnt() - self.all_args.exist_enemy_num == 1: - return True - self.stage = 2 - print("switch to stage 2") - if self.stage == 2: - if sum(np.array(self.recent_list) == -1) >= self.stage2_least_win_rate * ( - len(self.recent_list) + 1 - ) and len(self.recent_list) >= (self.recent_list_max_len - 10): - self.stage = 1 - self.recent_list = [] - return True - return False - - -class ExistEnemySelfplayStrategy(WinRateSelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - super(ExistEnemySelfplayStrategy, self).__init__( - all_args, nenvs, exist_enemy_num - ) - self.all_args = all_args - self.enemy_ids = [0] * nenvs # 第一个step就会更新,所以初始化无所谓 - # 列表的前exist_enemy_num个为已存在的对手 - if exist_enemy_num > 0: - self.play_list = [[]] * exist_enemy_num - self.history_cnt = exist_enemy_num - self.exist_enemy_num = exist_enemy_num - self.max_enemy_num = all_args.max_enemy_num - - def get_final_plist(self, f_hard, f_var): - raise NotImplementedError - - def get_plist(self): - def f_hard(win_rate_list): - p = 2 - return win_rate_list**p - - def f_var(win_rate_list): - return (1 - win_rate_list) * win_rate_list - - plist = self.get_final_plist(f_hard, f_var) - if self.max_enemy_num != -1: - if self.history_cnt - self.exist_enemy_num > self.max_enemy_num: - mask_index = np.array( - list( - range( - self.exist_enemy_num, self.history_cnt - self.max_enemy_num - ) - ) - ) - zero_vec = np.zeros( - self.history_cnt - self.exist_enemy_num - self.max_enemy_num - ) - plist[mask_index] = zero_vec - - return plist - - -class VarExistEnemySelfplayStrategy(ExistEnemySelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - super(VarExistEnemySelfplayStrategy, self).__init__( - all_args, nenvs, exist_enemy_num - ) - - def get_final_plist(self, f_hard, f_var): - win_num_list, tie_num_list, play_num_list = self.get_info_list( - ["win", "tie", "play"] - ) - win_rate_list = ( - np.array(win_num_list) + 0.5 * np.array(tie_num_list) + 0.5 - ) / (np.array(play_num_list) + 1) - win_rate_list = f_var(win_rate_list) - - return win_rate_list - - -class WeightExistEnemySelfplayStrategy(ExistEnemySelfplayStrategy): - def __init__(self, all_args, nenvs, exist_enemy_num): - super(WeightExistEnemySelfplayStrategy, self).__init__( - all_args, nenvs, exist_enemy_num - ) - - def get_final_plist(self, f_hard, f_var): - win_num_list, tie_num_list, play_num_list = self.get_info_list( - ["win", "tie", "play"] - ) - win_rate_list = ( - np.array(win_num_list) + 0.5 * np.array(tie_num_list) + 0.5 - ) / (np.array(play_num_list) + 1) - - if self.stage == 1: - win_rate_list = f_hard(win_rate_list)[:-1] - # if self.newest_pos != -1: - # win_rate_list[self.newest_pos] = 0 - win_rate_list = ( - win_rate_list / (sum(win_rate_list) + 1e-8) * (1 - self.latest_weight) - ) - return list(win_rate_list) + [self.latest_weight] - elif self.stage == 2: - win_rate_list = f_hard(win_rate_list) - if self.newest_pos != -1: - win_rate_list[self.newest_pos] = self.newest_weight - index_without_newest = list(range(self.history_cnt)) - index_without_newest.remove(self.newest_pos) - win_rate_list[index_without_newest] /= sum( - win_rate_list[index_without_newest] - ) - win_rate_list[index_without_newest] *= 1 - self.newest_weight - else: - win_rate_list /= sum(win_rate_list) - return win_rate_list diff --git a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py index a3de3c0f..ca8d1e95 100644 --- a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py +++ b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py @@ -104,6 +104,7 @@ def reset(self, *, seed: Optional[int] = None, **kwargs): action = self.get_opponent_action( player_name, observation, reward, termination, truncation, info ) + self.env.step(action) def on_episode_end( @@ -147,10 +148,18 @@ def _step(self, action): if termination or truncation: return ( copy.copy(self.env.observe(self.self_player)), - self.env.rewards[self.self_player], + ( + self.env.rewards[self.self_player] + if self.self_player in self.env.rewards + else 0 + ), termination, truncation, - self.env.infos[self.self_player], + ( + self.env.infos[self.self_player] + if self.self_player in self.env.rewards + else {} + ), ) else: diff --git a/openrl/utils/callbacks/checkpoint_callback.py b/openrl/utils/callbacks/checkpoint_callback.py index a4b3f5b6..56bf31b8 100644 --- a/openrl/utils/callbacks/checkpoint_callback.py +++ b/openrl/utils/callbacks/checkpoint_callback.py @@ -72,9 +72,7 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st """ return os.path.join( self.save_path, - ( - f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}" - ), + f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}", ) def _on_step(self) -> bool: diff --git a/openrl/utils/evaluation.py b/openrl/utils/evaluation.py index d603daa5..c008c437 100644 --- a/openrl/utils/evaluation.py +++ b/openrl/utils/evaluation.py @@ -68,12 +68,10 @@ def evaluate_policy( if not is_monitor_wrapped and warn: warnings.warn( - ( - "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This" - " may result in reporting modified episode lengths and rewards, if" - " other wrappers happen to modify these. Consider wrapping environment" - " first with ``Monitor`` wrapper." - ), + "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This" + " may result in reporting modified episode lengths and rewards, if" + " other wrappers happen to modify these. Consider wrapping environment" + " first with ``Monitor`` wrapper.", UserWarning, ) @@ -97,9 +95,13 @@ def evaluate_policy( episode_starts = np.ones((env.parallel_env_num,), dtype=bool) while (episode_counts < episode_count_targets).any(): + if not np.all(episode_starts == 0): + episode_starts_tmp = episode_starts + else: + episode_starts_tmp = None + actions, states = agent.act( - observations, - deterministic=deterministic, + observations, deterministic=deterministic, episode_starts=episode_starts_tmp ) observations, rewards, dones, infos = env.step(actions) rewards = np.squeeze(rewards, axis=-1) diff --git a/openrl/utils/logger.py b/openrl/utils/logger.py index 0f2f0e2e..d9c49f34 100644 --- a/openrl/utils/logger.py +++ b/openrl/utils/logger.py @@ -32,9 +32,9 @@ class Logger: def __init__( self, cfg, - project_name: str, - scenario_name: str, - wandb_entity: str, + project_name: str = "openrl", + scenario_name: str = "openrl", + wandb_entity: str = "openrl", exp_name: Optional[str] = None, log_path: Optional[str] = None, use_wandb: bool = False, @@ -46,6 +46,10 @@ def __init__( self.use_wandb = use_wandb self.use_tensorboard = use_tensorboard + self.skip_logging = False + if cfg.use_deepspeed and cfg.local_rank != 0: + self.skip_logging = True + self.log_level = log_level self.log_path = log_path self.project_name = project_name @@ -126,20 +130,21 @@ def _init(self) -> None: ) if self.use_wandb: - wandb.init( - config=self.cfg, - project=self.project_name, - entity=self.wandb_entity, - notes=socket.gethostname(), - name=self.scenario_name - + "_" - + str(self.exp_name) - + "_seed" - + str(self.cfg.seed), - dir=str(run_dir), - job_type="training", - reinit=True, - ) + if not self.skip_logging: + wandb.init( + config=self.cfg, + project=self.project_name, + entity=self.wandb_entity, + notes=socket.gethostname(), + name=self.scenario_name + + "_" + + str(self.exp_name) + + "_seed" + + str(self.cfg.seed), + dir=str(run_dir), + job_type="training", + reinit=True, + ) elif self.use_tensorboard: from tensorboardX import SummaryWriter @@ -152,7 +157,8 @@ def _init(self) -> None: def close(self): if self.use_wandb: - wandb.finish() + if not self.skip_logging: + wandb.finish() def info(self, msg: str): logging.info(msg) @@ -167,7 +173,8 @@ def log_learner_info( return for k, v in infos.items(): if self.use_wandb: - wandb.log({"Learner_{}/{}".format(leaner_id, k): v}, step=step) + if not self.skip_logging: + wandb.log({"Learner_{}/{}".format(leaner_id, k): v}, step=step) elif self.use_tensorboard: self.writter.add_scalars( "Learner_{}/{}".format(leaner_id, k), @@ -192,7 +199,8 @@ def log_info( logging_info_str += f"\t{k}: {v}\n" if self.use_wandb: - wandb.log({k: v}, step=step) + if not self.skip_logging: + wandb.log({k: v}, step=step) elif self.use_tensorboard: self.writter.add_scalars(k, {k: v}, step) if self.log_to_terminal: diff --git a/openrl/utils/type_aliases.py b/openrl/utils/type_aliases.py index 25991e24..d9012d7d 100644 --- a/openrl/utils/type_aliases.py +++ b/openrl/utils/type_aliases.py @@ -13,9 +13,7 @@ GymEnv = Union[gym.Env, vec_env.BaseVecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] -GymStepReturn = Union[ - Tuple[GymObs, float, bool, Dict], Tuple[GymObs, float, bool, bool, Dict] -] + TensorDict = Dict[Union[str, int], th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[ diff --git a/setup.py b/setup.py index 494e4c4c..28cffd3c 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def get_install_requires() -> list: return [ "setuptools>=67.0", - "gymnasium", + "gymnasium>=0.29", "click", "termcolor", "gym", @@ -60,16 +60,39 @@ def get_extra_requires() -> dict: "mpe": ["pyglet==1.5.27"], "nlp": [ "transformers==4.18.0", - "datasets", + "datasets==2.13", "nltk", "evaluate", "icetk", ], - "selfplay": ["ray[default]", "ray[serve]", "pettingzoo[classic]", "trueskill"], + "nlp_test": [ + "transformers", + "datasets==2.13", + "evaluate", + ], + "selfplay": [ + "ray[default]>=2.7", + "ray[serve]", + "async_timeout", + "pettingzoo[classic]", + "trueskill", + ], + "selfplay_test": [ + "ray[default]>=2.7", + "ray[serve]", + "async_timeout", + "fastapi", + "pettingzoo[mpe]", + "pettingzoo[butterfly]", + ], "retro": ["gym-retro"], "super_mario": ["gym-super-mario-bros"], + "atari": ["gymnasium[atari]", "gymnasium[accept-rom-license]"], } req["test"].extend(req["selfplay"]) + req["test"].extend(req["selfplay_test"]) + req["test"].extend(req["atari"]) + req["test"].extend(req["nlp_test"]) return req diff --git a/tests/test_algorithm/test_a2c_algorithm.py b/tests/test_algorithm/test_a2c_algorithm.py new file mode 100644 index 00000000..0f4f7226 --- /dev/null +++ b/tests/test_algorithm/test_a2c_algorithm.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.Discrete(2) + + +@pytest.fixture( + scope="module", params=["--use_share_model false", "--use_share_model true"] +) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def amp_config(): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args("") + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.ppo_module import PPOModule + + module = PPOModule( + config, + policy_input_space=obs_space, + critic_input_space=obs_space, + act_space=act_space, + share_model=config.use_share_model, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.normal_buffer import NormalReplayBuffer + + buffer = NormalReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=100, + ) + return buffer.data + + +@pytest.mark.unittest +def test_a2c_algorithm(config, init_module, buffer_data): + from openrl.algorithms.a2c import A2CAlgorithm + + a2c_algo = A2CAlgorithm(config, init_module) + + a2c_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_algorithm/test_bc_algorithm.py b/tests/test_algorithm/test_bc_algorithm.py new file mode 100644 index 00000000..fa073174 --- /dev/null +++ b/tests/test_algorithm/test_bc_algorithm.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.Discrete(2) + + +@pytest.fixture(scope="module", params=["", "--use_share_model true"]) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.bc_module import BCModule + + module = BCModule( + config, + policy_input_space=obs_space, + critic_input_space=obs_space, + act_space=act_space, + share_model=config.use_share_model, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.normal_buffer import NormalReplayBuffer + + buffer = NormalReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=100, + ) + return buffer.data + + +@pytest.mark.unittest +def test_bc_algorithm(config, init_module, buffer_data): + from openrl.algorithms.behavior_cloning import BCAlgorithm + + bc_algo = BCAlgorithm(config, init_module) + + bc_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_algorithm/test_ddpg_algorithm.py b/tests/test_algorithm/test_ddpg_algorithm.py new file mode 100644 index 00000000..b31a56df --- /dev/null +++ b/tests/test_algorithm/test_ddpg_algorithm.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.box.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def amp_config(): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args("") + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.ddpg_module import DDPGModule + + module = DDPGModule( + config, + input_space=obs_space, + act_space=act_space, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.offpolicy_buffer import OffPolicyReplayBuffer + + buffer = OffPolicyReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=5000, + ) + return buffer.data + + +@pytest.mark.unittest +def test_ddpg_algorithm(config, init_module, buffer_data): + from openrl.algorithms.ddpg import DDPGAlgorithm + + ddpg_algo = DDPGAlgorithm(config, init_module) + + ddpg_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_algorithm/test_ppo_algorithm.py b/tests/test_algorithm/test_ppo_algorithm.py index 8ac5c865..98a8a5d4 100644 --- a/tests/test_algorithm/test_ppo_algorithm.py +++ b/tests/test_algorithm/test_ppo_algorithm.py @@ -33,7 +33,9 @@ def act_space(): return spaces.Discrete(2) -@pytest.fixture(scope="module", params=["", "--use_share_model true"]) +@pytest.fixture( + scope="module", params=["--use_share_model false", "--use_share_model true"] +) def config(request): from openrl.configs.config import create_config_parser diff --git a/tests/test_algorithm/test_sac_algorithm.py b/tests/test_algorithm/test_sac_algorithm.py new file mode 100644 index 00000000..80447a3a --- /dev/null +++ b/tests/test_algorithm/test_sac_algorithm.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + + +@pytest.fixture +def obs_space(): + return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture +def act_space(): + return spaces.box.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32) + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.fixture +def amp_config(): + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args("") + return cfg + + +@pytest.fixture +def init_module(config, obs_space, act_space): + from openrl.modules.sac_module import SACModule + + module = SACModule( + config, + input_space=obs_space, + act_space=act_space, + ) + return module + + +@pytest.fixture +def buffer_data(config, obs_space, act_space): + from openrl.buffers.offpolicy_buffer import OffPolicyReplayBuffer + + buffer = OffPolicyReplayBuffer( + config, + num_agents=1, + obs_space=obs_space, + act_space=act_space, + data_client=None, + episode_length=5000, + ) + return buffer.data + + +@pytest.mark.unittest +def test_sac_algorithm(config, init_module, buffer_data): + from openrl.algorithms.sac import SACAlgorithm + + sac_algo = SACAlgorithm(config, init_module) + + sac_algo.train(buffer_data) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_arena/test_new_envs.py b/tests/test_arena/test_new_envs.py new file mode 100644 index 00000000..7a5dc01d --- /dev/null +++ b/tests/test_arena/test_new_envs.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import pytest +from pettingzoo.butterfly import cooperative_pong_v5 +from pettingzoo.classic import connect_four_v3, go_v5, texas_holdem_no_limit_v6 +from pettingzoo.mpe import simple_push_v3 + +from examples.custom_env.rock_paper_scissors import RockPaperScissors +from openrl.arena import make_arena +from openrl.arena.agents.local_agent import LocalAgent +from openrl.arena.agents.random_agent import RandomAgent +from openrl.envs.PettingZoo.registration import register +from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner + + +def ConnectFourEnv(render_mode, **kwargs): + return connect_four_v3.env(render_mode) + + +def RockPaperScissorsEnv(render_mode, **kwargs): + return RockPaperScissors(render_mode) + + +def GoEnv(render_mode, **kwargs): + return go_v5.env(render_mode=render_mode, board_size=5, komi=7.5) + + +def TexasHoldemEnv(render_mode, **kwargs): + return texas_holdem_no_limit_v6.env(render_mode=render_mode) + + +# MPE +def SimplePushEnv(render_mode, **kwargs): + return simple_push_v3.env(render_mode=render_mode) + + +def CooperativePongEnv(render_mode, **kwargs): + return cooperative_pong_v5.env(render_mode=render_mode) + + +def register_new_envs(): + new_env_dict = { + "connect_four_v3": ConnectFourEnv, + "RockPaperScissors": RockPaperScissorsEnv, + "go_v5": GoEnv, + "texas_holdem_no_limit_v6": TexasHoldemEnv, + "simple_push_v3": SimplePushEnv, + "cooperative_pong_v5": CooperativePongEnv, + } + + for env_id, env in new_env_dict.items(): + register(env_id, env) + return new_env_dict.keys() + + +def run_arena( + env_id: str, + parallel: bool = True, + seed=0, + total_games: int = 10, + max_game_onetime: int = 5, +): + env_wrappers = [RecordWinner] + + arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False) + + agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") + agent2 = RandomAgent() + + arena.reset( + agents={"agent1": agent1, "agent2": agent2}, + total_games=total_games, + max_game_onetime=max_game_onetime, + seed=seed, + ) + result = arena.run(parallel=parallel) + arena.close() + return result + + +@pytest.mark.unittest +def test_new_envs(): + env_ids = register_new_envs() + seed = 0 + for env_id in env_ids: + run_arena(env_id=env_id, seed=seed, parallel=False, total_games=1) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_arena/test_reproducibility.py b/tests/test_arena/test_reproducibility.py index 0d186ab0..9ced525c 100644 --- a/tests/test_arena/test_reproducibility.py +++ b/tests/test_arena/test_reproducibility.py @@ -22,6 +22,7 @@ from openrl.arena import make_arena from openrl.arena.agents.local_agent import LocalAgent +from openrl.arena.agents.random_agent import RandomAgent from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner @@ -41,7 +42,7 @@ def run_arena( arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=False) agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") - agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") + agent2 = RandomAgent() arena.reset( agents={"agent1": agent1, "agent2": agent2}, diff --git a/tests/test_buffer/test_generator.py b/tests/test_buffer/test_generator.py new file mode 100644 index 00000000..27763635 --- /dev/null +++ b/tests/test_buffer/test_generator.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import pytest + +from openrl.envs.common import make +from openrl.modules.common import PPONet as Net +from openrl.runners.common import PPOAgent as Agent + + +@pytest.fixture(scope="module", params=["--episode_length 10"]) +def episode_length(request): + return request.param + + +@pytest.fixture( + scope="module", + params=[ + "--use_recurrent_policy true --use_joint_action_loss true", + "--use_recurrent_policy true --use_joint_action_loss false", + "--use_recurrent_policy false --use_naive_recurrent true", + "--use_recurrent_policy false --use_naive_recurrent false", + ], +) +def generator_type(request): + return request.param + + +@pytest.fixture(scope="module", params=["--use_gae true", "--use_gae false"]) +def use_gae(request): + return request.param + + +@pytest.fixture( + scope="module", + params=["--use_proper_time_limits true", "--use_proper_time_limits false"], +) +def use_proper_time_limits(request): + return request.param + + +@pytest.fixture( + scope="module", + params=[ + "--use_popart true --use_valuenorm false", + "--use_popart false --use_valuenorm true", + "--use_popart false --use_valuenorm false", + ], +) +def use_popart(request): + return request.param + + +@pytest.fixture(scope="module") +def config(use_proper_time_limits, use_popart, use_gae, generator_type, episode_length): + config_str = ( + use_proper_time_limits + + " " + + use_popart + + " " + + use_gae + + " " + + generator_type + + " " + + episode_length + ) + + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(config_str.split()) + return cfg + + +@pytest.mark.unittest +def test_buffer_generator(config): + env = make("CartPole-v1", env_num=2) + agent = Agent(Net(env, cfg=config)) + agent.train(total_time_steps=50) + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_buffer/test_offpolicy_generator.py b/tests/test_buffer/test_offpolicy_generator.py new file mode 100644 index 00000000..ec960973 --- /dev/null +++ b/tests/test_buffer/test_offpolicy_generator.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import pytest + +from openrl.envs.common import make +from openrl.modules.common import DQNNet as Net +from openrl.runners.common import DQNAgent as Agent + + +@pytest.fixture(scope="module", params=["--episode_length 10"]) +def episode_length(request): + return request.param + + +@pytest.fixture( + scope="module", + params=[ + "--use_recurrent_policy false --use_joint_action_loss false", + ], +) +def generator_type(request): + return request.param + + +@pytest.fixture(scope="module", params=["--use_proper_time_limits false"]) +def use_proper_time_limits(request): + return request.param + + +@pytest.fixture(scope="module", params=["--use_popart false --use_valuenorm false"]) +def use_popart(request): + return request.param + + +@pytest.fixture(scope="module") +def config(use_proper_time_limits, use_popart, generator_type, episode_length): + config_str = ( + use_proper_time_limits + + " " + + use_popart + + " " + + generator_type + + " " + + episode_length + ) + + from openrl.configs.config import create_config_parser + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(config_str.split()) + return cfg + + +@pytest.mark.unittest +def test_buffer_generator(config): + env = make("CartPole-v1", env_num=2) + agent = Agent(Net(env, cfg=config)) + agent.train(total_time_steps=50) + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_dataset/test_expert_dataset.py b/tests/test_dataset/test_expert_dataset.py new file mode 100644 index 00000000..1eed9125 --- /dev/null +++ b/tests/test_dataset/test_expert_dataset.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import pytest +import torch + +from openrl.datasets.expert_dataset import ExpertDataset +from openrl.envs.common import make +from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper +from openrl.envs.wrappers.monitor import Monitor + +env_wrappers = [ + Monitor, +] + + +def gen_data(total_episode, data_save_path): + # begin to test + # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. + + env = make( + "IdentityEnv", + env_num=1, + asynchronous=True, + env_wrappers=env_wrappers, + ) + + env = GenDataWrapper( + env, data_save_path=data_save_path, total_episode=total_episode + ) + env.reset() + done = False + ep_length = 0 + while not done: + obs, r, done, info = env.step(env.random_action()) + ep_length += 1 + env.close() + return ep_length + + +@pytest.mark.unittest +def test_expert_dataset(tmp_path): + total_episode = 1 + data_save_path = tmp_path / "data.pkl" + ep_length = gen_data(total_episode, data_save_path) + + dataset = ExpertDataset( + data_save_path, + num_trajectories=None, + subsample_frequency=1, + seed=None, + env_id=0, + env_num=1, + ) + assert len(dataset) == ep_length, "len(dataset)={},data_length={}".format( + len(dataset), ep_length + ) + assert len(dataset[0]) == 2, "len(dataset[0])={}".format(len(dataset[0])) + + data_loader = torch.utils.data.DataLoader( + dataset=dataset, batch_size=1, shuffle=False, drop_last=True + ) + + step = 0 + for batch_data in data_loader: + assert len(batch_data) == 2, "len(batch_data)={}".format(len(batch_data)) + step += 1 + assert step == ep_length, "step={},ep_length={}".format(step, ep_length) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_env/test_mpe_env.py b/tests/test_env/test_mpe_env.py index 2dd664b6..0b555bb2 100644 --- a/tests/test_env/test_mpe_env.py +++ b/tests/test_env/test_mpe_env.py @@ -18,15 +18,16 @@ import os import sys +import numpy as np import pytest +from openrl.envs.common import make + @pytest.mark.unittest def test_mpe(): - from openrl.envs.common import make - - env_num = 6 - env = make("simple_spread", env_num=6) + env_num = 3 + env = make("simple_spread", env_num=env_num) obs, info = env.reset() obs, reward, done, info = env.step(env.random_action()) assert env.agent_num == 3 @@ -34,5 +35,27 @@ def test_mpe(): env.close() +@pytest.mark.unittest +def test_mpe_render(): + render_model = "human" + env_num = 2 + env = make( + "simple_spread", render_mode=render_model, env_num=env_num, asynchronous=False + ) + + env.reset(seed=0) + done = False + step = 0 + total_reward = 0 + while not np.any(done): + # Based on environmental observation input, predict next action. + + obs, r, done, info = env.step(env.random_action()) + step += 1 + total_reward += np.mean(r) + + env.close() + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_env/test_nlp/test_DailyDialogEnv.py b/tests/test_env/test_nlp/test_DailyDialogEnv.py new file mode 100644 index 00000000..6f0ac1df --- /dev/null +++ b/tests/test_env/test_nlp/test_DailyDialogEnv.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import pytest + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make + + +@pytest.fixture( + scope="module", + params=["--env.args {'data_path':None,'tokenizer_path':'builtin_BPE'}"], +) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.mark.unittest +def test_DailyDialogEnv(config): + env = make("daily_dialog", env_num=1, asynchronous=False, cfg=config) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_env/test_snake_env.py b/tests/test_env/test_snake_env.py new file mode 100644 index 00000000..8e558231 --- /dev/null +++ b/tests/test_env/test_snake_env.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import gymnasium as gym +import numpy as np +import pytest +from gymnasium import spaces + +from openrl.envs.common import make +from openrl.envs.wrappers.base_wrapper import BaseObservationWrapper +from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper + + +class ConvertObs(BaseObservationWrapper): + def __init__(self, env: gym.Env): + BaseObservationWrapper.__init__(self, env) + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(576,), dtype=np.float32 + ) + + def observation(self, observation): + new_obs = np.zeros((len(observation), 576), dtype=int) + return new_obs + + +@pytest.mark.unittest +def test_snake(): + env_num = 2 + for i in [1, 3]: + env = make( + f"snakes_{i}v{i}", + env_num=env_num, + asynchronous=False, + opponent_wrappers=[RandomOpponentWrapper], + env_wrappers=[ConvertObs], + auto_reset=False, + ) + ep_num = 3 + for ep_now in range(ep_num): + obs, info = env.reset() + done = False + step = 0 + + while not np.any(done): + obs, r, done, info = env.step(env.random_action()) + step += 1 + + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_env/test_vec_env/test_async_env.py b/tests/test_env/test_vec_env/test_async_env.py new file mode 100644 index 00000000..2c2301d3 --- /dev/null +++ b/tests/test_env/test_vec_env/test_async_env.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import multiprocessing as mp +import os +import sys + +import pytest +from gymnasium.wrappers import EnvCompatibility + +from openrl.envs.toy_envs import make_toy_envs +from openrl.envs.vec_env.async_venv import AsyncVectorEnv, _worker + + +class CustomEnvCompatibility(EnvCompatibility): + def reset(self, **kwargs): + return super().reset(**kwargs)[0] + + +def init_envs(): + env_wrappers = [CustomEnvCompatibility] + env_fns = make_toy_envs( + id="IdentityEnv", + env_num=2, + env_wrappers=env_wrappers, + ) + return env_fns + + +def assert_env_name(env, env_name): + if isinstance(env.metadata["name"], str): + assert env.metadata["name"] == env_name + else: + assert env.metadata["name"].__name__ == env_name + + +@pytest.mark.unittest +def test_async_env(): + env_name = "IdentityEnv" + env = AsyncVectorEnv(init_envs(), shared_memory=True) + assert ( + env._env_name == env_name + ), "AsyncVectorEnv should have the same metadata as the wrapped env" + env.exec_func(assert_env_name, indices=None, env_name=env_name) + env.call("render") + env_name_new = "IdentityEnvNew" + env.set_attr("metadata", {"name": env_name_new}) + env.exec_func(assert_env_name, indices=None, env_name=env_name_new) + + +def main_control(parent_pipe, child_pipe): + child_pipe.close() + + parent_pipe.send(("reset", {"seed": 0})) + result, success = parent_pipe.recv() + assert success, result + + parent_pipe.send(("step", [0])) + result, success = parent_pipe.recv() + assert success, result + + parent_pipe.send(("_call", ("render", [], {}))) + result, success = parent_pipe.recv() + assert success, result + + parent_pipe.send(("_setattr", ("metadata", {"name": "IdentityEnvNew"}))) + result, success = parent_pipe.recv() + assert success, result + + parent_pipe.send( + ("_func_exec", (assert_env_name, None, [], {"env_name": "IdentityEnvNew"})) + ) + result, success = parent_pipe.recv() + assert success, result + + parent_pipe.send(("close", None)) + result, success = parent_pipe.recv() + assert success, result + + +@pytest.mark.unittest +def test_worker(): + for auto_reset in [True, False]: + ctx = mp.get_context(None) + parent_pipe, child_pipe = ctx.Pipe() + + error_queue = ctx.Queue() + + process = ctx.Process( + target=main_control, + name="test", + args=(parent_pipe, child_pipe), + ) + process.daemon = True + process.start() + _worker(0, init_envs()[0], child_pipe, None, False, error_queue, auto_reset) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_env/test_vec_env/test_sync_env.py b/tests/test_env/test_vec_env/test_sync_env.py new file mode 100644 index 00000000..fb3d5d0b --- /dev/null +++ b/tests/test_env/test_vec_env/test_sync_env.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import pytest +from gymnasium.wrappers import EnvCompatibility + +from openrl.envs.toy_envs import make_toy_envs +from openrl.envs.vec_env.sync_venv import SyncVectorEnv + + +class CustomEnvCompatibility(EnvCompatibility): + def reset(self, **kwargs): + return super().reset(**kwargs)[0] + + +def init_envs(): + env_wrappers = [CustomEnvCompatibility] + env_fns = make_toy_envs( + id="IdentityEnv", + env_num=2, + env_wrappers=env_wrappers, + ) + return env_fns + + +def assert_env_name(env, env_name): + assert env.metadata["name"].__name__ == env_name + + +@pytest.mark.unittest +def test_sync_env(): + env_name = "IdentityEnv" + env = SyncVectorEnv(init_envs()) + env.exec_func(assert_env_name, indices=None, env_name=env_name) + env.call("render") + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_env/test_vec_env/test_vec_wrappers.py b/tests/test_env/test_vec_env/test_vec_wrappers.py new file mode 100644 index 00000000..9b5735c9 --- /dev/null +++ b/tests/test_env/test_vec_env/test_vec_wrappers.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import pickle +import sys + +import numpy as np +import pytest + +from openrl.envs.common import make +from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper, GenDataWrapper_v1 +from openrl.envs.vec_env.wrappers.zero_reward_wrapper import ZeroRewardWrapper +from openrl.envs.wrappers.monitor import Monitor + + +@pytest.mark.unittest +def test_zero_reward_wrapper(): + env = make("IdentityEnv", env_num=1) + env = ZeroRewardWrapper(env) + env.reset(seed=0) + while True: + obs, reward, done, info = env.step(env.random_action()) + assert np.all(reward == 0), "reward should be zero" + if done: + break + env.close() + + +@pytest.mark.unittest +def test_gen_data(tmp_path): + total_episode = 4 + env = make("IdentityEnv", env_wrappers=[Monitor], env_num=1) + data_save_path = tmp_path / "data.pkl" + env = GenDataWrapper( + env, data_save_path=str(data_save_path), total_episode=total_episode + ) + obs, info = env.reset(seed=0) + done = False + while not done: + obs, r, done, info = env.step(env.random_action()) + env.close() + + save_data = pickle.load(open(data_save_path, "rb")) + assert len(save_data["episode_lengths"]) == total_episode, ( + f"episode_lengths {len(save_data['episode_lengths'])} " + f"should be equal to total_episode {total_episode}" + ) + + +@pytest.mark.unittest +def test_gen_data_old(tmp_path): + total_episode = 4 + env = make("IdentityEnv", env_wrappers=[Monitor], env_num=1) + data_save_path = tmp_path / "data.pkl" + env = GenDataWrapper_v1( + env, data_save_path=str(data_save_path), total_episode=total_episode + ) + obs, info = env.reset(seed=0) + done = False + while not done: + obs, r, done, info = env.step(env.random_action()) + env.close() + + save_data = pickle.load(open(data_save_path, "rb")) + assert save_data["total_episode"] == total_episode, ( + f"episode_lengths {save_data['total_episode']} " + f"should be equal to total_episode {total_episode}" + ) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_env/test_wrappers.py b/tests/test_env/test_wrappers.py new file mode 100644 index 00000000..6042eccf --- /dev/null +++ b/tests/test_env/test_wrappers.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import pytest + + +@pytest.mark.unittest +def test_atari_wrappers(): + import gymnasium + + from openrl.envs.wrappers.atari_wrappers import ( + ClipRewardEnv, + EpisodicLifeEnv, + FireResetEnv, + NoopResetEnv, + WarpFrame, + ) + + env = gymnasium.make("ALE/Breakout-v5") + env = FireResetEnv(EpisodicLifeEnv(ClipRewardEnv(WarpFrame(NoopResetEnv(env))))) + env.reset(seed=0) + while True: + obs, reward, done, truncated, info = env.step(0) + if done: + break + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_examples/test_nlp.py b/tests/test_examples/test_nlp.py index 99111fdc..1524ae65 100644 --- a/tests/test_examples/test_nlp.py +++ b/tests/test_examples/test_nlp.py @@ -17,20 +17,26 @@ # """""" # +import os +import sys +import pytest + +from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.modules.common import PPONet as Net from openrl.runners.common import PPOAgent as Agent -def config(): - from openrl.configs.config import create_config_parser - +# @pytest.fixture(scope="module", params=["--env.args {'data_path':None,'tokenizer_path':'builtin_BPE'}"]) +@pytest.fixture(scope="module", params=[""]) +def config(request): cfg_parser = create_config_parser() - cfg = cfg_parser.parse_args() + cfg = cfg_parser.parse_args(request.param.split()) return cfg +@pytest.mark.unittest def test_train_nlp(config): env = make("fake_dialog_data", env_num=3, cfg=config) agent = Agent(Net(env)) @@ -38,5 +44,4 @@ def test_train_nlp(config): if __name__ == "__main__": - cfg = config() - test_train_nlp(cfg) + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_examples/test_train_atari.py b/tests/test_examples/test_train_atari.py new file mode 100644 index 00000000..4d8d2166 --- /dev/null +++ b/tests/test_examples/test_train_atari.py @@ -0,0 +1,74 @@ +"""""" + +import os +import sys + +import numpy as np +import pytest + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.wrappers.atari_wrappers import ( + ClipRewardEnv, + FireResetEnv, + NoopResetEnv, + WarpFrame, +) +from openrl.envs.wrappers.image_wrappers import TransposeImage +from openrl.envs.wrappers.monitor import Monitor +from openrl.modules.common import PPONet as Net +from openrl.runners.common import PPOAgent as Agent + +env_wrappers = [ + Monitor, + NoopResetEnv, + FireResetEnv, + WarpFrame, + ClipRewardEnv, + TransposeImage, +] + + +@pytest.fixture( + scope="module", + params=[ + "--episode_length 5 --use_recurrent_policy false --vec_info_class.id" + " EPS_RewardInfo --use_valuenorm true --use_adv_normalize true" + " --use_share_model True --entropy_coef 0.01" + ], +) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.mark.unittest +def test_train_atari(config): + env_num = 2 + env = make( + "ALE/Pong-v5", + env_num=env_num, + cfg=config, + asynchronous=True, + env_wrappers=env_wrappers, + ) + net = Net(env, cfg=config) + agent = Agent(net) + agent.train(total_time_steps=30) + agent.save("./ppo_agent/") + agent.load("./ppo_agent/") + agent.set_env(env) + obs, info = env.reset(seed=0) + step = 0 + while step < 5: + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + if np.any(done): + break + step += 1 + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_examples/test_train_gail.py b/tests/test_examples/test_train_gail.py new file mode 100644 index 00000000..656ff2d0 --- /dev/null +++ b/tests/test_examples/test_train_gail.py @@ -0,0 +1,75 @@ +"""""" + +import os +import sys + +import pytest + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper +from openrl.envs.wrappers.extra_wrappers import ZeroRewardWrapper +from openrl.envs.wrappers.monitor import Monitor +from openrl.modules.common import GAILNet as Net +from openrl.modules.common import PPONet +from openrl.runners.common import GAILAgent as Agent +from openrl.runners.common import PPOAgent + + +@pytest.fixture(scope="function") +def gen_data(tmpdir): + tmp_data_path = os.path.join(tmpdir, "data.pkl") + env_wrappers = [ + Monitor, + ] + print("generate data....") + env = make( + "CartPole-v1", + env_num=2, + asynchronous=True, + env_wrappers=env_wrappers, + ) + agent = PPOAgent(PPONet(env)) + env = GenDataWrapper(env, data_save_path=tmp_data_path, total_episode=5) + obs, info = env.reset() + done = False + while not done: + # Based on environmental observation input, predict next action. + action, _ = agent.act(obs, deterministic=True) + obs, r, done, info = env.step(action) + env.close() + print("generate data done!") + return tmp_data_path + + +@pytest.fixture( + scope="function", params=[" --gail_use_action false", " --gail_use_action true"] +) +def config(request, gen_data): + input_str = ( + "--episode_length 5 --use_recurrent_policy true --use_joint_action_loss true" + " --use_valuenorm true --use_adv_normalize true --reward_class.id GAILReward" + ) + input_str += request.param + input_str += " --expert_data " + gen_data + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(input_str.split()) + return cfg + + +@pytest.mark.unittest +def test_train_gail(config): + env = make("CartPole-v1", env_num=2, cfg=config, env_wrappers=[ZeroRewardWrapper]) + + net = Net( + env, + cfg=config, + ) + # initialize the trainer + agent = Agent(net) + agent.train(total_time_steps=200) + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_examples/test_train_mpe.py b/tests/test_examples/test_train_mpe.py index 419b3dab..36e3e689 100644 --- a/tests/test_examples/test_train_mpe.py +++ b/tests/test_examples/test_train_mpe.py @@ -1,4 +1,5 @@ """""" + import os import sys diff --git a/tests/test_modules/test_common/test_ddpg_net.py b/tests/test_modules/test_common/test_ddpg_net.py new file mode 100644 index 00000000..a4c03354 --- /dev/null +++ b/tests/test_modules/test_common/test_ddpg_net.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import pytest + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.wrappers.extra_wrappers import AddStep +from openrl.modules.common import DDPGNet as Net +from openrl.runners.common import DDPGAgent as Agent + +env_wrappers = [AddStep] + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +def train(Agent, Net, env_name, env_num, total_time_steps, config): + cfg = config + env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers) + + net = Net( + env, + cfg=cfg, + ) + # initialize the trainer + agent = Agent(net) + # start training, set total number of training steps to 20000 + agent.train(total_time_steps=total_time_steps) + env.close() + + +@pytest.mark.unittest +def test_ddpg_net(config): + train(Agent, Net, "IdentityEnvcontinuous", 2, 100, config) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_modules/test_common/test_dqn_net.py b/tests/test_modules/test_common/test_dqn_net.py new file mode 100644 index 00000000..292c08b4 --- /dev/null +++ b/tests/test_modules/test_common/test_dqn_net.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import pytest + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.wrappers.extra_wrappers import AddStep +from openrl.modules.common import DQNNet as Net +from openrl.runners.common import DQNAgent as Agent + +env_wrappers = [AddStep] + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +def train(Agent, Net, env_name, env_num, total_time_steps, config): + cfg = config + env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers) + + net = Net( + env, + cfg=cfg, + ) + # initialize the trainer + agent = Agent(net) + # start training, set total number of training steps to 20000 + agent.train(total_time_steps=total_time_steps) + env.close() + + +@pytest.mark.unittest +def test_dqn_net(config): + train(Agent, Net, "IdentityEnv", 2, 100, config) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_modules/test_common/test_sac_net.py b/tests/test_modules/test_common/test_sac_net.py new file mode 100644 index 00000000..8839986e --- /dev/null +++ b/tests/test_modules/test_common/test_sac_net.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import pytest + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.wrappers.extra_wrappers import AddStep +from openrl.modules.common import SACNet as Net +from openrl.runners.common import SACAgent as Agent + +env_wrappers = [AddStep] + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +def train(Agent, Net, env_name, env_num, total_time_steps, config): + cfg = config + env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers) + + net = Net( + env, + cfg=cfg, + ) + # initialize the trainer + agent = Agent(net) + # start training, set total number of training steps to 20000 + agent.train(total_time_steps=total_time_steps) + env.close() + + +@pytest.mark.unittest +def test_sac_net(config): + train(Agent, Net, "IdentityEnvcontinuous", 2, 100, config) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_modules/test_common/test_vdn_net.py b/tests/test_modules/test_common/test_vdn_net.py new file mode 100644 index 00000000..29f1f58f --- /dev/null +++ b/tests/test_modules/test_common/test_vdn_net.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import pytest + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.wrappers.mat_wrapper import MATWrapper +from openrl.modules.common import VDNNet +from openrl.runners.common import VDNAgent as Agent + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.mark.unittest +def test_vdn_net(config): + env_num = 2 + env = make( + "simple_spread", + env_num=env_num, + asynchronous=True, + ) + env = MATWrapper(env) + + net = VDNNet(env, cfg=config) + # initialize the trainer + agent = Agent(net) + # start training + agent.train(total_time_steps=100) + env.close() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_modules/test_networks/test_MAT_network.py b/tests/test_modules/test_networks/test_MAT_network.py new file mode 100644 index 00000000..25b4fd00 --- /dev/null +++ b/tests/test_modules/test_networks/test_MAT_network.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + +from openrl.configs.config import create_config_parser +from openrl.modules.networks.MAT_network import MultiAgentTransformer + + +@pytest.fixture(scope="module", params=[""]) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.mark.unittest +def test_MAT_network(config): + net = MultiAgentTransformer( + config, + input_space=spaces.Discrete(2), + action_space=spaces.Discrete(2), + ) + net.get_actor_para() + net.get_critic_para() + + obs = np.zeros([1, 2]) + rnn_states = np.zeros(2) + masks = np.zeros(2) + action = np.zeros(1) + net.get_actions(obs=obs, masks=masks) + net.eval_actions( + obs=obs, rnn_states=rnn_states, action=action, masks=masks, action_masks=None + ) + net.get_values(critic_obs=obs, masks=masks) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_modules/test_networks/test_attention.py b/tests/test_modules/test_networks/test_attention.py new file mode 100644 index 00000000..599d0000 --- /dev/null +++ b/tests/test_modules/test_networks/test_attention.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import pytest +import torch + +from openrl.configs.config import create_config_parser +from openrl.modules.networks.utils.attention import Encoder + + +@pytest.fixture( + scope="module", params=["--use_average_pool True", "--use_average_pool False"] +) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.mark.unittest +def test_attention(config): + for cat_self in [False, True]: + net = Encoder(cfg=config, split_shape=[[1, 1], [1, 1]], cat_self=cat_self) + net(torch.zeros((1, 1))) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_modules/test_networks/test_policy_value_network_gpt.py b/tests/test_modules/test_networks/test_policy_value_network_gpt.py new file mode 100644 index 00000000..66a6caaa --- /dev/null +++ b/tests/test_modules/test_networks/test_policy_value_network_gpt.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +import os +import sys + +import numpy as np +import pytest +from gymnasium import spaces + +from openrl.configs.config import create_config_parser +from openrl.modules.networks.policy_value_network_gpt import ( + PolicyValueNetworkGPT as PolicyValueNetwork, +) + + +@pytest.fixture(scope="module", params=["--model_path test_gpt2"]) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.mark.unittest +def test_gpt_network(config): + net = PolicyValueNetwork( + cfg=config, + input_space=spaces.Discrete(2), + action_space=spaces.Discrete(2), + ) + + net.get_actor_para() + net.get_critic_para() + + obs = { + "input_encoded_pt": np.zeros([1, 2]), + "input_attention_mask_pt": np.zeros([1, 2]), + } + rnn_states = np.zeros(2) + masks = np.zeros(2) + action = np.zeros(1) + net.get_actions(obs=obs, rnn_states=rnn_states, masks=masks) + net.eval_actions( + obs=obs, rnn_states=rnn_states, action=action, masks=masks, action_masks=None + ) + net.get_values(obs=obs, rnn_states=rnn_states, masks=masks) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_rewards/test_nlp_reward.py b/tests/test_rewards/test_nlp_reward.py new file mode 100644 index 00000000..739943ef --- /dev/null +++ b/tests/test_rewards/test_nlp_reward.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import numpy as np +import pytest + +from openrl.buffers.normal_buffer import NormalReplayBuffer +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.rewards import RewardFactory + + +@pytest.fixture( + scope="module", + params=[ + "--reward_class.id NLPReward --reward_class.args" + " {'intent_model':'builtin_intent','ref_model':'builtin_ref','use_deepspeed':False}" + ], +) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(request.param.split()) + return cfg + + +@pytest.mark.unittest +def test_nlp_reward(config): + env = make("fake_dialog_data", env_num=1) + reward = RewardFactory.get_reward_class(config.reward_class, env) + data = {} + data["rewards"] = np.zeros(32) + env_info = {} + env_info["final_info"] = { + "prompt_texts": "hello", + "generated_texts": "hello", + "meta_infos": {"intent": [1]}, + } + data["infos"] = [env_info] * 32 + data["step"] = 0 + data["actions"] = [0] + data["action_log_probs"] = np.zeros(32) + buffer = NormalReplayBuffer( + config, + num_agents=env.agent_num, + obs_space=env.observation_space, + act_space=env.action_space, + data_client=None, + episode_length=1, + ) + data["buffer"] = buffer + reward.step_reward(data=data) + reward.batch_rewards(buffer=buffer) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_selfplay/test_selfplay_strategy.py b/tests/test_selfplay/test_selfplay_strategy.py deleted file mode 100644 index 61b04052..00000000 --- a/tests/test_selfplay/test_selfplay_strategy.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright 2023 The OpenRL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""""" -import os -import sys - -import pytest - -from openrl.selfplay.strategies import ( - NaiveSelfplayStrategy, - OnlyLatestSelfplayStrategy, - VarExistEnemySelfplayStrategy, - WeightExistEnemySelfplayStrategy, - WeightSelfplayStrategy, - WinRateSelfplayStrategy, -) - - -@pytest.fixture(scope="module", params=[""]) -def config(request): - from openrl.configs.config import create_config_parser - - cfg_parser = create_config_parser() - cfg = cfg_parser.parse_args(request.param.split()) - return cfg - - -@pytest.mark.unittest -def test_naive_selfplay(config): - strategy = NaiveSelfplayStrategy(config, 1, 1) - strategy.get_plist() - strategy.update_weight(enemy_loses=1) - strategy.update_win_rate(dones=True, enemy_wins=1) - strategy.push_newone() - - -@pytest.mark.unittest -def test_only_latest_selfplay(config): - strategy = OnlyLatestSelfplayStrategy(config, 1, 1) - strategy.get_plist() - strategy.update_weight(enemy_loses=1) - strategy.push_newone() - - -@pytest.mark.unittest -def test_weight_selfplay(config): - strategy = WeightSelfplayStrategy(config, 1, 1) - strategy.get_plist() - strategy.update_weight(enemy_loses=1) - strategy.push_newone() - - -@pytest.mark.unittest -def test_win_rate_selfplay(config): - strategy = WinRateSelfplayStrategy(config, 1, 1) - strategy.get_plist() - strategy.update_weight(enemy_loses=1) - - -@pytest.mark.unittest -def test_var_exist_enemy_selfplay(config): - strategy = VarExistEnemySelfplayStrategy(config, 1, 1) - strategy.get_plist() - strategy.update_weight(enemy_loses=1) - strategy.push_newone() - - -@pytest.mark.unittest -def test_weight_exist_enemy_selfplay(config): - strategy = WeightExistEnemySelfplayStrategy(config, 1, 1) - strategy.get_plist() - strategy.update_weight(enemy_loses=1) - strategy.push_newone() - - -if __name__ == "__main__": - sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_selfplay/test_train_selfplay.py b/tests/test_selfplay/test_train_selfplay.py new file mode 100644 index 00000000..a67ea964 --- /dev/null +++ b/tests/test_selfplay/test_train_selfplay.py @@ -0,0 +1,134 @@ +import os +import socket +import sys + +import numpy as np +import pytest +import torch + +from openrl.configs.config import create_config_parser +from openrl.envs.common import make +from openrl.envs.wrappers import FlattenObservation +from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner +from openrl.modules.common import PPONet as Net +from openrl.runners.common import PPOAgent as Agent +from openrl.selfplay.wrappers.opponent_pool_wrapper import OpponentPoolWrapper +from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper + + +def find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +@pytest.fixture( + scope="module", + params=[ + {"port": find_free_port(), "strategy": "RandomOpponent"}, + {"port": find_free_port(), "strategy": "LastOpponent"}, + ], +) +def config(request): + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args(["--config", "./examples/selfplay/selfplay.yaml"]) + cfg.selfplay_api.port = request.param["port"] + print("port:", request.param["port"]) + for i, c in enumerate(cfg.callbacks): + if c["id"] == "SelfplayCallback": + c["args"][ + "opponent_template" + ] = "./examples/selfplay/opponent_templates/tictactoe_opponent" + port = c["args"]["api_address"].split(":")[-1].split("/")[0] + c["args"]["api_address"] = c["args"]["api_address"].replace( + port, str(request.param["port"]) + ) + cfg.callbacks[i] = c + elif c["id"] == "SelfplayAPI": + c["args"]["sample_strategy"] = request.param["strategy"] + c["args"]["port"] = request.param["port"] + cfg.callbacks[i] = c + + else: + pass + + return cfg + + +def train(cfg): + # Create environment + env_num = 2 + render_model = None + env = make( + "tictactoe_v3", + render_mode=render_model, + env_num=env_num, + asynchronous=False, + opponent_wrappers=[RecordWinner, OpponentPoolWrapper], + env_wrappers=[FlattenObservation], + cfg=cfg, + ) + # Create neural network + + net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") + # Create agent + agent = Agent(net) + # Begin training + agent.train(total_time_steps=10) + env.close() + agent.save("./selfplay_agent/") + return agent + + +def evaluation(): + print("Evaluation...") + env_num = 1 + env = make( + "tictactoe_v3", + env_num=env_num, + asynchronous=True, + opponent_wrappers=[RandomOpponentWrapper], + env_wrappers=[FlattenObservation], + auto_reset=False, + ) + + cfg_parser = create_config_parser() + cfg = cfg_parser.parse_args([]) + net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") + + agent = Agent(net) + + agent.load("./selfplay_agent/") + agent.set_env(env) + env.reset(seed=0) + + total_reward = 0.0 + ep_num = 2 + for ep_now in range(ep_num): + obs, info = env.reset() + done = False + step = 0 + + while not np.any(done): + # predict next action based on the observation + action, _ = agent.act(obs, info, deterministic=True) + obs, r, done, info = env.step(action) + step += 1 + + if np.any(done): + total_reward += np.mean(r) > 0 + print(f"{ep_now}/{ep_num}: reward: {np.mean(r)}") + print(f"win rate: {total_reward/ep_num}") + env.close() + print("Evaluation finished.") + + +@pytest.mark.unittest +def test_train_selfplay(config): + train(config) + evaluation() + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))