diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml
index cac1ec21..9671f935 100644
--- a/.github/workflows/unit_test.yml
+++ b/.github/workflows/unit_test.yml
@@ -17,6 +17,10 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
+ - name: Install system dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y xvfb libglu1-mesa-dev python3-opengl
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
@@ -27,7 +31,7 @@ jobs:
- name: do_unittest
timeout-minutes: 40
run: |
- python3 -m pytest tests --cov=openrl --cov-report=xml -m unittest --cov-report=term-missing --durations=0 -v --color=yes
+ xvfb-run -s "-screen 0 1400x900x24" python3 -m pytest tests --cov=openrl --cov-report=xml -m unittest --cov-report=term-missing --durations=0 -v --color=yes -s
- name: Upload coverage reports to Codecov with GitHub Action
uses: codecov/codecov-action@v3
with:
diff --git a/.gitignore b/.gitignore
index c92a6657..469ab373 100644
--- a/.gitignore
+++ b/.gitignore
@@ -153,10 +153,11 @@ run_results/
api_docs
.vscode
*.pkl
-api_docs
*.json
opponent_pool
!/examples/selfplay/opponent_templates/tictactoe_opponent/info.json
+!/examples/nlp/ds_config.json
+!/examples/nlp/eval_ds_config.json
wandb_run
examples/dmc/new.gif
/examples/snake/submissions/rl/actor_2000.pth
diff --git a/Project.md b/Project.md
index d7c455f5..38a8d1c7 100644
--- a/Project.md
+++ b/Project.md
@@ -18,7 +18,7 @@ However, in many practical applications, it is important to develop reasonable a
In this paper, we propose an on-policy framework for discovering multiple strategies for the same task.
Experimental results show that our method efficiently finds diverse strategies in a wide variety of reinforcement learning tasks.
-- Paper: [DGPO: Discovering Multiple Strategies with Diversity-Guided Policy Optimization](https://arxiv.org/abs/2207.05631)(AAMAS Extended Abstract 2023)
-- Authors: Wenze Chen, Shiyu Huang, Yuan Chiang, Ting Chen, Jun Zhu
+- Paper: [DGPO: Discovering Multiple Strategies with Diversity-Guided Policy Optimization](https://arxiv.org/abs/2207.05631)(AAAAI 2024)
+- Authors: Wenze Chen, Shiyu Huang, Yuan Chiang, Tim Pearce, Wei-Wei Tu, Ting Chen, Jun Zhu
diff --git a/README.md b/README.md
index 741fdef2..c76e3691 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
---
@@ -25,10 +25,10 @@
[![Contributors](https://img.shields.io/github/contributors/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/graphs/contributors)
[![GitHub license](https://img.shields.io/github/license/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/blob/master/LICENSE)
-[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/guvAS2up)
+[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/qMbVT2qBhr)
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg)
-OpenRL-v0.1.7 is updated on Sep 21, 2023
+OpenRL-v0.2.0 is updated on Dec 20, 2023
The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with
OpenRL, you can switch to the stable branch.
@@ -58,6 +58,8 @@ Currently, the features supported by OpenRL include:
- Reinforcement learning training support for natural language tasks (such as dialogue)
+- Support [DeepSpeed](https://github.com/microsoft/DeepSpeed)
+
- Support [Arena](https://openrl-docs.readthedocs.io/en/latest/arena/index.html) , which allows convenient evaluation of
various agents (even submissions for [JiDi](https://openrl-docs.readthedocs.io/en/latest/arena/index.html#performing-local-evaluation-of-agents-submitted-to-the-jidi-platform-using-openrl)) in a competitive environment.
@@ -160,19 +162,19 @@ Here we provide a table for the comparison of OpenRL and existing popular RL lib
OpenRL employs a modular design and high-level abstraction, allowing users to accomplish training for various tasks
through a unified and user-friendly interface.
-| Library | NLP/RLHF | Multi-agent | Self-Play Training | Offline RL | Bilingual Document |
-|:------------------------------------------------------------------:|:------------------:|:--------------------:|:--------------------:|:------------------:|:------------------:|
-| **[OpenRL](https://github.com/OpenRL-Lab/openrl)** | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
-| [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) | :x: | :x: | :x: | :x: | :x: |
-| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: |
-| [DI-engine](https://github.com/opendilab/DI-engine/) | :x: | :heavy_check_mark: | not fullly supported | :heavy_check_mark: | :heavy_check_mark: |
-| [Tianshou](https://github.com/thu-ml/tianshou) | :x: | not fullly supported | not fullly supported | :heavy_check_mark: | :heavy_check_mark: |
-| [MARLlib](https://github.com/Replicable-MARL/MARLlib) | :x: | :heavy_check_mark: | not fullly supported | :x: | :x: |
-| [MAPPO Benchmark](https://github.com/marlbenchmark/on-policy) | :x: | :heavy_check_mark: | :x: | :x: | :x: |
-| [RL4LMs](https://github.com/allenai/RL4LMs) | :heavy_check_mark: | :x: | :x: | :x: | :x: |
-| [trlx](https://github.com/CarperAI/trlx) | :heavy_check_mark: | :x: | :x: | :x: | :x: |
-| [trl](https://github.com/huggingface/trl) | :heavy_check_mark: | :x: | :x: | :x: | :x: |
-| [TimeChamber](https://github.com/inspirai/TimeChamber) | :x: | :x: | :heavy_check_mark: | :x: | :x: |
+| Library | NLP/RLHF | Multi-agent | Self-Play Training | Offline RL | [DeepSpeed](https://github.com/microsoft/DeepSpeed) |
+|:------------------------------------------------------------------:|:------------------:|:--------------------:|:--------------------:|:------------------:|:--------------------:|
+| **[OpenRL](https://github.com/OpenRL-Lab/openrl)** | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
+| [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) | :x: | :x: | :x: | :x: | :x: |
+| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: |
+| [DI-engine](https://github.com/opendilab/DI-engine/) | :x: | :heavy_check_mark: | not fullly supported | :heavy_check_mark: | :x: |
+| [Tianshou](https://github.com/thu-ml/tianshou) | :x: | not fullly supported | not fullly supported | :heavy_check_mark: | :x: |
+| [MARLlib](https://github.com/Replicable-MARL/MARLlib) | :x: | :heavy_check_mark: | not fullly supported | :x: | :x: |
+| [MAPPO Benchmark](https://github.com/marlbenchmark/on-policy) | :x: | :heavy_check_mark: | :x: | :x: | :x: |
+| [RL4LMs](https://github.com/allenai/RL4LMs) | :heavy_check_mark: | :x: | :x: | :x: | :x: |
+| [trlx](https://github.com/CarperAI/trlx) | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
+| [trl](https://github.com/huggingface/trl) | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
+| [TimeChamber](https://github.com/inspirai/TimeChamber) | :x: | :x: | :heavy_check_mark: | :x: | :x: |
## Installation
@@ -333,7 +335,7 @@ If you are using OpenRL in your research project, you are also welcome to join t
- Join the [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg) group to discuss
OpenRL usage and development with us.
-- Join the [Discord](https://discord.gg/guvAS2up) group to discuss OpenRL usage and development with us.
+- Join the [Discord](https://discord.gg/qMbVT2qBhr) group to discuss OpenRL usage and development with us.
- Send an E-mail to: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)
- Join the [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions).
diff --git a/README_zh.md b/README_zh.md
index c8fb4619..e75c76af 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -1,5 +1,5 @@
-
+
@@ -26,10 +26,10 @@
[![Contributors](https://img.shields.io/github/contributors/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/graphs/contributors)
[![GitHub license](https://img.shields.io/github/license/OpenRL-Lab/openrl)](https://github.com/OpenRL-Lab/openrl/blob/master/LICENSE)
-[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/guvAS2up)
+[![Embark](https://img.shields.io/badge/discord-OpenRL-%237289da.svg?logo=discord)](https://discord.gg/qMbVT2qBhr)
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg)
-OpenRL-v0.1.7 is updated on Sep 21, 2023
+OpenRL-v0.1.10 is updated on Oct 27, 2023
The main branch is the latest version of OpenRL, which is under active development. If you just want to have a try with
OpenRL, you can switch to the stable branch.
@@ -51,6 +51,7 @@ OpenRL基于PyTorch进行开发,目标是为强化学习研究社区提供一
- 支持通过专家数据进行离线强化学习训练
- 支持自博弈训练
- 支持自然语言任务(如对话任务)的强化学习训练
+- 支持[DeepSpeed](https://github.com/microsoft/DeepSpeed)
- 支持[竞技场](https://openrl-docs.readthedocs.io/zh/latest/arena/index.html)功能,可以在多智能体对抗性环境中方便地对各种智能体(甚至是[及第平台](https://openrl-docs.readthedocs.io/zh/latest/arena/index.html#openrl)上提交的智能体)进行评测。
- 支持从[Hugging Face](https://huggingface.co/)上导入模型和数据。支持加载Hugging Face上[Stable-baselines3的模型](https://openrl-docs.readthedocs.io/zh/latest/sb3/index.html)来进行测试和训练。
- 提供用户自有环境接入OpenRL的[详细教程](https://openrl-docs.readthedocs.io/zh/latest/custom_env/index.html).
@@ -128,18 +129,18 @@ OpenRL-Lab将持续维护和更新OpenRL,欢迎大家加入我们的[开源社
这里我们提供了一个表格,比较了OpenRL和其他常用的强化学习库。 OpenRL采用模块化设计和高层次的抽象,使得用户可以通过统一的简单易用的接口完成各种任务的训练。
-| 强化学习库 | 自然语言任务/RLHF | 多智能体训练 | 自博弈训练 | 离线强化学习 | 双语文档 |
+| 强化学习库 | 自然语言任务/RLHF | 多智能体训练 | 自博弈训练 | 离线强化学习 | [DeepSpeed](https://github.com/microsoft/DeepSpeed) |
|:------------------------------------------------------------------:|:------------------:|:--------------------:|:--------------------:|:------------------:|:------------------:|
| **[OpenRL](https://github.com/OpenRL-Lab/openrl)** | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) | :x: | :x: | :x: | :x: | :x: |
| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: |
-| [DI-engine](https://github.com/opendilab/DI-engine/) | :x: | :heavy_check_mark: | not fullly supported | :heavy_check_mark: | :heavy_check_mark: |
-| [Tianshou](https://github.com/thu-ml/tianshou) | :x: | not fullly supported | not fullly supported | :heavy_check_mark: | :heavy_check_mark: |
+| [DI-engine](https://github.com/opendilab/DI-engine/) | :x: | :heavy_check_mark: | not fullly supported | :heavy_check_mark: | :x: |
+| [Tianshou](https://github.com/thu-ml/tianshou) | :x: | not fullly supported | not fullly supported | :heavy_check_mark: | :x: |
| [MARLlib](https://github.com/Replicable-MARL/MARLlib) | :x: | :heavy_check_mark: | not fullly supported | :x: | :x: |
| [MAPPO Benchmark](https://github.com/marlbenchmark/on-policy) | :x: | :heavy_check_mark: | :x: | :x: | :x: |
| [RL4LMs](https://github.com/allenai/RL4LMs) | :heavy_check_mark: | :x: | :x: | :x: | :x: |
-| [trlx](https://github.com/CarperAI/trlx) | :heavy_check_mark: | :x: | :x: | :x: | :x: |
-| [trl](https://github.com/huggingface/trl) | :heavy_check_mark: | :x: | :x: | :x: | :x: |
+| [trlx](https://github.com/CarperAI/trlx) | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
+| [trl](https://github.com/huggingface/trl) | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| [TimeChamber](https://github.com/inspirai/TimeChamber) | :x: | :x: | :heavy_check_mark: | :x: | :x: |
## 安装
@@ -293,7 +294,7 @@ openrl --mode train --env CartPole-v1
- 加入 [slack](https://join.slack.com/t/openrlhq/shared_invite/zt-1tqwpvthd-Eeh0IxQ~DIaGqYXoW2IUQg)
群组,与我们一起讨论OpenRL的使用和开发。
-- 加入 [Discord](https://discord.gg/guvAS2up) 群组,与我们一起讨论OpenRL的使用和开发。
+- 加入 [Discord](https://discord.gg/qMbVT2qBhr) 群组,与我们一起讨论OpenRL的使用和开发。
- 发送邮件到: [huangshiyu@4paradigm.com](huangshiyu@4paradigm.com)
- 加入 [GitHub Discussion](https://github.com/orgs/OpenRL-Lab/discussions)
diff --git a/examples/arena/README.md b/examples/arena/README.md
index e9d59b91..940bea33 100644
--- a/examples/arena/README.md
+++ b/examples/arena/README.md
@@ -3,6 +3,7 @@
```bash
pip install "openrl[selfplay]"
+pip install "pettingzoo[mpe]","pettingzoo[butterfly]"
```
### Usage
@@ -15,3 +16,11 @@ python run_arena.py
### Evaluate Google Research Football submissions for JiDi locally
If you want to evaluate your Google Research Football submissions for JiDi locally, please try to use tizero as illustrated [here](foothttps://github.com/OpenRL-Lab/TiZero#evaluate-jidi-submissions-locally).
+
+### Evaluate more environments
+
+We also provide a script to evaluate more environments, including MPE, Go, Texas Holdem, Butterfly. You can run the script as follows:
+
+```shell
+python evaluate_more_envs.py
+```
\ No newline at end of file
diff --git a/examples/arena/evaluate_more_envs.py b/examples/arena/evaluate_more_envs.py
new file mode 100644
index 00000000..3b7bfe07
--- /dev/null
+++ b/examples/arena/evaluate_more_envs.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+from pettingzoo.butterfly import cooperative_pong_v5
+from pettingzoo.classic import connect_four_v3, go_v5, rps_v2, texas_holdem_no_limit_v6
+from pettingzoo.mpe import simple_push_v3
+
+from openrl.arena import make_arena
+from openrl.arena.agents.local_agent import LocalAgent
+from openrl.arena.agents.random_agent import RandomAgent
+from openrl.envs.PettingZoo.registration import register
+from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
+
+
+def ConnectFourEnv(render_mode, **kwargs):
+ return connect_four_v3.env(render_mode)
+
+
+def RockPaperScissorsEnv(render_mode, **kwargs):
+ return rps_v2.env(num_actions=3, max_cycles=15)
+
+
+def GoEnv(render_mode, **kwargs):
+ return go_v5.env(render_mode=render_mode, board_size=5, komi=7.5)
+
+
+def TexasHoldemEnv(render_mode, **kwargs):
+ return texas_holdem_no_limit_v6.env(render_mode=render_mode)
+
+
+# MPE
+def SimplePushEnv(render_mode, **kwargs):
+ return simple_push_v3.env(render_mode=render_mode)
+
+
+def CooperativePongEnv(render_mode, **kwargs):
+ return cooperative_pong_v5.env(render_mode=render_mode)
+
+
+def register_new_envs():
+ new_env_dict = {
+ "connect_four_v3": ConnectFourEnv,
+ "RockPaperScissors": RockPaperScissorsEnv,
+ "go_v5": GoEnv,
+ "texas_holdem_no_limit_v6": TexasHoldemEnv,
+ "simple_push_v3": SimplePushEnv,
+ "cooperative_pong_v5": CooperativePongEnv,
+ }
+
+ for env_id, env in new_env_dict.items():
+ register(env_id, env)
+ return new_env_dict.keys()
+
+
+def run_arena(
+ env_id: str,
+ parallel: bool = True,
+ seed=0,
+ total_games: int = 10,
+ max_game_onetime: int = 5,
+):
+ env_wrappers = [RecordWinner]
+
+ arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False)
+
+ agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent")
+ agent2 = RandomAgent()
+
+ arena.reset(
+ agents={"agent1": agent1, "agent2": agent2},
+ total_games=total_games,
+ max_game_onetime=max_game_onetime,
+ seed=seed,
+ )
+ result = arena.run(parallel=parallel)
+ arena.close()
+ print(result)
+ return result
+
+
+def test_new_envs():
+ env_ids = register_new_envs()
+ seed = 0
+ for env_id in env_ids:
+ run_arena(env_id=env_id, seed=seed, parallel=False, total_games=1)
+
+
+if __name__ == "__main__":
+ test_new_envs()
diff --git a/examples/arena/run_arena.py b/examples/arena/run_arena.py
index e880884c..fdc0776a 100644
--- a/examples/arena/run_arena.py
+++ b/examples/arena/run_arena.py
@@ -17,6 +17,7 @@
""""""
from openrl.arena import make_arena
from openrl.arena.agents.local_agent import LocalAgent
+from openrl.arena.agents.random_agent import RandomAgent
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
@@ -37,7 +38,7 @@ def run_arena(
arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=use_tqdm)
agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent")
- agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent")
+ agent2 = RandomAgent()
arena.reset(
agents={"agent1": agent1, "agent2": agent2},
@@ -52,5 +53,12 @@ def run_arena(
if __name__ == "__main__":
- run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10)
- # run_arena(render=False, parallel=False, seed=1, total_games=1, max_game_onetime=1,use_tqdm=False)
+ # run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10)
+ run_arena(
+ render=False,
+ parallel=False,
+ seed=1,
+ total_games=300,
+ max_game_onetime=1,
+ use_tqdm=False,
+ )
diff --git a/examples/atari/train_ppo.py b/examples/atari/train_ppo.py
index 4f122c40..5920e819 100644
--- a/examples/atari/train_ppo.py
+++ b/examples/atari/train_ppo.py
@@ -59,7 +59,6 @@ def train():
agent = Agent(net, use_wandb=True)
# start training, set total number of training steps to 20000
- # agent.train(total_time_steps=1000)
agent.train(total_time_steps=5000000)
env.close()
agent.save("./ppo_agent/")
diff --git a/examples/behavior_cloning/test_env.py b/examples/behavior_cloning/test_env.py
index 60b272c6..fe0fa1b7 100644
--- a/examples/behavior_cloning/test_env.py
+++ b/examples/behavior_cloning/test_env.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/behavior_cloning/train_bc.py b/examples/behavior_cloning/train_bc.py
index 0d562dee..16d2ef2f 100644
--- a/examples/behavior_cloning/train_bc.py
+++ b/examples/behavior_cloning/train_bc.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/cartpole/train_a2c.py b/examples/cartpole/train_a2c.py
index 415f0bba..35ca95a9 100644
--- a/examples/cartpole/train_a2c.py
+++ b/examples/cartpole/train_a2c.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
import torch
diff --git a/examples/cartpole/train_dqn_beta.py b/examples/cartpole/train_dqn_beta.py
index 2dffaa81..3e32ec28 100644
--- a/examples/cartpole/train_dqn_beta.py
+++ b/examples/cartpole/train_dqn_beta.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/cartpole/train_ppo.py b/examples/cartpole/train_ppo.py
index ee11f871..77e41008 100644
--- a/examples/cartpole/train_ppo.py
+++ b/examples/cartpole/train_ppo.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/custom_env/pettingzoo_env.py b/examples/custom_env/pettingzoo_env.py
index 8b173449..d5644b7b 100644
--- a/examples/custom_env/pettingzoo_env.py
+++ b/examples/custom_env/pettingzoo_env.py
@@ -25,6 +25,7 @@
from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper
register("RockPaperScissors", RockPaperScissors)
+
env = make(
"RockPaperScissors",
env_num=10,
diff --git a/examples/custom_env/rock_paper_scissors.py b/examples/custom_env/rock_paper_scissors.py
index 2811a1ff..f18e1841 100644
--- a/examples/custom_env/rock_paper_scissors.py
+++ b/examples/custom_env/rock_paper_scissors.py
@@ -18,6 +18,7 @@
import functools
+import time
import gymnasium
import numpy as np
@@ -54,7 +55,7 @@ class RockPaperScissors(AECEnv):
metadata = {"render_modes": ["human"], "name": "rps_v2"}
- def __init__(self, render_mode=None):
+ def __init__(self, id, render_mode=None):
"""
The init method takes in environment arguments and
should define the following attributes:
@@ -122,8 +123,8 @@ def observe(self, agent):
"""
# observation of one agent is the previous state of the other
# return np.array(self.observations[agent])
- obs = np.zeros(4, dtype=np.int64)
- obs[self.observations[agent]] = 1
+ obs = np.zeros([1, 4], dtype=np.int64)
+ obs[0, self.observations[agent]] = 1
return obs
def close(self):
@@ -182,6 +183,7 @@ def step(self, action):
# handles stepping an agent which is already dead
# accepts a None action for the one agent, and moves the agent_selection to
# the next dead agent, or if there are no more dead agents, to the next live agent
+ action = None
self._was_dead_step(action)
return
diff --git a/examples/ddpg/train_ddpg_beta.py b/examples/ddpg/train_ddpg_beta.py
index 2a19f557..7ba61ee0 100644
--- a/examples/ddpg/train_ddpg_beta.py
+++ b/examples/ddpg/train_ddpg_beta.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/envpool/README.md b/examples/envpool/README.md
new file mode 100644
index 00000000..e9a16389
--- /dev/null
+++ b/examples/envpool/README.md
@@ -0,0 +1,20 @@
+## Installation
+
+
+Install envpool with:
+
+``` shell
+pip install envpool
+```
+
+Note 1: envpool only supports Linux operating system.
+
+## Usage
+
+You can use `OpenRL` to train Cartpole (envpool) via:
+
+``` shell
+PYTHON_PATH train_ppo.py
+```
+
+You can also add custom wrappers in `envpool_wrapper.py`. Currently we have `VecAdapter` and `VecMonitor` wrappers.
\ No newline at end of file
diff --git a/examples/envpool/envpool_wrappers.py b/examples/envpool/envpool_wrappers.py
new file mode 100644
index 00000000..bf975166
--- /dev/null
+++ b/examples/envpool/envpool_wrappers.py
@@ -0,0 +1,181 @@
+import time
+import warnings
+from typing import Optional
+
+import gym
+import gymnasium
+import numpy as np
+from envpool.python.protocol import EnvPool
+from packaging import version
+from stable_baselines3.common.vec_env import VecEnvWrapper as BaseWrapper
+from stable_baselines3.common.vec_env import VecMonitor
+from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn
+
+is_legacy_gym = version.parse(gym.__version__) < version.parse("0.26.0")
+
+
+class VecEnvWrapper(BaseWrapper):
+ @property
+ def agent_num(self):
+ if self.is_original_envpool_env():
+ return 1
+ else:
+ return self.env.agent_num
+
+ def is_original_envpool_env(self):
+ return not hasattr(self.venv, "agent_num`")
+
+
+class VecAdapter(VecEnvWrapper):
+ """
+ Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv.
+
+ :param venv: The envpool object.
+ """
+
+ def __init__(self, venv: EnvPool):
+ venv.num_envs = venv.spec.config.num_envs
+ observation_space = venv.observation_space
+ new_observation_space = gymnasium.spaces.Box(
+ low=observation_space.low,
+ high=observation_space.high,
+ dtype=observation_space.dtype,
+ )
+ action_space = venv.action_space
+ if isinstance(action_space, gym.spaces.Discrete):
+ new_action_space = gymnasium.spaces.Discrete(action_space.n)
+ elif isinstance(action_space, gym.spaces.MultiDiscrete):
+ new_action_space = gymnasium.spaces.MultiDiscrete(action_space.nvec)
+ elif isinstance(action_space, gym.spaces.MultiBinary):
+ new_action_space = gymnasium.spaces.MultiBinary(action_space.n)
+ elif isinstance(action_space, gym.spaces.Box):
+ new_action_space = gymnasium.spaces.Box(
+ low=action_space.low,
+ high=action_space.high,
+ dtype=action_space.dtype,
+ )
+ else:
+ raise NotImplementedError(f"Action space {action_space} is not supported")
+ super().__init__(
+ venv=venv,
+ observation_space=new_observation_space,
+ action_space=new_action_space,
+ )
+
+ def step_async(self, actions: np.ndarray) -> None:
+ self.actions = actions
+
+ def reset(self) -> VecEnvObs:
+ if is_legacy_gym:
+ return self.venv.reset(), {}
+ else:
+ return self.venv.reset()
+
+ def step_wait(self) -> VecEnvStepReturn:
+ if is_legacy_gym:
+ obs, rewards, dones, info_dict = self.venv.step(self.actions)
+ else:
+ obs, rewards, terms, truncs, info_dict = self.venv.step(self.actions)
+ dones = terms + truncs
+ rewards = rewards
+ infos = []
+ for i in range(self.num_envs):
+ infos.append(
+ {
+ key: info_dict[key][i]
+ for key in info_dict.keys()
+ if isinstance(info_dict[key], np.ndarray)
+ }
+ )
+ if dones[i]:
+ infos[i]["terminal_observation"] = obs[i]
+ if is_legacy_gym:
+ obs[i] = self.venv.reset(np.array([i]))
+ else:
+ obs[i] = self.venv.reset(np.array([i]))[0]
+ return obs, rewards, dones, infos
+
+
+class VecMonitor(VecEnvWrapper):
+ def __init__(
+ self,
+ venv,
+ filename: Optional[str] = None,
+ info_keywords=(),
+ ):
+ # Avoid circular import
+ from stable_baselines3.common.monitor import Monitor, ResultsWriter
+
+ try:
+ is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
+ except AttributeError:
+ is_wrapped_with_monitor = False
+
+ if is_wrapped_with_monitor:
+ warnings.warn(
+ "The environment is already wrapped with a `Monitor` wrapperbut you are"
+ " wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics"
+ " will beoverwritten by the `VecMonitor` ones.",
+ UserWarning,
+ )
+
+ VecEnvWrapper.__init__(self, venv)
+ self.episode_count = 0
+ self.t_start = time.time()
+
+ env_id = None
+ if hasattr(venv, "spec") and venv.spec is not None:
+ env_id = venv.spec.id
+
+ self.results_writer: Optional[ResultsWriter] = None
+ if filename:
+ self.results_writer = ResultsWriter(
+ filename,
+ header={"t_start": self.t_start, "env_id": str(env_id)},
+ extra_keys=info_keywords,
+ )
+
+ self.info_keywords = info_keywords
+ self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
+ self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
+
+ def reset(self, **kwargs) -> VecEnvObs:
+ obs, info = self.venv.reset()
+ self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
+ self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
+ return obs, info
+
+ def step_wait(self) -> VecEnvStepReturn:
+ obs, rewards, dones, infos = self.venv.step_wait()
+ self.episode_returns += rewards
+ self.episode_lengths += 1
+ new_infos = list(infos[:])
+ for i in range(len(dones)):
+ if dones[i]:
+ info = infos[i].copy()
+ episode_return = self.episode_returns[i]
+ episode_length = self.episode_lengths[i]
+ episode_info = {
+ "r": episode_return,
+ "l": episode_length,
+ "t": round(time.time() - self.t_start, 6),
+ }
+ for key in self.info_keywords:
+ episode_info[key] = info[key]
+ info["episode"] = episode_info
+ self.episode_count += 1
+ self.episode_returns[i] = 0
+ self.episode_lengths[i] = 0
+ if self.results_writer:
+ self.results_writer.write_row(episode_info)
+ new_infos[i] = info
+ rewards = np.expand_dims(rewards, 1)
+ return obs, rewards, dones, new_infos
+
+ def close(self) -> None:
+ if self.results_writer:
+ self.results_writer.close()
+ return self.venv.close()
+
+
+__all__ = ["VecAdapter", "VecMonitor"]
diff --git a/examples/envpool/make_env.py b/examples/envpool/make_env.py
new file mode 100644
index 00000000..669ca67a
--- /dev/null
+++ b/examples/envpool/make_env.py
@@ -0,0 +1,131 @@
+import copy
+import inspect
+from typing import Callable, Iterable, List, Optional, Union
+
+import envpool
+from gymnasium import Env
+
+from openrl.envs.vec_env import (
+ AsyncVectorEnv,
+ RewardWrapper,
+ SyncVectorEnv,
+ VecMonitorWrapper,
+)
+from openrl.envs.vec_env.vec_info import VecInfoFactory
+from openrl.envs.wrappers.base_wrapper import BaseWrapper
+from openrl.rewards import RewardFactory
+
+
+def build_envs(
+ make,
+ id: str,
+ env_num: int = 1,
+ wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None,
+ need_env_id: bool = False,
+ **kwargs,
+) -> List[Callable[[], Env]]:
+ cfg = kwargs.get("cfg", None)
+
+ def create_env(env_id: int, env_num: int, need_env_id: bool) -> Callable[[], Env]:
+ def _make_env() -> Env:
+ new_kwargs = copy.deepcopy(kwargs)
+ if need_env_id:
+ new_kwargs["env_id"] = env_id
+ new_kwargs["env_num"] = env_num
+ if "envpool" in new_kwargs:
+ # for now envpool doesnt support any render mode
+ # envpool also doesnt stores the id anywhere
+ new_kwargs.pop("envpool")
+ env = make(
+ id,
+ **new_kwargs,
+ )
+ env.unwrapped.spec.id = id
+
+ if wrappers is not None:
+ if callable(wrappers):
+ if issubclass(wrappers, BaseWrapper):
+ env = wrappers(env, cfg=cfg)
+ else:
+ env = wrappers(env)
+ elif isinstance(wrappers, Iterable) and all(
+ [callable(w) for w in wrappers]
+ ):
+ for wrapper in wrappers:
+ if (
+ issubclass(wrapper, BaseWrapper)
+ and "cfg" in inspect.signature(wrapper.__init__).parameters
+ ):
+ env = wrapper(env, cfg=cfg)
+ else:
+ env = wrapper(env)
+ else:
+ raise NotImplementedError
+
+ return env
+
+ return _make_env
+
+ env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)]
+ return env_fns
+
+
+def make_envpool_envs(
+ id: str,
+ env_num: int = 1,
+ **kwargs,
+):
+ assert "env_type" in kwargs
+ assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"]
+ kwargs["envpool"] = True
+
+ if "env_wrappers" in kwargs:
+ env_wrappers = kwargs.pop("env_wrappers")
+ else:
+ env_wrappers = []
+ env_fns = build_envs(
+ make=envpool.make,
+ id=id,
+ env_num=env_num,
+ wrappers=env_wrappers,
+ **kwargs,
+ )
+ return env_fns
+
+
+def make(
+ id: str,
+ env_num: int = 1,
+ asynchronous: bool = False,
+ add_monitor: bool = True,
+ render_mode: Optional[str] = None,
+ auto_reset: bool = True,
+ **kwargs,
+):
+ cfg = kwargs.get("cfg", None)
+ if id in envpool.registration.list_all_envs():
+ env_fns = make_envpool_envs(
+ id=id.split(":")[-1],
+ env_num=env_num,
+ **kwargs,
+ )
+ if asynchronous:
+ env = AsyncVectorEnv(
+ env_fns, render_mode=render_mode, auto_reset=auto_reset
+ )
+ else:
+ env = SyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset)
+
+ reward_class = cfg.reward_class if cfg else None
+ reward_class = RewardFactory.get_reward_class(reward_class, env)
+
+ env = RewardWrapper(env, reward_class)
+
+ if add_monitor:
+ vec_info_class = cfg.vec_info_class if cfg else None
+ vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env)
+ env = VecMonitorWrapper(vec_info_class, env)
+
+ return env
+ else:
+ raise NotImplementedError(f"env {id} is not supported")
diff --git a/examples/envpool/train_ppo.py b/examples/envpool/train_ppo.py
new file mode 100644
index 00000000..b6550b96
--- /dev/null
+++ b/examples/envpool/train_ppo.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import numpy as np
+from make_env import make
+
+from examples.envpool.envpool_wrappers import VecAdapter, VecMonitor
+from openrl.configs.config import create_config_parser
+from openrl.modules.common import PPONet as Net
+from openrl.modules.common.ppo_net import PPONet as Net
+from openrl.runners.common import PPOAgent as Agent
+
+
+def train():
+ # create the neural network
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args()
+
+ # create environment, set environment parallelism to 9
+ env = make(
+ "CartPole-v1",
+ render_mode=None,
+ env_num=9,
+ asynchronous=False,
+ env_wrappers=[VecAdapter, VecMonitor],
+ env_type="gym",
+ )
+
+ net = Net(
+ env,
+ cfg=cfg,
+ )
+ # initialize the trainer
+ agent = Agent(net, use_wandb=False, project_name="CartPole-v1")
+ # start training, set total number of training steps to 20000
+ agent.train(total_time_steps=20000)
+
+ env.close()
+ return agent
+
+
+def evaluation(agent):
+ # begin to test
+ # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
+ render_mode = "group_human"
+ render_mode = None
+ env = make(
+ "CartPole-v1",
+ env_wrappers=[VecAdapter, VecMonitor],
+ render_mode=render_mode,
+ env_num=9,
+ asynchronous=True,
+ env_type="gym",
+ )
+ # The trained agent sets up the interactive environment it needs.
+ agent.set_env(env)
+ # Initialize the environment and get initial observations and environmental information.
+ obs, info = env.reset()
+ done = False
+ step = 0
+ total_step, total_reward = 0, 0
+ while not np.any(done):
+ # Based on environmental observation input, predict next action.
+ action, _ = agent.act(obs, deterministic=True)
+ obs, r, done, info = env.step(action)
+ step += 1
+ total_step += 1
+ total_reward += np.mean(r)
+ if step % 50 == 0:
+ print(f"{step}: reward:{np.mean(r)}")
+ env.close()
+ print("total step:", total_step)
+ print("total reward:", total_reward)
+
+
+if __name__ == "__main__":
+ agent = train()
+ evaluation(agent)
diff --git a/examples/gail/train_gail.py b/examples/gail/train_gail.py
index abe73039..4e227be9 100644
--- a/examples/gail/train_gail.py
+++ b/examples/gail/train_gail.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/gridworld/train_dqn.py b/examples/gridworld/train_dqn.py
index 2b859784..900a1287 100644
--- a/examples/gridworld/train_dqn.py
+++ b/examples/gridworld/train_dqn.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/gridworld/train_ppo.py b/examples/gridworld/train_ppo.py
index 683e9579..71f59bcb 100644
--- a/examples/gridworld/train_ppo.py
+++ b/examples/gridworld/train_ppo.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/nlp/README.md b/examples/nlp/README.md
index 6bcbb7c0..2fb61de9 100644
--- a/examples/nlp/README.md
+++ b/examples/nlp/README.md
@@ -6,6 +6,14 @@ Users can train the dialog task via:
python train_ppo.py --config nlp_ppo.yaml
```
+Users can train the dialog task with deepspeed via:
+
+```shell
+deepspeed train_ppo.py --config nlp_ppo_ds.yaml
+
+
+```
+
After the training, users can chat with the agent via:
```shell
diff --git a/examples/nlp/ds_config.json b/examples/nlp/ds_config.json
new file mode 100644
index 00000000..3de0eb2d
--- /dev/null
+++ b/examples/nlp/ds_config.json
@@ -0,0 +1,9 @@
+{
+ "train_batch_size": 32,
+ "train_micro_batch_size_per_gpu": 16,
+ "steps_per_print": 10,
+ "zero_optimization": {
+ "stage": 2
+ },
+ "fp16": {"enabled": false, "loss_scale_window": 100}
+}
\ No newline at end of file
diff --git a/examples/nlp/eval_ds_config.json b/examples/nlp/eval_ds_config.json
new file mode 100644
index 00000000..58c08252
--- /dev/null
+++ b/examples/nlp/eval_ds_config.json
@@ -0,0 +1,10 @@
+{
+ "train_batch_size": 32,
+ "train_micro_batch_size_per_gpu": 16,
+ "steps_per_print": 10,
+ "zero_optimization": {
+ "stage": 0,
+ "offload_param": {"device": "cpu"}
+},
+ "fp16": {"enabled": false}
+}
\ No newline at end of file
diff --git a/examples/nlp/nlp_ppo.yaml b/examples/nlp/nlp_ppo.yaml
index 47da5280..918a75b8 100644
--- a/examples/nlp/nlp_ppo.yaml
+++ b/examples/nlp/nlp_ppo.yaml
@@ -1,19 +1,16 @@
seed: 0
-lr: 1e-6
-critic_lr: 1e-6
+lr: 1e-7
+critic_lr: 1e-7
run_dir: ./run_results/
log_interval: 1
-use_recurrent_policy: true
use_valuenorm: true
use_adv_normalize: true
wandb_entity: "openrl-lab"
ppo_epoch: 5
episode_length: 128
num_mini_batch: 20
-use_share_model: true
-use_amp: true
+
hidden_size: 1
-data_chunk_length: 1
model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
env:
@@ -25,8 +22,8 @@ vec_info_class:
id: "NLPVecInfo"
reward_class:
id: "NLPReward"
- args: {
- "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier",
+ args: {
"ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog",
+ "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier",
}
\ No newline at end of file
diff --git a/examples/nlp/nlp_ppo_ds.yaml b/examples/nlp/nlp_ppo_ds.yaml
new file mode 100644
index 00000000..88dac18c
--- /dev/null
+++ b/examples/nlp/nlp_ppo_ds.yaml
@@ -0,0 +1,37 @@
+seed: 0
+lr: 1e-7
+critic_lr: 1e-7
+run_dir: ./run_results/
+log_interval: 1
+use_valuenorm: true
+use_adv_normalize: true
+wandb_entity: "openrl-lab"
+ppo_epoch: 5
+episode_length: 128
+num_mini_batch: 20
+
+hidden_size: 1
+
+use_deepspeed: true
+use_fp16: false
+use_offload: false
+deepspeed_config: ds_config.json
+
+model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
+env:
+ args: {
+ 'tokenizer_path': 'gpt2',
+ 'data_path': 'daily_dialog',
+ }
+vec_info_class:
+ id: "NLPVecInfo"
+reward_class:
+ id: "NLPReward"
+ args: {
+ "use_deepspeed": true,
+ "ref_ds_config": "eval_ds_config.json",
+ "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog",
+ "intent_ds_config": "eval_ds_config.json",
+ "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier",
+ }
+
\ No newline at end of file
diff --git a/examples/nlp/train_ppo.py b/examples/nlp/train_ppo.py
index f549c122..4fefcf52 100644
--- a/examples/nlp/train_ppo.py
+++ b/examples/nlp/train_ppo.py
@@ -1,19 +1,25 @@
""""""
+
from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
-from openrl.modules.networks.policy_value_network_gpt import (
- PolicyValueNetworkGPT as PolicyValueNetwork,
-)
+from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork
+from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork
from openrl.runners.common import PPOAgent as Agent
def train():
# create environment
cfg_parser = create_config_parser()
+ try:
+ import deepspeed
+
+ cfg_parser = deepspeed.add_config_arguments(cfg_parser)
+ except:
+ print("choose not to use deepspeed in the nlp task")
cfg = cfg_parser.parse_args()
- env_num = 10
+ env_num = 5
env = make(
"daily_dialog",
env_num=env_num,
@@ -22,7 +28,7 @@ def train():
)
# create the neural network
- model_dict = {"model": PolicyValueNetwork}
+ model_dict = {"policy": PolicyNetwork, "critic": ValueNetwork}
net = Net(env, device="cuda", cfg=cfg, model_dict=model_dict)
# initialize the trainer
diff --git a/examples/retro/train_retro.py b/examples/retro/train_retro.py
index ad13749a..0668b620 100644
--- a/examples/retro/train_retro.py
+++ b/examples/retro/train_retro.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from custom_registration import make
diff --git a/examples/sac/train_ddpg.py b/examples/sac/train_ddpg.py
index 484a1f6d..5bc2bab8 100644
--- a/examples/sac/train_ddpg.py
+++ b/examples/sac/train_ddpg.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/sac/train_sac_beta.py b/examples/sac/train_sac_beta.py
index 9fa905a8..bc40c1dc 100644
--- a/examples/sac/train_sac_beta.py
+++ b/examples/sac/train_sac_beta.py
@@ -1,4 +1,5 @@
""""""
+
import numpy as np
from openrl.configs.config import create_config_parser
diff --git a/examples/selfplay/selfplay.yaml b/examples/selfplay/selfplay.yaml
index 7a7c1bbe..8a05611d 100644
--- a/examples/selfplay/selfplay.yaml
+++ b/examples/selfplay/selfplay.yaml
@@ -1,6 +1,6 @@
globals:
selfplay_api_host: 127.0.0.1
- selfplay_api_port: 10086
+ selfplay_api_port: 13486
seed: 0
selfplay_api:
diff --git a/examples/smac/custom_vecinfo.py b/examples/smac/custom_vecinfo.py
index 52a2b5b2..ba39f6e1 100644
--- a/examples/smac/custom_vecinfo.py
+++ b/examples/smac/custom_vecinfo.py
@@ -41,10 +41,10 @@ def statistics(self, buffer: Any) -> Dict[str, Any]:
assert (
"game_state" in singe_env_info["final_info"].keys()
), "game_state must be in info"
- assert singe_env_info["final_info"]["game_state"] in [
- "win",
- "lose",
- ], "game_state in the final_info must be win or lose"
+ # assert singe_env_info["final_info"]["game_state"] in [
+ # "win",
+ # "lose",
+ # ], "game_state in the final_info must be win or lose"
self.win_history.append(
singe_env_info["final_info"]["game_state"] == "win"
)
diff --git a/examples/smac/train_ppo.py b/examples/smac/train_ppo.py
index 32f8acff..4c03d295 100644
--- a/examples/smac/train_ppo.py
+++ b/examples/smac/train_ppo.py
@@ -25,7 +25,8 @@ def train():
# create environment
env_num = 8
env = make(
- "2s_vs_1sc",
+ "3m",
+ # "2s_vs_1sc",
env_num=env_num,
asynchronous=True,
cfg=cfg,
diff --git a/examples/smacv2/custom_vecinfo.py b/examples/smacv2/custom_vecinfo.py
index 6dd90d00..48fc210d 100644
--- a/examples/smacv2/custom_vecinfo.py
+++ b/examples/smacv2/custom_vecinfo.py
@@ -33,21 +33,21 @@ def __init__(self, *args, **kwargs):
def statistics(self, buffer: Any) -> Dict[str, Any]:
info_dict = super().statistics(buffer)
- """for step_info in self.infos:
+ for step_info in self.infos:
for singe_env_info in step_info:
assert isinstance(singe_env_info, dict), "singe_env_info must be dict"
if "final_info" in singe_env_info.keys():
assert (
"game_state" in singe_env_info["final_info"].keys()
- ), "game_state must be in info"
- assert singe_env_info["final_info"]["game_state"] in [
- "win",
- "lose",
- ], "game_state in the final_info must be win or lose"
+ ), "win_state must be in info"
+ # assert singe_env_info["final_info"]["game_state"] in [
+ # "win",
+ # "lose",
+ # ], "win_state in the final_info must be win or lose"
self.win_history.append(
singe_env_info["final_info"]["game_state"] == "win"
- )"""
+ )
if len(self.win_history) > 0:
info_dict["win_rate"] = np.mean(self.win_history)
diff --git a/examples/toy_env/train_ppo.py b/examples/toy_env/train_ppo.py
index 49cb0c9f..6410b52a 100644
--- a/examples/toy_env/train_ppo.py
+++ b/examples/toy_env/train_ppo.py
@@ -1,4 +1,5 @@
""""""
+
from train_and_eval import evaluation, train
from openrl.modules.common import PPONet as Net
diff --git a/openrl/__init__.py b/openrl/__init__.py
index 00bcaacf..2ea67943 100644
--- a/openrl/__init__.py
+++ b/openrl/__init__.py
@@ -1,5 +1,5 @@
__TITLE__ = "openrl"
-__VERSION__ = "v0.1.7"
+__VERSION__ = "v0.2.0"
__DESCRIPTION__ = "Distributed Deep RL Framework"
__AUTHOR__ = "OpenRL Contributors"
__EMAIL__ = "huangshiyu@4paradigm.com"
diff --git a/openrl/algorithms/dqn.py b/openrl/algorithms/dqn.py
index bbca547b..ebd8d727 100644
--- a/openrl/algorithms/dqn.py
+++ b/openrl/algorithms/dqn.py
@@ -167,7 +167,9 @@ def prepare_loss(
)
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
- q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
+ q_loss = torch.mean(
+ F.mse_loss(q_values, q_targets.detach())
+ ) # 均方误差损失函数
loss_list.append(q_loss)
diff --git a/openrl/algorithms/ppo.py b/openrl/algorithms/ppo.py
index 1c226645..18b5f2c0 100644
--- a/openrl/algorithms/ppo.py
+++ b/openrl/algorithms/ppo.py
@@ -41,10 +41,12 @@ def __init__(
self.use_joint_action_loss = cfg.use_joint_action_loss
super(PPOAlgorithm, self).__init__(cfg, init_module, agent_num, device)
self.train_list = [self.train_ppo]
+ self.use_deepspeed = cfg.use_deepspeed
def ppo_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
- optimizer.zero_grad()
+ if not self.use_deepspeed:
+ optimizer.zero_grad()
(
critic_obs_batch,
@@ -108,8 +110,18 @@ def ppo_update(self, sample, turn_on=True):
active_masks_batch,
turn_on,
)
- for loss in loss_list:
- loss.backward()
+ if self.use_deepspeed:
+ if self._use_share_model:
+ for loss in loss_list:
+ self.algo_module.models["model"].backward(loss)
+ else:
+ actor_loss = loss_list[0]
+ critic_loss = loss_list[1]
+ self.algo_module.models["policy"].backward(actor_loss)
+ self.algo_module.models["critic"].backward(critic_loss)
+ else:
+ for loss in loss_list:
+ loss.backward()
# else:
if self._use_share_model:
@@ -141,8 +153,15 @@ def ppo_update(self, sample, turn_on=True):
self.algo_module.scaler.update()
else:
- for optimizer in self.algo_module.optimizers.values():
- optimizer.step()
+ if self.use_deepspeed:
+ if self._use_share_model:
+ self.algo_module.optimizers["model"].step()
+ else:
+ self.algo_module.optimizers["policy"].step()
+ self.algo_module.optimizers["critic"].step()
+ else:
+ for optimizer in self.algo_module.optimizers.values():
+ optimizer.step()
if self.world_size > 1:
torch.cuda.synchronize()
@@ -168,7 +187,7 @@ def cal_value_loss(
-self.clip_param, self.clip_param
)
- if self._use_popart or self._use_valuenorm:
+ if (self._use_popart or self._use_valuenorm) and value_normalizer is not None:
value_normalizer.update(return_batch)
error_clipped = (
value_normalizer.normalize(return_batch) - value_pred_clipped
@@ -371,9 +390,12 @@ def train_ppo(self, buffer, turn_on):
].module.value_normalizer
else:
value_normalizer = self.algo_module.get_critic_value_normalizer()
- advantages = buffer.returns[:-1] - value_normalizer.denormalize(
- buffer.value_preds[:-1]
- )
+ if value_normalizer is not None:
+ advantages = buffer.returns[:-1] - value_normalizer.denormalize(
+ buffer.value_preds[:-1]
+ )
+ else:
+ advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
else:
advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
diff --git a/openrl/algorithms/vdn.py b/openrl/algorithms/vdn.py
index f1215c03..83bdb5ed 100644
--- a/openrl/algorithms/vdn.py
+++ b/openrl/algorithms/vdn.py
@@ -211,7 +211,9 @@ def prepare_loss(
rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
- q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
+ q_loss = torch.mean(
+ F.mse_loss(q_values, q_targets.detach())
+ ) # 均方误差损失函数
loss_list.append(q_loss)
return loss_list
diff --git a/openrl/arena/__init__.py b/openrl/arena/__init__.py
index 4bea924d..cb154a9f 100644
--- a/openrl/arena/__init__.py
+++ b/openrl/arena/__init__.py
@@ -30,9 +30,11 @@ def make_arena(
**kwargs,
):
if custom_build_env is None:
+ from openrl.envs import PettingZoo
+
if (
env_id in pettingzoo_all_envs
- or env_id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys()
+ or env_id in PettingZoo.registration.pettingzoo_env_dict.keys()
):
from openrl.envs.PettingZoo import make_PettingZoo_env
diff --git a/openrl/arena/agents/random_agent.py b/openrl/arena/agents/random_agent.py
new file mode 100644
index 00000000..d09e5e15
--- /dev/null
+++ b/openrl/arena/agents/random_agent.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+from openrl.arena.agents.base_agent import BaseAgent
+from openrl.selfplay.opponents.base_opponent import BaseOpponent
+from openrl.selfplay.opponents.random_opponent import RandomOpponent
+from openrl.selfplay.opponents.utils import load_opponent_from_path
+
+
+class RandomAgent(BaseAgent):
+ def __init__(self):
+ super().__init__()
+
+ def _new_agent(self) -> BaseOpponent:
+ return RandomOpponent()
diff --git a/openrl/arena/games/two_player_game.py b/openrl/arena/games/two_player_game.py
index 7a1b4e0e..40585393 100644
--- a/openrl/arena/games/two_player_game.py
+++ b/openrl/arena/games/two_player_game.py
@@ -31,9 +31,10 @@ def default_dispatch_func(
players: List[str],
agent_names: List[str],
) -> Dict[str, str]:
- assert len(players) == len(
- agent_names
- ), "The number of players must be equal to the number of agents."
+ assert len(players) == len(agent_names), (
+ f"The number of players {len(players)} must be equal to the number of"
+ f" agents: {len(agent_names)}."
+ )
assert len(players) == 2, "The number of players must be equal to 2."
np_random.shuffle(agent_names)
return dict(zip(players, agent_names))
@@ -49,20 +50,21 @@ def _run(self, env_fn: Callable, agents: List[BaseAgent]):
for player, agent in player2agent.items():
agent.reset(env, player)
result = {}
+ truncation_dict = {}
while True:
termination = False
info = {}
for player_name in env.agent_iter():
observation, reward, termination, truncation, info = env.last()
-
- if termination:
+ truncation_dict[player_name] = truncation
+ if termination or all(truncation_dict.values()):
break
action = player2agent[player_name].act(
player_name, observation, reward, termination, truncation, info
)
env.step(action)
- if termination:
+ if termination or all(truncation_dict.values()):
assert "winners" in info, "The game is terminated but no winners."
assert "losers" in info, "The game is terminated but no losers."
diff --git a/openrl/buffers/offpolicy_replay_data.py b/openrl/buffers/offpolicy_replay_data.py
index 4d62d53f..31e52e85 100644
--- a/openrl/buffers/offpolicy_replay_data.py
+++ b/openrl/buffers/offpolicy_replay_data.py
@@ -97,52 +97,52 @@ def __init__(
)
self.first_insert_flag = True
- def dict_insert(self, data):
- if self._mixed_obs:
- for key in self.critic_obs.keys():
- self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy()
- for key in self.policy_obs.keys():
- self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy()
- for key in self.next_policy_obs.keys():
- self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][
- key
- ].copy()
- for key in self.next_critic_obs.keys():
- self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][
- key
- ].copy()
- else:
- self.critic_obs[self.step + 1] = data["critic_obs"].copy()
- self.policy_obs[self.step + 1] = data["policy_obs"].copy()
- self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy()
- self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy()
-
- if "rnn_states" in data:
- self.rnn_states[self.step + 1] = data["rnn_states"].copy()
- if "rnn_states_critic" in data:
- self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy()
- if "actions" in data:
- self.actions[self.step + 1] = data["actions"].copy()
- if "action_log_probs" in data:
- self.action_log_probs[self.step] = data["action_log_probs"].copy()
-
- if "value_preds" in data:
- self.value_preds[self.step] = data["value_preds"].copy()
- if "rewards" in data:
- self.rewards[self.step + 1] = data["rewards"].copy()
- if "masks" in data:
- self.masks[self.step + 1] = data["masks"].copy()
-
- if "bad_masks" in data:
- self.bad_masks[self.step + 1] = data["bad_masks"].copy()
- if "active_masks" in data:
- self.active_masks[self.step + 1] = data["active_masks"].copy()
- if "action_masks" in data:
- self.action_masks[self.step + 1] = data["action_masks"].copy()
-
- if (self.step + 1) % self.episode_length != 0:
- self.first_insert_flag = False
- self.step = (self.step + 1) % self.episode_length
+ # def dict_insert(self, data):
+ # if self._mixed_obs:
+ # for key in self.critic_obs.keys():
+ # self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy()
+ # for key in self.policy_obs.keys():
+ # self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy()
+ # for key in self.next_policy_obs.keys():
+ # self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][
+ # key
+ # ].copy()
+ # for key in self.next_critic_obs.keys():
+ # self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][
+ # key
+ # ].copy()
+ # else:
+ # self.critic_obs[self.step + 1] = data["critic_obs"].copy()
+ # self.policy_obs[self.step + 1] = data["policy_obs"].copy()
+ # self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy()
+ # self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy()
+ #
+ # if "rnn_states" in data:
+ # self.rnn_states[self.step + 1] = data["rnn_states"].copy()
+ # if "rnn_states_critic" in data:
+ # self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy()
+ # if "actions" in data:
+ # self.actions[self.step + 1] = data["actions"].copy()
+ # if "action_log_probs" in data:
+ # self.action_log_probs[self.step] = data["action_log_probs"].copy()
+ #
+ # if "value_preds" in data:
+ # self.value_preds[self.step] = data["value_preds"].copy()
+ # if "rewards" in data:
+ # self.rewards[self.step + 1] = data["rewards"].copy()
+ # if "masks" in data:
+ # self.masks[self.step + 1] = data["masks"].copy()
+ #
+ # if "bad_masks" in data:
+ # self.bad_masks[self.step + 1] = data["bad_masks"].copy()
+ # if "active_masks" in data:
+ # self.active_masks[self.step + 1] = data["active_masks"].copy()
+ # if "action_masks" in data:
+ # self.action_masks[self.step + 1] = data["action_masks"].copy()
+ #
+ # if (self.step + 1) % self.episode_length != 0:
+ # self.first_insert_flag = False
+ # self.step = (self.step + 1) % self.episode_length
def init_buffer(self, raw_obs, action_masks=None):
critic_obs = get_critic_obs(raw_obs)
diff --git a/openrl/buffers/replay_data.py b/openrl/buffers/replay_data.py
index 40a4b383..a8f4c1b7 100644
--- a/openrl/buffers/replay_data.py
+++ b/openrl/buffers/replay_data.py
@@ -198,49 +198,49 @@ def get_batch_data(
else:
return np.concatenate(data[step])
- def all_batch_data(self, data_name: str, min=None, max=None):
- assert hasattr(self, data_name)
- data = getattr(self, data_name)
-
- if isinstance(data, ObsData):
- return data.all_batch(min, max)
- else:
- return data[min:max].reshape((-1, *data.shape[3:]))
-
- def dict_insert(self, data):
- if self._mixed_obs:
- for key in self.critic_obs.keys():
- self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy()
- for key in self.policy_obs.keys():
- self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy()
- else:
- self.critic_obs[self.step + 1] = data["critic_obs"].copy()
- self.policy_obs[self.step + 1] = data["policy_obs"].copy()
-
- if "rnn_states" in data:
- self.rnn_states[self.step + 1] = data["rnn_states"].copy()
- if "rnn_states_critic" in data:
- self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy()
- if "actions" in data:
- self.actions[self.step] = data["actions"].copy()
- if "action_log_probs" in data:
- self.action_log_probs[self.step] = data["action_log_probs"].copy()
-
- if "value_preds" in data:
- self.value_preds[self.step] = data["value_preds"].copy()
- if "rewards" in data:
- self.rewards[self.step] = data["rewards"].copy()
- if "masks" in data:
- self.masks[self.step + 1] = data["masks"].copy()
-
- if "bad_masks" in data:
- self.bad_masks[self.step + 1] = data["bad_masks"].copy()
- if "active_masks" in data:
- self.active_masks[self.step + 1] = data["active_masks"].copy()
- if "action_masks" in data:
- self.action_masks[self.step + 1] = data["action_masks"].copy()
-
- self.step = (self.step + 1) % self.episode_length
+ # def all_batch_data(self, data_name: str, min=None, max=None):
+ # assert hasattr(self, data_name)
+ # data = getattr(self, data_name)
+ #
+ # if isinstance(data, ObsData):
+ # return data.all_batch(min, max)
+ # else:
+ # return data[min:max].reshape((-1, *data.shape[3:]))
+
+ # def dict_insert(self, data):
+ # if self._mixed_obs:
+ # for key in self.critic_obs.keys():
+ # self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy()
+ # for key in self.policy_obs.keys():
+ # self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy()
+ # else:
+ # self.critic_obs[self.step + 1] = data["critic_obs"].copy()
+ # self.policy_obs[self.step + 1] = data["policy_obs"].copy()
+ #
+ # if "rnn_states" in data:
+ # self.rnn_states[self.step + 1] = data["rnn_states"].copy()
+ # if "rnn_states_critic" in data:
+ # self.rnn_states_critic[self.step + 1] = data["rnn_states_critic"].copy()
+ # if "actions" in data:
+ # self.actions[self.step] = data["actions"].copy()
+ # if "action_log_probs" in data:
+ # self.action_log_probs[self.step] = data["action_log_probs"].copy()
+ #
+ # if "value_preds" in data:
+ # self.value_preds[self.step] = data["value_preds"].copy()
+ # if "rewards" in data:
+ # self.rewards[self.step] = data["rewards"].copy()
+ # if "masks" in data:
+ # self.masks[self.step + 1] = data["masks"].copy()
+ #
+ # if "bad_masks" in data:
+ # self.bad_masks[self.step + 1] = data["bad_masks"].copy()
+ # if "active_masks" in data:
+ # self.active_masks[self.step + 1] = data["active_masks"].copy()
+ # if "action_masks" in data:
+ # self.action_masks[self.step + 1] = data["action_masks"].copy()
+ #
+ # self.step = (self.step + 1) % self.episode_length
def insert(
self,
@@ -323,7 +323,9 @@ def compute_returns(self, next_value, value_normalizer=None):
self.value_preds[-1] = next_value
gae = 0
for step in reversed(range(self.rewards.shape[0])):
- if self._use_popart or self._use_valuenorm:
+ if (
+ self._use_popart or self._use_valuenorm
+ ) and value_normalizer is not None:
# step + 1
delta = (
self.rewards[step]
@@ -357,7 +359,9 @@ def compute_returns(self, next_value, value_normalizer=None):
else:
self.returns[-1] = next_value
for step in reversed(range(self.rewards.shape[0])):
- if self._use_popart or self._use_valuenorm:
+ if (
+ self._use_popart or self._use_valuenorm
+ ) and value_normalizer is not None:
self.returns[step] = (
self.returns[step + 1] * self.gamma * self.masks[step + 1]
+ self.rewards[step]
@@ -947,119 +951,119 @@ def naive_recurrent_generator(self, advantages, num_mini_batch):
yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch
- def recurrent_generator_v2(
- self, advantages, num_mini_batch=None, mini_batch_size=None
- ):
- """
- Yield training data for MLP policies.
- :param advantages: (np.ndarray) advantage estimates.
- :param num_mini_batch: (int) number of minibatches to split the batch into.
- :param mini_batch_size: (int) number of samples in each minibatch.
- """
- episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3]
- batch_size = n_rollout_threads * episode_length
-
- if mini_batch_size is None:
- assert (
- batch_size >= num_mini_batch
- ), (
- "PPO requires the number of processes ({}) "
- "* number of steps ({}) = {} "
- "to be greater than or equal to the number of PPO mini batches ({})."
- "".format(
- n_rollout_threads,
- episode_length,
- n_rollout_threads * episode_length,
- num_mini_batch,
- )
- )
- mini_batch_size = batch_size // num_mini_batch
-
- rand = torch.randperm(batch_size).numpy()
- sampler = [
- rand[i * mini_batch_size : (i + 1) * mini_batch_size]
- for i in range(num_mini_batch)
- ]
-
- # keep (num_agent, dim)
- critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[2:])
-
- policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[2:])
-
- rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:])
-
- rnn_states_critic = self.rnn_states_critic[:-1].reshape(
- -1, *self.rnn_states_critic.shape[2:]
- )
-
- actions = self.actions.reshape(-1, *self.actions.shape[2:])
-
- if self.action_masks is not None:
- action_masks = self.action_masks[:-1].reshape(
- -1, *self.action_masks.shape[2:]
- )
-
- value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:])
-
- returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:])
-
- masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:])
-
- active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:])
-
- action_log_probs = self.action_log_probs.reshape(
- -1, *self.action_log_probs.shape[2:]
- )
-
- advantages = advantages.reshape(-1, *advantages.shape[2:])
-
- shuffle = False
- if shuffle:
- rows, cols = _shuffle_agent_grid(batch_size, num_agents)
-
- if self.action_masks is not None:
- action_masks = action_masks[rows, cols]
- critic_obs = critic_obs[rows, cols]
- policy_obs = policy_obs[rows, cols]
- rnn_states = rnn_states[rows, cols]
- rnn_states_critic = rnn_states_critic[rows, cols]
- actions = actions[rows, cols]
- value_preds = value_preds[rows, cols]
- returns = returns[rows, cols]
- masks = masks[rows, cols]
- active_masks = active_masks[rows, cols]
- action_log_probs = action_log_probs[rows, cols]
- advantages = advantages[rows, cols]
-
- for indices in sampler:
- # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim]
- critic_obs_batch = critic_obs[indices].reshape(-1, *critic_obs.shape[2:])
- policy_obs_batch = policy_obs[indices].reshape(-1, *policy_obs.shape[2:])
- rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:])
- rnn_states_critic_batch = rnn_states_critic[indices].reshape(
- -1, *rnn_states_critic.shape[2:]
- )
- actions_batch = actions[indices].reshape(-1, *actions.shape[2:])
- if self.action_masks is not None:
- action_masks_batch = action_masks[indices].reshape(
- -1, *action_masks.shape[2:]
- )
- else:
- action_masks_batch = None
- value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:])
- return_batch = returns[indices].reshape(-1, *returns.shape[2:])
- masks_batch = masks[indices].reshape(-1, *masks.shape[2:])
- active_masks_batch = active_masks[indices].reshape(
- -1, *active_masks.shape[2:]
- )
- old_action_log_probs_batch = action_log_probs[indices].reshape(
- -1, *action_log_probs.shape[2:]
- )
- if advantages is None:
- adv_targ = None
- else:
- adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:])
- yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch
+ # def recurrent_generator_v2(
+ # self, advantages, num_mini_batch=None, mini_batch_size=None
+ # ):
+ # """
+ # Yield training data for MLP policies.
+ # :param advantages: (np.ndarray) advantage estimates.
+ # :param num_mini_batch: (int) number of minibatches to split the batch into.
+ # :param mini_batch_size: (int) number of samples in each minibatch.
+ # """
+ # episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3]
+ # batch_size = n_rollout_threads * episode_length
+ #
+ # if mini_batch_size is None:
+ # assert (
+ # batch_size >= num_mini_batch
+ # ), (
+ # "PPO requires the number of processes ({}) "
+ # "* number of steps ({}) = {} "
+ # "to be greater than or equal to the number of PPO mini batches ({})."
+ # "".format(
+ # n_rollout_threads,
+ # episode_length,
+ # n_rollout_threads * episode_length,
+ # num_mini_batch,
+ # )
+ # )
+ # mini_batch_size = batch_size // num_mini_batch
+ #
+ # rand = torch.randperm(batch_size).numpy()
+ # sampler = [
+ # rand[i * mini_batch_size : (i + 1) * mini_batch_size]
+ # for i in range(num_mini_batch)
+ # ]
+ #
+ # # keep (num_agent, dim)
+ # critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[2:])
+ #
+ # policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[2:])
+ #
+ # rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:])
+ #
+ # rnn_states_critic = self.rnn_states_critic[:-1].reshape(
+ # -1, *self.rnn_states_critic.shape[2:]
+ # )
+ #
+ # actions = self.actions.reshape(-1, *self.actions.shape[2:])
+ #
+ # if self.action_masks is not None:
+ # action_masks = self.action_masks[:-1].reshape(
+ # -1, *self.action_masks.shape[2:]
+ # )
+ #
+ # value_preds = self.value_preds[:-1].reshape(-1, *self.value_preds.shape[2:])
+ #
+ # returns = self.returns[:-1].reshape(-1, *self.returns.shape[2:])
+ #
+ # masks = self.masks[:-1].reshape(-1, *self.masks.shape[2:])
+ #
+ # active_masks = self.active_masks[:-1].reshape(-1, *self.active_masks.shape[2:])
+ #
+ # action_log_probs = self.action_log_probs.reshape(
+ # -1, *self.action_log_probs.shape[2:]
+ # )
+ #
+ # advantages = advantages.reshape(-1, *advantages.shape[2:])
+ #
+ # shuffle = False
+ # if shuffle:
+ # rows, cols = _shuffle_agent_grid(batch_size, num_agents)
+ #
+ # if self.action_masks is not None:
+ # action_masks = action_masks[rows, cols]
+ # critic_obs = critic_obs[rows, cols]
+ # policy_obs = policy_obs[rows, cols]
+ # rnn_states = rnn_states[rows, cols]
+ # rnn_states_critic = rnn_states_critic[rows, cols]
+ # actions = actions[rows, cols]
+ # value_preds = value_preds[rows, cols]
+ # returns = returns[rows, cols]
+ # masks = masks[rows, cols]
+ # active_masks = active_masks[rows, cols]
+ # action_log_probs = action_log_probs[rows, cols]
+ # advantages = advantages[rows, cols]
+ #
+ # for indices in sampler:
+ # # [L,T,N,Dim]-->[L*T,N,Dim]-->[index,N,Dim]-->[index*N, Dim]
+ # critic_obs_batch = critic_obs[indices].reshape(-1, *critic_obs.shape[2:])
+ # policy_obs_batch = policy_obs[indices].reshape(-1, *policy_obs.shape[2:])
+ # rnn_states_batch = rnn_states[indices].reshape(-1, *rnn_states.shape[2:])
+ # rnn_states_critic_batch = rnn_states_critic[indices].reshape(
+ # -1, *rnn_states_critic.shape[2:]
+ # )
+ # actions_batch = actions[indices].reshape(-1, *actions.shape[2:])
+ # if self.action_masks is not None:
+ # action_masks_batch = action_masks[indices].reshape(
+ # -1, *action_masks.shape[2:]
+ # )
+ # else:
+ # action_masks_batch = None
+ # value_preds_batch = value_preds[indices].reshape(-1, *value_preds.shape[2:])
+ # return_batch = returns[indices].reshape(-1, *returns.shape[2:])
+ # masks_batch = masks[indices].reshape(-1, *masks.shape[2:])
+ # active_masks_batch = active_masks[indices].reshape(
+ # -1, *active_masks.shape[2:]
+ # )
+ # old_action_log_probs_batch = action_log_probs[indices].reshape(
+ # -1, *action_log_probs.shape[2:]
+ # )
+ # if advantages is None:
+ # adv_targ = None
+ # else:
+ # adv_targ = advantages[indices].reshape(-1, *advantages.shape[2:])
+ # yield critic_obs_batch, policy_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, action_masks_batch
def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length):
episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3]
diff --git a/openrl/configs/config.py b/openrl/configs/config.py
index 2a616fe6..49bdae79 100644
--- a/openrl/configs/config.py
+++ b/openrl/configs/config.py
@@ -498,13 +498,14 @@ def create_config_parser():
)
parser.add_argument(
"--use_popart",
- action="store_true",
default=False,
+ type=bool,
help="by default False, use PopArt to normalize rewards.",
)
parser.add_argument(
"--dual_clip_ppo",
default=False,
+ type=bool,
help="by default False, use dual-clip ppo.",
)
parser.add_argument(
@@ -618,7 +619,7 @@ def create_config_parser():
)
parser.add_argument(
"--use_average_pool",
- action="store_false",
+ type=bool,
default=True,
help="by default True, use average pooling for attn model.",
)
@@ -730,8 +731,8 @@ def create_config_parser():
)
parser.add_argument(
"--use_gae",
- action="store_false",
default=True,
+ type=bool,
help="use generalized advantage estimation",
)
parser.add_argument(
@@ -748,8 +749,8 @@ def create_config_parser():
)
parser.add_argument(
"--use_proper_time_limits",
- action="store_true",
default=False,
+ type=bool,
help="compute returns taking into account time limits",
)
parser.add_argument(
@@ -1234,5 +1235,29 @@ def create_config_parser():
type=float,
help="newest_weight",
)
+ parser.add_argument(
+ "--use_deepspeed",
+ default=False,
+ type=bool,
+ help="whether to use deepspeed",
+ )
+ parser.add_argument(
+ "--local_rank",
+ default=-1,
+ type=int,
+ help="local_rank",
+ )
+ parser.add_argument(
+ "--use_offload",
+ default=False,
+ type=bool,
+ help="whether to use offload (deepspeed)",
+ )
+ parser.add_argument(
+ "--use_fp16",
+ default=False,
+ type=bool,
+ help="whether to use fp16 (deepspeed)",
+ )
return parser
diff --git a/openrl/configs/utils.py b/openrl/configs/utils.py
index 53e1f4d2..a8420767 100644
--- a/openrl/configs/utils.py
+++ b/openrl/configs/utils.py
@@ -16,7 +16,7 @@
""""""
-
+import os
import re
import tempfile
@@ -83,9 +83,19 @@ def __call__(self, parser, cfg, values, option_string=None):
# Load the rendered content as a dictionary
data = yaml.safe_load(rendered_content)
- # Write the result to a temporary file
- with tempfile.NamedTemporaryFile("w", delete=True, suffix=".yaml") as temp_file:
+ # Write the result to a temporary file. Not work on Windows.
+ # with tempfile.NamedTemporaryFile("w", delete=True, suffix=".yaml") as temp_file:
+ # yaml.dump(data, temp_file)
+ # temp_file.seek(0) # Move to the beginning of the file
+ # # Use the default behavior of ActionConfigFile to handle the temporary file
+ # super().__call__(parser, cfg, temp_file.name, option_string)
+
+ # Write the result to a temporary file. This works on all platforms.
+ temp_fd, temp_filename = tempfile.mkstemp(suffix=".yaml")
+ with os.fdopen(temp_fd, "w") as temp_file:
yaml.dump(data, temp_file)
- temp_file.seek(0) # Move to the beginning of the file
+ try:
# Use the default behavior of ActionConfigFile to handle the temporary file
- super().__call__(parser, cfg, temp_file.name, option_string)
+ super().__call__(parser, cfg, temp_filename, option_string)
+ finally:
+ os.remove(temp_filename)
diff --git a/openrl/envs/__init__.py b/openrl/envs/__init__.py
index a2eb835f..d12c493a 100644
--- a/openrl/envs/__init__.py
+++ b/openrl/envs/__init__.py
@@ -16,12 +16,9 @@
toy_all_envs = [
"BitFlippingEnv",
- "FakeImageEnv",
"IdentityEnv",
"IdentityEnvcontinuous",
"IdentityEnvBox",
- "IdentityEnvMultiBinary",
- "IdentityEnvMultiDiscrete",
"SimpleMultiObsEnv",
"SimpleMultiObsEnv",
]
diff --git a/openrl/envs/common/build_envs.py b/openrl/envs/common/build_envs.py
index 94c34019..76f4b35b 100644
--- a/openrl/envs/common/build_envs.py
+++ b/openrl/envs/common/build_envs.py
@@ -2,6 +2,7 @@
import inspect
from typing import Callable, Iterable, List, Optional, Union
+import gymnasium as gym
from gymnasium import Env
from openrl.envs.wrappers.base_wrapper import BaseWrapper
@@ -33,6 +34,8 @@ def _make_env() -> Env:
if need_env_id:
new_kwargs["env_id"] = env_id
new_kwargs["env_num"] = env_num
+ if id.startswith("ALE/") or id in gym.envs.registry.keys():
+ new_kwargs.pop("cfg", None)
env = make(
id,
diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py
index 099f2b39..5d1ed645 100644
--- a/openrl/envs/common/registration.py
+++ b/openrl/envs/common/registration.py
@@ -20,6 +20,7 @@
import gymnasium as gym
import openrl
+from openrl.envs.PettingZoo.registration import pettingzoo_env_dict
from openrl.envs.vec_env import (
AsyncVectorEnv,
BaseVecEnv,
@@ -107,7 +108,6 @@ def make(
id=id,
env_num=env_num,
render_mode=convert_render_mode,
- cfg=cfg,
**kwargs,
)
elif id in openrl.envs.toy_all_envs:
@@ -149,10 +149,7 @@ def make(
render_mode=convert_render_mode,
**kwargs,
)
- elif (
- id in openrl.envs.pettingzoo_all_envs
- or id in openrl.envs.PettingZoo.registration.pettingzoo_env_dict.keys()
- ):
+ elif id in openrl.envs.pettingzoo_all_envs or id in pettingzoo_env_dict.keys():
from openrl.envs.PettingZoo import make_PettingZoo_envs
env_fns = make_PettingZoo_envs(
diff --git a/openrl/envs/mpe/rendering.py b/openrl/envs/mpe/rendering.py
index ab1a47db..6dae5d66 100644
--- a/openrl/envs/mpe/rendering.py
+++ b/openrl/envs/mpe/rendering.py
@@ -1,6 +1,7 @@
"""
2D rendering framework
"""
+
from __future__ import division
import os
@@ -26,15 +27,14 @@
try:
from pyglet.gl import *
+
except ImportError:
print(
"Error occured while running `from pyglet.gl import *`",
- (
- "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
- " install python-opengl'. If you're running on a server, you may need a"
- " virtual frame buffer; something like this should work: 'xvfb-run -s"
- ' "-screen 0 1400x900x24" python \''
- ),
+ "HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
+ " install python-opengl'. If you're running on a server, you may need a"
+ " virtual frame buffer; something like this should work: 'xvfb-run -s"
+ ' "-screen 0 1400x900x24" python \'',
)
import math
@@ -320,28 +320,6 @@ def make_polyline(v):
return PolyLine(v, False)
-def make_capsule(length, width):
- left, r, t, b = 0, length, width / 2, -width / 2
- box = make_polygon([(left, b), (left, t), (r, t), (r, b)])
- circ0 = make_circle(width / 2)
- circ1 = make_circle(width / 2)
- circ1.add_attr(Transform(translation=(length, 0)))
- geom = Compound([box, circ0, circ1])
- return geom
-
-
-class Compound(Geom):
- def __init__(self, gs):
- Geom.__init__(self)
- self.gs = gs
- for g in self.gs:
- g.attrs = [a for a in g.attrs if not isinstance(a, Color)]
-
- def render1(self):
- for g in self.gs:
- g.render()
-
-
class PolyLine(Geom):
def __init__(self, v, close):
Geom.__init__(self)
@@ -373,59 +351,3 @@ def render1(self):
glVertex2f(*self.start)
glVertex2f(*self.end)
glEnd()
-
-
-class Image(Geom):
- def __init__(self, fname, width, height):
- Geom.__init__(self)
- self.width = width
- self.height = height
- img = pyglet.image.load(fname)
- self.img = img
- self.flip = False
-
- def render1(self):
- self.img.blit(
- -self.width / 2, -self.height / 2, width=self.width, height=self.height
- )
-
-
-# ================================================================
-
-
-class SimpleImageViewer(object):
- def __init__(self, display=None):
- self.window = None
- self.isopen = False
- self.display = display
-
- def imshow(self, arr):
- if self.window is None:
- height, width, channels = arr.shape
- self.window = pyglet.window.Window(
- width=width, height=height, display=self.display
- )
- self.width = width
- self.height = height
- self.isopen = True
- assert arr.shape == (
- self.height,
- self.width,
- 3,
- ), "You passed in an image with the wrong number shape"
- image = pyglet.image.ImageData(
- self.width, self.height, "RGB", arr.tobytes(), pitch=self.width * -3
- )
- self.window.clear()
- self.window.switch_to()
- self.window.dispatch_events()
- image.blit(0, 0)
- self.window.flip()
-
- def close(self):
- if self.isopen:
- self.window.close()
- self.isopen = False
-
- def __del__(self):
- self.close()
diff --git a/openrl/envs/nlp/daily_dialog_env.py b/openrl/envs/nlp/daily_dialog_env.py
index 0c7a6ff7..d197a232 100644
--- a/openrl/envs/nlp/daily_dialog_env.py
+++ b/openrl/envs/nlp/daily_dialog_env.py
@@ -36,11 +36,24 @@ def __init__(
prompt_truncation_side (str): truncation side for prompt text (Defaults to "left")
"""
- self.debug = cfg.env.args["data_path"] is None
+ self.debug = (
+ cfg.env.args["data_path"] is None or cfg.env.args["data_path"] == "None"
+ )
self.env_name = "daily_dialog"
tokenizer_name = cfg.env.args["tokenizer_path"]
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
+ if tokenizer_name == "builtin_BPE":
+ from tokenizers import Tokenizer, models
+
+ self.tokenizer = Tokenizer(models.BPE())
+
+ self.tokenizer.pad_token = ""
+ self.tokenizer.eos_token = ""
+ self.tokenizer.vocab_size = 2
+ self.tokenizer.name_or_path = "builtin_BPE"
+
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
@@ -100,7 +113,7 @@ def __init__(
self.__time_step = None
self.reward_function = None
- def set_reward(self, reward_fn):
+ def set_reward(self, reward_fn=None):
self.reward_function = reward_fn
def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]:
diff --git a/openrl/envs/nlp/rewards/intent.py b/openrl/envs/nlp/rewards/intent.py
index 397ea810..bc4da36c 100644
--- a/openrl/envs/nlp/rewards/intent.py
+++ b/openrl/envs/nlp/rewards/intent.py
@@ -9,25 +9,93 @@
from openrl.supports.opengpu.manager import LocalGPUManager
+def get_default_ds_config(offload=True, stage=0, fp16=True):
+ device = "cpu" if offload else "none"
+ zero_opt_dict = {
+ "stage": stage,
+ "offload_param": {"device": device},
+ }
+ return {
+ "train_batch_size": 16,
+ "train_micro_batch_size_per_gpu": 16,
+ "steps_per_print": 10,
+ "zero_optimization": zero_opt_dict,
+ "fp16": {"enabled": fp16},
+ }
+
+
class Intent:
- def __init__(self, intent_model: str, intent_coeff: float = 1.0) -> None:
+ def __init__(
+ self,
+ intent_model: str,
+ intent_coeff: float = 1.0,
+ use_deepspeed: bool = True,
+ ds_config: str = "default",
+ ) -> None:
super().__init__()
self._intent_coeff = intent_coeff
+ self.use_deepspeed = use_deepspeed
+ self.use_half = False
+ self.use_data_parallel = not use_deepspeed # default to use data parallel
+ self.use_model_parallel = False
- model_path = data_abs_path(intent_model)
- self._tokenizer = AutoTokenizer.from_pretrained(intent_model)
- self._model = AutoModelForSequenceClassification.from_pretrained(model_path)
+ if intent_model == "builtin_intent":
+ self._device = "cpu"
+ self.use_data_parallel = False
+
+ from transformers import GPT2Config, GPT2LMHeadModel
+
+ class TestTokenizer:
+ def __call__(
+ self,
+ input_texts,
+ return_tensors="pt",
+ truncation=True,
+ padding=True,
+ max_length=None,
+ ):
+ class EncodedOutput:
+ def __init__(self, input_ids, attention_mask):
+ self.input_ids = input_ids
+ self.attention_mask = attention_mask
+
+ input_ids = torch.zeros((32), dtype=torch.long)
+ attention_masks = torch.zeros((32), dtype=torch.long)
+ return EncodedOutput(input_ids, attention_masks)
+
+ self._tokenizer = TestTokenizer()
+ config = GPT2Config()
+ self._model = GPT2LMHeadModel(config)
- if torch.cuda.is_available():
- manager = LocalGPUManager()
- manager.log_info()
- self._device = f"cuda:{manager.get_gpu()}"
else:
- self._device = "cpu"
- print("Intent Model choose to use device:{}".format(self._device))
+ self._device = "cuda"
+ model_path = data_abs_path(intent_model)
+ self._tokenizer = AutoTokenizer.from_pretrained(intent_model)
+ self._model = AutoModelForSequenceClassification.from_pretrained(model_path)
+
+ if self.use_deepspeed:
+ import deepspeed
- self._model = self._model.to(self._device)
+ if ds_config == "default":
+ ds_config = get_default_ds_config()
+ else:
+ import json
+
+ with open(ds_config) as file:
+ ds_config = json.load(file)
+
+ self._model = self._model.to(self._device)
+ self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config)
+ self.use_fp16 = ds_config["fp16"]["enabled"]
+ else:
+ if self.use_model_parallel:
+ self._model.parallelize()
+ elif self.use_data_parallel:
+ if self.use_half:
+ self._model = self._model.half()
+ self._model = torch.nn.DataParallel(self._model)
+ self._model = self._model.to(self._device)
def __call__(
self,
@@ -58,11 +126,19 @@ def get_input_for_classifier(prompt, generated_text):
input_texts, return_tensors="pt", truncation=True, padding=True
)
+ if self.use_half:
+ encoded.input_ids = encoded.input_ids.int()
+ encoded.attention_mask = encoded.attention_mask.int()
+ else:
+ encoded.input_ids = encoded.input_ids.long()
+ encoded.attention_mask = encoded.attention_mask.long()
+
with torch.no_grad():
outputs = self._model(
input_ids=encoded.input_ids.to(self._device),
attention_mask=encoded.attention_mask.to(self._device),
)
+
pred_labels = torch.argmax(outputs.logits, dim=1).tolist()
score = (np.array(pred_labels) == np.array(target_intents)) * 1.0
diff --git a/openrl/envs/nlp/rewards/kl_penalty.py b/openrl/envs/nlp/rewards/kl_penalty.py
index 039d82d5..c98c6bfb 100644
--- a/openrl/envs/nlp/rewards/kl_penalty.py
+++ b/openrl/envs/nlp/rewards/kl_penalty.py
@@ -10,24 +10,79 @@
from openrl.envs.nlp.utils.distribution import CategoricalDistribution
+def get_default_ds_config(offload=True, stage=0, fp16=True):
+ device = "cpu" if offload else "none"
+ zero_opt_dict = {
+ "stage": stage,
+ "offload_param": {"device": device},
+ }
+ return {
+ "train_batch_size": 16,
+ "train_micro_batch_size_per_gpu": 16,
+ "steps_per_print": 10,
+ "zero_optimization": zero_opt_dict,
+ "fp16": {"enabled": fp16},
+ }
+
+
class KLPenalty(nn.Module):
def __init__(
self,
action_space: gym.Space,
ref_model: str,
apply_model_parallel: bool = True,
+ use_deepspeed: bool = True,
+ ds_config: str = "default",
):
super().__init__()
+ self.device = "cuda"
+ self.use_deepspeed = use_deepspeed
+ self.use_half = False
+ self.use_data_parallel = not use_deepspeed
+ self.use_model_parallel = False
+ assert not (self.use_deepspeed and self.use_data_parallel)
+ assert not (self.use_deepspeed and self.use_model_parallel)
+ assert not (self.use_data_parallel and self.use_model_parallel)
+
# reference model
- self._apply_model_parallel = apply_model_parallel
- self._ref_net = AutoModelForCausalLM.from_pretrained(ref_model)
+ if ref_model == "builtin_ref":
+ self.device = "cpu"
+ self.use_data_parallel = False
+
+ from transformers import GPT2Config, GPT2LMHeadModel
+
+ config = GPT2Config()
+ self._ref_net = GPT2LMHeadModel(config)
+ else:
+ self._ref_net = AutoModelForCausalLM.from_pretrained(ref_model)
self._ref_net = self._ref_net.eval()
- if torch.cuda.is_available():
- if self._apply_model_parallel and self._ref_net.is_parallelizable:
+ if self.use_deepspeed:
+ import deepspeed
+
+ if ds_config == "default":
+ self.use_fp16 = True
+ ds_config = get_default_ds_config()
+ else:
+ import json
+
+ with open(ds_config) as file:
+ ds_config = json.load(file)
+ if "fp16" in ds_config:
+ self.use_fp16 = ds_config["fp16"]["enabled"]
+ else:
+ self.use_fp16 = False
+
+ self._ref_engine, *_ = deepspeed.initialize(model=self, config=ds_config)
+ else:
+ if self.use_model_parallel:
self._ref_net.parallelize()
- else: # else defaults to data parallel
- self._ref_net = torch.nn.DataParallel(self._ref_net)
+ elif self.use_data_parallel: # else defaults to data parallel
+ if self.use_half:
+ self._ref_net = self._ref_net.half()
+ else:
+ self._ref_net = torch.nn.DataParallel(self._ref_net)
+ self._ref_net = self._ref_net.to(self.device)
# alpha adjustment
self._alpha = 0.2
@@ -61,20 +116,31 @@ def __call__(
past_model_kwargs = {
"attention_mask": attention_mask,
}
-
model_inputs = self._prepare_inputs_for_model(
self._ref_net, input_ids, past_model_kwargs
)
+ if self.use_half:
+ for key in ["input_ids", "position_ids", "attention_mask"]:
+ if key in model_inputs:
+ model_inputs[key] = model_inputs[key].int()
+ else:
+ for key in ["input_ids", "position_ids", "attention_mask"]:
+ if key in model_inputs:
+ model_inputs[key] = model_inputs[key].long()
+
with torch.no_grad():
output = self._ref_net(output_hidden_states=True, **model_inputs)
output["past_key_values"] = None
next_token_logits = output.logits[:, -1, :]
+ if self.use_deepspeed and self.use_fp16:
+ next_token_logits = next_token_logits.double()
dist = self._action_dist.proba_distribution(action_logits=next_token_logits)
action_input = actions.to(next_token_logits.device)
ref_log_prob = dist.log_prob(action_input)
ref_log_prob = ref_log_prob.reshape(action_log_probs.shape)
+
kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy()
rew = -self._alpha * kl_div
infos = []
@@ -97,7 +163,7 @@ def _prepare_inputs_for_model(
input_ids, **model_kwargs
)
- if self._apply_model_parallel and unwrap_model(model).is_parallelizable:
+ if self.use_model_parallel:
# if model is in parallel mode, move the tensors to the first device
model_inputs = {
key: (
@@ -108,4 +174,15 @@ def _prepare_inputs_for_model(
)
for key, value in model_inputs.items()
}
+ elif self.use_data_parallel:
+ model_inputs = {
+ key: value.to(self.device) if isinstance(value, torch.Tensor) else value
+ for key, value in model_inputs.items()
+ }
+ elif self.use_deepspeed:
+ model_inputs = {
+ key: value.to("cuda") if isinstance(value, torch.Tensor) else value
+ for key, value in model_inputs.items()
+ }
+
return model_inputs
diff --git a/openrl/envs/nlp/rewards/meteor.py b/openrl/envs/nlp/rewards/meteor.py
index c9acd16f..5bd169ad 100644
--- a/openrl/envs/nlp/rewards/meteor.py
+++ b/openrl/envs/nlp/rewards/meteor.py
@@ -6,13 +6,21 @@
import openrl.envs.nlp as nlp
+class VirtualMetric:
+ def compute(self, predictions: Any, references: Any) -> Dict[str, float]:
+ return {"meteor": 0.0}
+
+
class Meteor:
- def __init__(self, meteor_coeff: int) -> None:
+ def __init__(self, meteor_coeff: int, test: bool = False) -> None:
super().__init__()
self._meteor_coeff = meteor_coeff
- self._metric = evaluate.load(
- str(Path(nlp.__file__).parent / "utils/metrics/meteor.py")
- )
+ if test:
+ self._metric = VirtualMetric()
+ else:
+ self._metric = evaluate.load(
+ str(Path(nlp.__file__).parent / "utils/metrics/meteor.py")
+ )
def __call__(
self,
diff --git a/openrl/envs/snake/common.py b/openrl/envs/snake/common.py
deleted file mode 100644
index 6a67a0a3..00000000
--- a/openrl/envs/snake/common.py
+++ /dev/null
@@ -1,227 +0,0 @@
-import os
-import sys
-
-import numpy as np
-
-
-class HiddenPrints:
- def __enter__(self):
- self._original_stdout = sys.stdout
- sys.stdout = open(os.devnull, "w")
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.stdout.close()
- sys.stdout = self._original_stdout
-
-
-class Board:
- def __init__(self, board_height, board_width, snakes, beans_positions, teams):
- # print('create board, beans_position: ', beans_positions)
- self.height = board_height
- self.width = board_width
- self.snakes = snakes
- self.snakes_count = len(snakes)
- self.beans_positions = beans_positions
- self.blank_sign = -self.snakes_count
- self.bean_sign = -self.snakes_count + 1
- self.board = np.zeros((board_height, board_width), dtype=int) + self.blank_sign
- self.open = dict()
- for key, snake in self.snakes.items():
- self.open[key] = [snake.head] # state 0 open list, heads, ready to spread
- # see [A* Pathfinding (E01: algorithm explanation)](https://www.youtube.com/watch?v=-L-WgKMFuhE)
- for x, y in snake.pos:
- self.board[x][y] = key # obstacles, e.g. 0, 1, 2, 3, 4, 5
- # for x, y in beans_positions:
- # self.board[x][y] = self.bean_sign # beans
-
- self.state = 0
- self.controversy = dict()
- self.teams = teams
-
- # print('initial board')
- # print(self.board)
-
- def step(self): # delay: prevent rear-end collision
- new_open = {key: [] for key in self.snakes.keys()}
- self.state += 1 # update state
- # if self.state > delay:
- # for key, snake in self.snakes.items(): # drop tail
- # if snake.len >= self.state:
- # self.board[snake.pos[-(self.state - delay)][0]][snake.pos[-(self.state - delay)][1]] \
- # = self.blank_sign
- for key, snake in self.snakes.items():
- if snake.len >= self.state:
- self.board[snake.pos[-self.state][0]][
- snake.pos[-self.state][1]
- ] = self.blank_sign # drop tail
- for key, value in self.open.items(): # value: e.g. [[8, 3], [6, 3], [7, 4]]
- others_tail_pos = [
- (
- self.snakes[_].pos[-self.state]
- if self.snakes[_].len >= self.state
- else []
- )
- for _ in set(range(self.snakes_count)) - {key}
- ]
- for x, y in value:
- # print('start to spread snake {} on grid ({}, {})'.format(key, x, y))
- for x_, y_ in [
- ((x + 1) % self.height, y), # down
- ((x - 1) % self.height, y), # up
- (x, (y + 1) % self.width), # right
- (x, (y - 1) % self.width),
- ]: # left
- sign = self.board[x_][y_]
- idx = (
- sign % self.snakes_count
- ) # which snake, e.g. 0, 1, 2, 3, 4, 5 / number of claims
- state = (
- sign // self.snakes_count
- ) # manhattan distance to snake who claim the point or its negative
- if sign == self.blank_sign: # grid in initial state
- if [x_, y_] in others_tail_pos:
- # print('do not spread other snakes tail, in case of rear-end collision')
- continue # do not spread other snakes' tail, in case of rear-end collision
- self.board[x_][y_] = self.state * self.snakes_count + key
- self.snakes[key].claimed_count += 1
- new_open[key].append([x_, y_])
-
- elif key != idx and self.state == state:
- # second claim, init controversy, change grid value from + to -
- # print(
- # '\tgird ({}, {}) in the same state claimed by different snakes '
- # 'with sign {}, idx {} and state {}'.format(
- # x_, y_, sign, idx, state))
- if (
- self.snakes[idx].len > self.snakes[key].len
- ): # shorter snake claim the controversial grid
- # print('\t\tsnake {} is shorter than snake {}'.format(key, idx))
- self.snakes[idx].claimed_count -= 1
- new_open[idx].remove([x_, y_])
- self.board[x_][y_] = self.state * self.snakes_count + key
- self.snakes[key].claimed_count += 1
- new_open[key].append([x_, y_])
- elif (
- self.snakes[idx].len == self.snakes[key].len
- ): # controversial claim
- # print(
- # '\t\tcontroversy! first claimed by snake {}, then claimed by snake {}'.format(idx, key))
- self.controversy[(x_, y_)] = {
- "state": self.state,
- "length": self.snakes[idx].len,
- "indexes": [idx, key],
- }
- # first claim by snake idx, then claim by snake key
- self.board[x_][y_] = -self.state * self.snakes_count + 1
- # if + 2, not enough for all snakes claim one grid!!
- self.snakes[
- idx
- ].claimed_count -= (
- 1 # controversy, no snake claim this grid!!
- )
- new_open[key].append([x_, y_])
- else: # (self.snakes[idx].len < self.snakes[key].len)
- pass # longer snake do not claim the controversial grid
-
- elif (
- (x_, y_) in self.controversy
- and key not in self.controversy[(x_, y_)]["indexes"]
- and self.state + state == 0
- ): # third claim or more
- # print('snake {} meets third or more claim in grid ({}, {})'.format(key, x_, y_))
- controversy = self.controversy[(x_, y_)]
- # pprint.pprint(controversy)
- if (
- controversy["length"] > self.snakes[key].len
- ): # shortest snake claim grid, do 4 things
- # print('\t\tsnake {} is shortest'.format(key))
- indexes_count = len(controversy["indexes"])
- for i in controversy["indexes"]:
- self.snakes[i].claimed_count -= (
- 1 / indexes_count
- ) # update claimed_count !
- new_open[i].remove([x_, y_])
- del self.controversy[(x_, y_)]
- self.board[x_][y_] = self.state * self.snakes_count + key
- self.snakes[key].claimed_count += 1
- new_open[key].append([x_, y_])
- elif (
- controversy["length"] == self.snakes[key].len
- ): # controversial claim
- # print('\t\tcontroversy! multi claimed by snake {}'.format(key))
- self.controversy[(x_, y_)]["indexes"].append(key)
- self.board[x_][y_] += 1
- new_open[key].append([x_, y_])
- else: # (controversy['length'] < self.snakes[key].len)
- pass # longer snake do not claim the controversial grid
- else:
- pass # do nothing with lower state grids
-
- self.open = new_open # update open
- # update controversial snakes' claimed_count (in fraction) in the end
- for _, d in self.controversy.items():
- controversial_snake_count = len(
- d["indexes"]
- ) # number of controversial snakes
- for idx in d["indexes"]:
- self.snakes[idx].claimed_count += 1 / controversial_snake_count
-
-
-class SnakePos:
- def __init__(self, snake_positions, board_height, board_width, beans_positions):
- self.pos = snake_positions # [[2, 9], [2, 8], [2, 7]]
- self.len = len(snake_positions) # >= 3
- self.head = snake_positions[0]
- self.beans_positions = beans_positions
- self.claimed_count = 0
-
- displace = [
- (self.head[0] - snake_positions[1][0]) % board_height,
- (self.head[1] - snake_positions[1][1]) % board_width,
- ]
- # print('creat snake, pos: ', self.pos, 'displace:', displace)
- if displace == [
- board_height - 1,
- 0,
- ]: # all action are ordered by left, up, right, relative to the body
- self.dir = 0 # up
- self.legal_action = [2, 0, 3]
- elif displace == [1, 0]:
- self.dir = 1 # down
- self.legal_action = [3, 1, 2]
- elif displace == [0, board_width - 1]:
- self.dir = 2 # left
- self.legal_action = [1, 2, 0]
- elif displace == [0, 1]:
- self.dir = 3 # right
- self.legal_action = [0, 3, 1]
- else:
- assert False, "snake positions error"
- positions = [
- [(self.head[0] - 1) % board_height, self.head[1]],
- [(self.head[0] + 1) % board_height, self.head[1]],
- [self.head[0], (self.head[1] - 1) % board_width],
- [self.head[0], (self.head[1] + 1) % board_width],
- ]
- self.legal_position = [positions[_] for _ in self.legal_action]
-
- def get_action(self, position):
- if position not in self.legal_position:
- assert False, "the start and end points do not match"
- idx = self.legal_position.index(position)
- return self.legal_action[idx] # 0, 1, 2, 3: up, down, left, right
-
- def step(self, legal_input):
- if legal_input in self.legal_position:
- position = legal_input
- elif legal_input in self.legal_action:
- idx = self.legal_action.index(legal_input)
- position = self.legal_position[idx]
- else:
- assert False, "illegal snake move"
- self.head = position
- self.pos.insert(0, position)
- if position in self.beans_positions: # eat a bean
- self.len += 1
- else: # do not eat a bean
- self.pos.pop()
diff --git a/openrl/envs/snake/snake.py b/openrl/envs/snake/snake.py
index 73e81229..4a5be6a5 100644
--- a/openrl/envs/snake/snake.py
+++ b/openrl/envs/snake/snake.py
@@ -674,7 +674,9 @@ class Snake:
def __init__(self, player_id, board_width, board_height, init_len):
self.actions = [-2, 2, -1, 1]
self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"}
- self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
+ self.direction = random.choice(
+ self.actions
+ ) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
self.board_width = board_width
self.board_height = board_height
x = random.randrange(0, board_height)
diff --git a/openrl/envs/snake/snake_3v3.py b/openrl/envs/snake/snake_3v3.py
deleted file mode 100644
index 78d787ef..00000000
--- a/openrl/envs/snake/snake_3v3.py
+++ /dev/null
@@ -1,854 +0,0 @@
-# -*- coding:utf-8 -*-
-# 作者:zruizhi
-# 创建时间: 2020/7/30 17:24 下午
-# 描述:
-import copy
-import itertools
-import random
-import time
-from itertools import count
-
-import numpy as np
-from gym import Env, spaces
-from PIL import Image, ImageDraw, ImageFont
-
-from .common import Board, HiddenPrints, SnakePos # TODO: Snake类的重名问题
-from .discrete import Discrete
-from .gridgame import GridGame
-from .observation import *
-
-
-class SnakeEatBeans(GridGame, GridObservation, DictObservation):
- def __init__(self, all_args, env_id):
- self.all_args = all_args
- conf = {
- "class_literal": "SnakeEatBeans",
- "n_player": 6,
- "board_width": 20,
- "board_height": 10,
- "channels": 15,
- "cell_range": 8,
- "n_beans": 5,
- "max_step": 200,
- "game_name": "snakes",
- "is_obs_continuous": False,
- "is_act_continuous": False,
- "agent_nums": [3, 3],
- "obs_type": ["dict", "dict"],
- "save_interval": 100,
- "save_path": "../../replay/snake_3v3/replay_{}.gif",
- }
- self.terminate_flg = False
- colors = conf.get("colors", [(255, 255, 255), (255, 140, 0)])
- super(SnakeEatBeans, self).__init__(conf, colors)
- # 0: 没有 1:食物 2-n_player+1:各玩家蛇身
- self.n_cell_type = self.n_player + 2
- self.step_cnt = 1
- self.n_beans = int(conf["n_beans"])
- # 方向[-2,2,-1,1]分别表示[上,下,左,右]
- self.actions = [-2, 2, -1, 1]
- self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"}
- self.snakes_position = {}
- self.players = []
- self.cur_bean_num = 0
- self.beans_position = []
- # 1<= init_len <= 3
- self.init_len = 3
- self.current_state = self.init_state()
- self.all_observes = self.get_all_observes()
- if self.n_player * self.init_len > self.board_height * self.board_width:
- raise Exception(
- "玩家数量过多:%d,超出board范围:%d,%d"
- % (self.n_player, self.board_width, self.board_height)
- )
-
- self.input_dimension = self.board_width * self.board_height
- self.action_dim = self.get_action_dim()
- self.channels = conf["channels"]
-
- self.num_agents = conf["agent_nums"][0]
- self.num_enemys = conf["agent_nums"][1]
-
- self.observation_space = [
- spaces.Box(
- low=-np.inf,
- high=-np.inf,
- shape=(self.channels, self.board_width, self.board_height),
- dtype=np.float32,
- )
- ]
- self.share_observation_space = []
- self.share_observation_space = [
- spaces.Box(
- low=-np.inf,
- high=+np.inf,
- shape=(self.channels, self.board_width, self.board_height),
- dtype=np.float32,
- )
- ]
- self.action_space = [Discrete(4) for _ in range(self.n_player)]
- self.save_interval = conf["save_interval"]
- self.save_path = conf["save_path"]
- self.episode = 0
- self.render = all_args.save_replay
- self.img_list = []
- self.env_id = env_id
-
- def seed(self, seed=None):
- if seed is None:
- np.random.seed(1)
- else:
- np.random.seed(seed)
-
- def check_win(self):
- flg = self.won.index(max(self.won)) + 2
- return flg
-
- def get_grid_observation(self, current_state, player_id, info_before):
- return current_state
-
- def get_dict_observation(self, current_state, player_id, info_before):
- key_info = {1: self.beans_position}
- for i in range(self.n_player):
- snake = self.players[i]
- key_info[snake.player_id] = snake.segments
- # key_info['state_map'] = current_state
- key_info["board_width"] = self.board_width
- key_info["board_height"] = self.board_height
- key_info["last_direction"] = (
- info_before.get("directions") if isinstance(info_before, dict) else None
- )
- key_info["controlled_snake_index"] = player_id
-
- return key_info
-
- def set_action_space(self):
- action_space = [[Discrete(4)] for _ in range(self.n_player)]
- return action_space
-
- def reset(self):
- self.step_cnt = 1
- self.snakes_position = (
- {}
- ) # 格式类似于{1: [[3, 1], [4, 3], [1, 2], [0, 6], [3, 3]], 2: [[3, 0], [3, 7], [3, 6]], 3: [[2, 7], [1, 7], [0, 7]]}
- self.players = []
- self.cur_bean_num = 0
- self.beans_position = []
- self.current_state = self.init_state()
- self.all_observes = self.get_all_observes()
- self.terminate_flg = False
- self.img_list = []
- self.episode += 1
-
- # available actions
- left_avail_actions = np.ones([self.num_agents, self.action_dim])
- right_avail_actions = np.ones([self.num_enemys, self.action_dim])
- avail_actions = np.concatenate([left_avail_actions, right_avail_actions], 0)
- # process obs
- board = []
- for i in range(self.n_player):
- board.append([self.get_board(self.all_observes[i])])
-
- board_ = np.concatenate(board)
- obs = []
- for raw_obs in self.all_observes:
- obs.append([self.raw2vec(raw_obs)])
- obs_ = np.concatenate(obs)
- obs_ = np.concatenate((obs_, board_), axis=1)
-
- share_obs = np.repeat(np.expand_dims(obs_[0], axis=0), 6, 0)
-
- return obs_, share_obs, avail_actions # obs:(n_player, 288)
-
- # return self.all_observes
-
- def step(self, joint_action):
- info_before = self.step_before_info()
- joint_action = np.expand_dims(joint_action, 1)
- all_observes, info_after = self.get_next_state(joint_action)
- done = self.is_terminal()
- reward = self.get_reward(joint_action)
- left_avail_actions = np.ones([self.num_agents, self.action_dim])
- right_avail_actions = np.ones([self.num_enemys, self.action_dim])
- avail_actions = np.concatenate([left_avail_actions, right_avail_actions], 0)
-
- board = []
- for i in range(self.n_player):
- board.append([self.get_board(all_observes[i])])
-
- board_ = np.concatenate(board)
-
- obs = []
-
- for raw_obs in all_observes:
- obs.append([self.raw2vec(raw_obs)]) # obs:[[(14, 20, 10)], [], ..., []]
-
- obs_ = np.concatenate(obs) # (n_player, channels, width, height)
- obs_ = np.concatenate((obs_, board_), axis=1)
-
- share_obs = np.repeat(np.expand_dims(obs_[0], axis=0), 6, 0)
-
- if done:
- reward = self.get_final_reward(reward)
-
- rewards = np.expand_dims(np.array(reward), axis=1)
-
- dones = [done] * self.n_player
- infos = [info_after] * self.n_player
-
- if self.render and self.episode % self.save_interval == 0 and self.env_id == 0:
- img = self.render_board()
- img_pil = Image.fromarray(img)
- self.img_list.append(img_pil)
-
- if done:
- self.img_list[0].save(
- self.save_path.format(self.episode),
- save_all=True,
- append_images=self.img_list[1:],
- duration=400,
- )
- print("save replay gif to" + self.save_path.format(self.episode))
-
- return obs_, share_obs, rewards, dones, infos, avail_actions
- # return all_observes, reward, done, info_before, info_after
-
- # obs: 0 空白 1 豆子 2 我方蛇头 3 我方蛇身 4-5 友方蛇头 6-7 友方蛇身 8-10 敌方蛇头 11-13 敌方蛇身
- def raw2vec(self, raw_obs):
- control_index = raw_obs["controlled_snake_index"]
- width = raw_obs["board_width"]
- height = raw_obs["board_height"]
- beans = raw_obs[1]
- pos = raw_obs[control_index]
-
- obs = np.zeros(width * height, dtype=int)
- head_h, head_w = pos[0]
- obs[head_h * width + head_w] = 2
-
- for bean in beans:
- h, w = bean
- obs[h * width + w] = 1
-
- for p in pos[1:]:
- h, w = p
- obs[h * width + w] = 3
-
- if control_index == 2:
- h1, w1 = raw_obs[3][0]
- h2, w2 = raw_obs[4][0]
- obs[h1 * width + w1] = 4
- obs[h2 * width + w2] = 5
- for p in raw_obs[3][1:]:
- h, w = p
- obs[h * width + w] = 6
- for p in raw_obs[4][1:]:
- h, w = p
- obs[h * width + w] = 7
- for i in range(self.num_agents + 2, self.n_player + 2):
- h, w = raw_obs[i][0]
- obs[h * width + w] = i + 3
- for p in raw_obs[i][1:]:
- h, w = p
- obs[h * width + w] = i + 6
- elif control_index == 3:
- h1, w1 = raw_obs[2][0]
- h2, w2 = raw_obs[4][0]
- obs[h1 * width + w1] = 4
- obs[h2 * width + w2] = 5
- for p in raw_obs[2][1:]:
- h, w = p
- obs[h * width + w] = 6
- for p in raw_obs[4][1:]:
- h, w = p
- obs[h * width + w] = 7
- for i in range(self.num_agents + 2, self.n_player + 2):
- h, w = raw_obs[i][0]
- obs[h * width + w] = i + 3
- for p in raw_obs[i][1:]:
- h, w = p
- obs[h * width + w] = i + 6
- elif control_index == 4:
- h1, w1 = raw_obs[2][0]
- h2, w2 = raw_obs[3][0]
- obs[h1 * width + w1] = 4
- obs[h2 * width + w2] = 5
- for p in raw_obs[2][1:]:
- h, w = p
- obs[h * width + w] = 6
- for p in raw_obs[3][1:]:
- h, w = p
- obs[h * width + w] = 7
- for i in range(self.num_agents + 2, self.n_player + 2):
- h, w = raw_obs[i][0]
- obs[h * width + w] = i + 3
- for p in raw_obs[i][1:]:
- h, w = p
- obs[h * width + w] = i + 6
- elif control_index == 5:
- h1, w1 = raw_obs[6][0]
- h2, w2 = raw_obs[7][0]
- obs[h1 * width + w1] = 4
- obs[h2 * width + w2] = 5
- for p in raw_obs[6][1:]:
- h, w = p
- obs[h * width + w] = 6
- for p in raw_obs[7][1:]:
- h, w = p
- obs[h * width + w] = 7
- for i in range(2, self.num_agents + 2):
- h, w = raw_obs[i][0]
- obs[h * width + w] = i + 6
- for p in raw_obs[i][1:]:
- h, w = p
- obs[h * width + w] = i + 9
- elif control_index == 6:
- h1, w1 = raw_obs[5][0]
- h2, w2 = raw_obs[7][0]
- obs[h1 * width + w1] = 4
- obs[h2 * width + w2] = 5
- for p in raw_obs[5][1:]:
- h, w = p
- obs[h * width + w] = 6
- for p in raw_obs[7][1:]:
- h, w = p
- obs[h * width + w] = 7
- for i in range(2, self.num_agents + 2):
- h, w = raw_obs[i][0]
- obs[h * width + w] = i + 6
- for p in raw_obs[i][1:]:
- h, w = p
- obs[h * width + w] = i + 9
- else:
- h1, w1 = raw_obs[5][0]
- h2, w2 = raw_obs[6][0]
- obs[h1 * width + w1] = 4
- obs[h2 * width + w2] = 5
- for p in raw_obs[5][1:]:
- h, w = p
- obs[h * width + w] = 6
- for p in raw_obs[6][1:]:
- h, w = p
- obs[h * width + w] = 7
- for i in range(2, self.num_agents + 2):
- h, w = raw_obs[i][0]
- obs[h * width + w] = i + 6
- for p in raw_obs[i][1:]:
- h, w = p
- obs[h * width + w] = i + 9
-
- obs_ = np.zeros(width * height * (self.channels - 1), dtype=int)
- for i in range(width * height):
- obs_[i * (self.channels - 1) + obs[i]] = (
- 1 # channels的最后一维是territory matrix, 此处不生成, 要减去
- )
- obs_ = obs_.reshape(
- height, width, (self.channels - 1)
- ) # (height, width, channels-1 )
- obs_ = obs_.transpose((2, 1, 0))
-
- return obs_
-
- def get_board(self, observation_list):
- observation_len = len(observation_list.keys())
- teams = None
- teams = [[0, 1, 2], [3, 4, 5]] # 3v3
- teams_count = len(teams)
- snakes_count = sum([len(_) for _ in teams])
-
- # read observation
- obs = observation_list.copy()
- board_height = obs["board_height"] # 10
- board_width = obs["board_width"] # 20
- # print("obs['controlled_snake_index'] is ", obs['controlled_snake_index'])
- ctrl_agent_index = obs["controlled_snake_index"] - 2 # 0, 1, 2, 3, 4, 5
- # last_directions = obs['last_direction'] # ['up', 'left', 'down', 'left', 'left', 'left']
- beans_positions = obs[1] # e.g.[[7, 15], [4, 14], [5, 12], [4, 12], [5, 7]]
- snakes = {
- key - 2: SnakePos(obs[key], board_height, board_width, beans_positions)
- for key in obs.keys() & {_ + 2 for _ in range(snakes_count)}
- } # &: intersection
- team_indexes = [_ for _ in teams if ctrl_agent_index in _][0]
-
- init_board = Board(board_height, board_width, snakes, beans_positions, teams)
- bd = copy.deepcopy(init_board)
-
- with HiddenPrints():
- while not all(
- _ == [] for _ in bd.open.values()
- ): # loop until all values in open are empty list
- bd.step()
-
- board = np.array(bd.board).transpose()
- board = np.expand_dims(board, axis=0)
- return board
-
- def init_state(self):
- for i in range(self.n_player):
- s = Snake(i + 2, self.board_width, self.board_height, self.init_len)
- s_len = 1
- while s_len < self.init_len:
- if s_len == 1 and i > 0:
- origin_hit = self.is_hit(s.headPos, self.snakes_position)
- else:
- origin_hit = 0
- cur_head = s.move_and_add(self.snakes_position)
- cur_hit = self.is_hit(cur_head, self.snakes_position) or self.is_hit(
- cur_head, {i: s.segments[1:]}
- )
- if origin_hit or cur_hit:
- x = random.randrange(0, self.board_height)
- y = random.randrange(0, self.board_width)
- s.headPos = [x, y]
- s.segments = [s.headPos]
- s.direction = random.choice(self.actions)
- s_len = 1
- else:
- s_len += 1
- self.snakes_position[s.player_id] = s.segments
- self.players.append(s)
-
- self.generate_beans()
- self.init_info = {
- "snakes_position": [
- list(v)
- for k, v in sorted(
- self.snakes_position.items(), key=lambda item: item[0]
- )
- ],
- "beans_position": list(self.beans_position),
- }
- directs = []
- for i in range(len(self.players)):
- s = self.players[i]
- directs.append(self.actions_name[s.direction])
- self.init_info["directions"] = directs
-
- return self.update_state()
-
- def update_state(self):
- next_state = [
- [[0] * self.cell_dim for _ in range(self.board_width)]
- for _ in range(self.board_height)
- ]
- for i in range(self.n_player):
- snake = self.players[i]
- for pos in snake.segments:
- next_state[pos[0]][pos[1]][0] = i + 2
-
- for pos in self.beans_position:
- next_state[pos[0]][pos[1]][0] = 1
-
- return next_state
-
- def step_before_info(self, info=""):
- directs = []
- for i in range(len(self.players)):
- s = self.players[i]
- directs.append(self.actions_name[s.direction])
- info = {"directions": directs}
-
- return info
-
- def is_hit(self, cur_head, snakes_position):
- is_hit = False
- for k, v in snakes_position.items():
- for pos in v:
- if cur_head == pos:
- is_hit = True
- # print("hit:", cur_head, snakes_position)
- break
- if is_hit:
- break
-
- return is_hit
-
- def generate_beans(self):
- all_valid_positions = set(
- itertools.product(range(0, self.board_height), range(0, self.board_width))
- )
- all_valid_positions = all_valid_positions - set(map(tuple, self.beans_position))
- for positions in self.snakes_position.values():
- all_valid_positions = all_valid_positions - set(map(tuple, positions))
-
- left_bean_num = self.n_beans - self.cur_bean_num
- all_valid_positions = np.array(list(all_valid_positions))
- left_valid_positions = len(all_valid_positions)
-
- new_bean_num = (
- left_bean_num
- if left_valid_positions > left_bean_num
- else left_valid_positions
- )
-
- if left_valid_positions > 0:
- new_bean_positions_idx = np.random.choice(
- left_valid_positions, size=new_bean_num, replace=False
- )
- new_bean_positions = all_valid_positions[new_bean_positions_idx]
- else:
- new_bean_positions = []
-
- for new_bean_pos in new_bean_positions:
- self.beans_position.append(list(new_bean_pos))
- self.cur_bean_num += 1
-
- def get_all_observes(self, before_info=""):
- self.all_observes = []
- for i in range(self.n_player):
- each_obs = self.get_dict_observation(self.current_state, i + 2, before_info)
- self.all_observes.append(each_obs)
-
- return self.all_observes
-
- def get_next_state(self, all_action):
- before_info = self.step_before_info()
- not_valid = self.is_not_valid_action(all_action)
- if not not_valid:
- # 各玩家行动
- # print("current_state", self.current_state)
- eat_snakes = [0] * self.n_player
- ally_reward = 0
- enemy_reward = 0
- for i in range(self.n_player): # 判断是否吃到了豆子
- snake = self.players[i]
- act = self.actions[np.argmax(all_action[i][0])]
- # print(snake.player_id, "此轮的动作为:", self.actions_name[act])
- snake.change_direction(act)
- snake.move_and_add(self.snakes_position) # 更新snake.segment
- if self.be_eaten(snake.headPos): # @yanxue
- snake.snake_reward = 1
- eat_snakes[i] = 1
- else:
- snake.snake_reward = 0
- snake.pop()
- # print(snake.player_id, snake.segments) # @yanxue
- snake_position = [[-1] * self.board_width for _ in range(self.board_height)]
- re_generatelist = [0] * self.n_player
- for i in range(self.n_player): # 判断是否相撞
- snake = self.players[i]
- segment = snake.segments
- for j in range(len(segment)):
- x = segment[j][0]
- y = segment[j][1]
- if snake_position[x][y] != -1:
- if j == 0: # 撞头
- re_generatelist[i] = 1
- compare_snake = self.players[snake_position[x][y]]
- if [x, y] == compare_snake.segments[0]: # 两头相撞won
- re_generatelist[snake_position[x][y]] = 1
- else:
- snake_position[x][y] = i
- for i in range(self.n_player):
- snake = self.players[i]
- if re_generatelist[i] == 1:
- if eat_snakes[i] == 1:
- snake.snake_reward = (
- self.init_len - len(snake.segments) + 1
- ) # 身体越长,惩罚越大
- else:
- snake.snake_reward = self.init_len - len(snake.segments)
- snake.segments = []
-
- for i in range(self.num_agents):
- ally_reward += self.players[i].snake_reward
- for i in range(self.num_enemys):
- enemy_reward += self.players[i + self.num_agents].snake_reward
- alpha = 0.8
- for i in range(self.num_agents):
- self.players[i].snake_reward = (
- self.players[i].snake_reward - enemy_reward / 3
- ) * alpha + ally_reward / 3 * (1 - alpha)
- for i in range(self.num_agents, self.n_player):
- self.players[i].snake_reward = (
- self.players[i].snake_reward - ally_reward / 3
- ) * alpha + enemy_reward / 3 * (1 - alpha)
-
- for i in range(self.n_player):
- snake = self.players[i]
- if re_generatelist[i] == 1:
- snake = self.clear_or_regenerate(snake)
- self.snakes_position[snake.player_id] = snake.segments
- snake.score = snake.get_score()
- # yanxue add
- # 更新状态
- self.generate_beans()
-
- next_state = self.update_state()
- self.current_state = next_state
- self.step_cnt += 1
-
- self.won = [0] * self.n_player
-
- for i in range(self.n_player):
- s = self.players[i]
- self.won[i] = s.score
- info_after = {}
- info_after["snakes_position"] = [
- list(v)
- for k, v in sorted(
- self.snakes_position.items(), key=lambda item: item[0]
- )
- ]
- info_after["beans_position"] = list(self.beans_position)
- info_after["hit"] = re_generatelist
- info_after["score"] = self.won
- self.all_observes = self.get_all_observes(before_info)
-
- return self.all_observes, info_after
-
- def clear_or_regenerate(self, snake):
- direct_x = [0, 1, -1, 0]
- direct_y = [1, 0, 0, -1]
- snake.segments = []
- snake.score = 0
- grid = self.get_render_data(self.update_state())
-
- def can_regenerate():
- for x in range(self.board_height):
- for y in range(self.board_width):
- if grid[x][y] == 0:
- q = []
- q.append([x, y])
- seg = []
- while q:
- cur = q.pop(0)
- if cur not in seg:
- seg.append(cur)
- for i in range(4):
- nx = (direct_x[i] + cur[0]) % self.board_height
- ny = (direct_y[i] + cur[1]) % self.board_width
- # if nx < 0 or nx >= self.board_height or ny < 0 or ny >= self.board_width:
- # continue
- if grid[nx][ny] == 0 and [nx, ny] not in q:
- grid[nx][ny] = 1
- q.append([nx, ny])
- if len(seg) == self.init_len:
- # print("regenerate")
- if len(seg) < 3:
- snake.direction = random.choice(self.actions)
- elif len(seg) == 3:
- mid = (
- [seg[1][0], seg[2][1]],
- [seg[2][0], seg[1][1]],
- )
- if seg[0] in mid:
- seg[0], seg[1] = seg[1], seg[0]
- snake.segments = seg
- snake.headPos = seg[0]
- if seg[0][0] == seg[1][0]:
- # 右
- if seg[0][1] > seg[1][1]:
- snake.direction = 1
- # 左
- else:
- snake.direction = -1
- elif seg[0][1] == seg[1][1]:
- # 下
- if seg[0][0] > seg[1][0]:
- snake.direction = 2
- # 上
- else:
- snake.direction = -2
- # print("re head", snake.headPos) # 输出重新生成的蛇
- # print("re snakes segments", snake.segments)
- return True
- # print("clear")
- return False
-
- flg = can_regenerate()
- if not flg:
- self.terminate_flg = True
- # print(self.terminate_flg)
- return snake
-
- def is_not_valid_action(self, all_action):
- not_valid = 0
- if len(all_action) != self.n_player:
- raise Exception("all action 维度不正确!", len(all_action))
-
- for i in range(self.n_player):
- if len(all_action[i][0]) != 4:
- raise Exception("玩家%d joint action维度不正确!" % i, all_action[i])
- return not_valid
-
- def get_reward(self, all_action):
- r = [0] * self.n_player
- for i in range(self.n_player):
- r[i] = self.players[i].snake_reward
- self.n_return[i] += r[i]
- # print("score:", self.won)
- return r
-
- def get_final_reward(self, reward):
- ally_reward = reward[0] + reward[1] + reward[2]
- enemy_reward = reward[3] + reward[4] + reward[5]
- if ally_reward > enemy_reward:
- reward[0] += 10
- reward[1] += 10
- reward[2] += 10
- reward[3] -= 10
- reward[4] -= 10
- reward[5] -= 10
- elif ally_reward < enemy_reward:
- reward[3] += 10
- reward[4] += 10
- reward[5] += 10
- reward[0] -= 10
- reward[1] -= 10
- reward[2] -= 10
- return reward
-
- def is_terminal(self):
- all_member = self.n_beans
- # all_member = len(self.beans_position)
- for s in self.players:
- all_member += len(s.segments)
- is_done = (
- self.step_cnt > self.max_step
- or all_member > self.board_height * self.board_width
- )
-
- return is_done or self.terminate_flg
-
- def encode(self, actions):
- joint_action = self.init_action_space()
- if len(actions) != self.n_player:
- raise Exception("action输入维度不正确!", len(actions))
- for i in range(self.n_player):
- joint_action[i][0][int(actions[i])] = 1
- return joint_action
-
- def get_terminal_actions(self):
- print("请输入%d个玩家的动作方向[0-3](上下左右),空格隔开:" % self.n_player)
- cur = input()
- actions = cur.split(" ")
- return self.encode(actions)
-
- def be_eaten(self, snake_pos):
- for bean in self.beans_position:
- if snake_pos[0] == bean[0] and snake_pos[1] == bean[1]:
- self.beans_position.remove(bean)
- self.cur_bean_num -= 1
- return True
- return False
-
- def get_action_dim(self):
- action_dim = 1
- for i in range(len(self.joint_action_space[0])):
- action_dim *= self.joint_action_space[0][i].n
-
- return action_dim
-
- def draw_board(self):
- cols = [chr(i) for i in range(65, 65 + self.board_width)]
- s = ", ".join(cols)
- print(" ", s)
- for i in range(self.board_height):
- # print(i)
- print(chr(i + 65), self.current_state[i])
-
- @staticmethod
- def _render_board(state, board, colors, unit, fix, extra_info):
- im = GridGame._render_board(state, board, colors, unit, fix)
- draw = ImageDraw.Draw(im)
- # fnt = ImageFont.truetype("Courier.dfont", 16)
- fnt = ImageFont.load_default()
- for i, pos in zip(count(1), extra_info):
- x, y = pos
- draw.text(
- ((y + 1 / 4) * unit, (x + 1 / 4) * unit),
- "#{}".format(i),
- font=fnt,
- fill=(0, 0, 0),
- )
-
- return im
-
- def render_board(self):
- extra_info = [tuple(x.headPos) for x in self.players]
- im_data = np.array(
- SnakeEatBeans._render_board(
- self.get_render_data(self.current_state),
- self.grid,
- self.colors,
- self.grid_unit,
- self.grid_unit_fix,
- extra_info,
- )
- )
- return im_data
-
- @staticmethod
- def parse_extra_info(data):
- # return eval(re.search(r'({.*})', data['info_after']).group(1)).values()
- # d = (eval(eval(data)['snakes_position']).values())
- if isinstance(data, str):
- d = eval(data)["snakes_position"]
- else:
- d = data["snakes_position"]
-
- return [i[0] for i in d]
-
-
-class Snake:
- def __init__(self, player_id, board_width, board_height, init_len):
- self.actions = [-2, 2, -1, 1]
- self.actions_name = {-2: "up", 2: "down", -1: "left", 1: "right"}
- self.direction = random.choice(self.actions) # 方向[-2,2,-1,1]分别表示[上,下,左,右]
- self.board_width = board_width
- self.board_height = board_height
- x = random.randrange(0, board_height)
- y = random.randrange(0, board_width)
- self.segments = [[x, y]]
- self.headPos = self.segments[0]
- self.player_id = player_id
- self.score = 0
- self.snake_reward = 0
- self.init_len = init_len
-
- def get_score(self):
- return len(self.segments) - self.init_len
-
- def change_direction(self, act):
- if act + self.direction != 0:
- self.direction = act
- else:
- n_direct = random.choice(self.actions)
- while n_direct + self.direction == 0:
- n_direct = random.choice(self.actions)
- self.direction = n_direct
- # print("方向不合法,重新生成")
- # print("direction", self.actions_name[self.direction])
-
- # 超过边界,可以穿越
- def update_position(self, position):
- position[0] %= self.board_height
- position[1] %= self.board_width
- return position
-
- def move_and_add(self, snakes_position):
- cur_head = list(self.headPos)
- # 根据方向移动蛇头的坐标
- # 右
- if self.direction == 1:
- cur_head[1] += 1
- # 左
- if self.direction == -1:
- cur_head[1] -= 1
- # 上
- if self.direction == -2:
- cur_head[0] -= 1
- # 下
- if self.direction == 2:
- cur_head[0] += 1
-
- cur_head = self.update_position(cur_head)
- # print("cur head", cur_head)
- # print("cur snakes positions", snakes_position)
-
- self.segments.insert(0, cur_head)
- self.headPos = self.segments[0]
- return cur_head
-
- def pop(self):
- self.segments.pop() # 在蛇尾减去一格
diff --git a/openrl/envs/toy_envs/__init__.py b/openrl/envs/toy_envs/__init__.py
index 4e6588ef..cf785cc5 100644
--- a/openrl/envs/toy_envs/__init__.py
+++ b/openrl/envs/toy_envs/__init__.py
@@ -18,25 +18,12 @@
from typing import Any
from openrl.envs.toy_envs.bit_flipping_env import BitFlippingEnv
-from openrl.envs.toy_envs.identity_env import (
- FakeImageEnv,
- IdentityEnv,
- IdentityEnvBox,
- IdentityEnvcontinuous,
- IdentityEnvMultiBinary,
- IdentityEnvMultiDiscrete,
-)
-from openrl.envs.toy_envs.multi_input_envs import SimpleMultiObsEnv
+from openrl.envs.toy_envs.identity_env import IdentityEnv, IdentityEnvcontinuous
__all__ = [
"BitFlippingEnv",
- "FakeImageEnv",
"IdentityEnv",
"IdentityEnvcontinuous",
- "IdentityEnvBox",
- "IdentityEnvMultiBinary",
- "IdentityEnvMultiDiscrete",
- "SimpleMultiObsEnv",
]
@@ -49,13 +36,8 @@
env_dict = {
"BitFlippingEnv": BitFlippingEnv,
- "FakeImageEnv": FakeImageEnv,
"IdentityEnv": IdentityEnv,
"IdentityEnvcontinuous": IdentityEnvcontinuous,
- "IdentityEnvBox": IdentityEnvBox,
- "IdentityEnvMultiBinary": IdentityEnvMultiBinary,
- "IdentityEnvMultiDiscrete": IdentityEnvMultiDiscrete,
- "SimpleMultiObsEnv": SimpleMultiObsEnv,
}
diff --git a/openrl/envs/toy_envs/bit_flipping_env.py b/openrl/envs/toy_envs/bit_flipping_env.py
index 0d77ebd4..5534ed37 100644
--- a/openrl/envs/toy_envs/bit_flipping_env.py
+++ b/openrl/envs/toy_envs/bit_flipping_env.py
@@ -5,8 +5,6 @@
from gymnasium import Env, spaces
from gymnasium.envs.registration import EnvSpec
-from openrl.utils.type_aliases import GymStepReturn
-
class BitFlippingEnv(Env):
"""
@@ -175,7 +173,7 @@ def reset(
self.state = self.obs_space.sample()
return self._get_obs(), {}
- def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
+ def step(self, action: Union[np.ndarray, int]):
if self.continuous:
self.state[action > 0] = 1 - self.state[action > 0]
else:
diff --git a/openrl/envs/toy_envs/identity_env.py b/openrl/envs/toy_envs/identity_env.py
index dd653626..c3867756 100644
--- a/openrl/envs/toy_envs/identity_env.py
+++ b/openrl/envs/toy_envs/identity_env.py
@@ -6,8 +6,6 @@
from gymnasium.envs.registration import EnvSpec
from gymnasium.utils import seeding
-from openrl.utils.type_aliases import GymStepReturn
-
T = TypeVar("T", int, np.ndarray)
@@ -30,6 +28,7 @@ def __init__(
``dim`` and ``space``.
:param ep_length: the length of each episode in time_steps
"""
+
if space is None:
if dim is None:
dim = 2
@@ -38,12 +37,14 @@ def __init__(
assert (
dim is None
), "arguments for both 'dim' and 'space' provided: at most one allowed"
+
self.dim = dim
self.observation_space = spaces.Discrete(1)
self.action_space = space
self.ep_length = ep_length
self.current_step = 0
self.num_resets = -1 # Becomes 0 after __init__ exits.
+ self.metadata.update({"name": IdentityEnv})
def reset(
self,
@@ -53,6 +54,8 @@ def reset(
) -> T:
if seed is not None:
self.seed(seed)
+ if self._np_random is None:
+ self.seed(0)
self.current_step = 0
self.num_resets += 1
self._choose_next_state()
@@ -67,6 +70,7 @@ def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]:
def _choose_next_state(self) -> None:
# self.state = [self.action_space.sample()]
+ assert self.dim is not None
self.state = [self._np_random.integers(0, self.dim)]
def _get_reward(self, action: T) -> float:
@@ -153,114 +157,3 @@ def _get_reward(self, action: T) -> float:
def render(self, mode: str = "human") -> None:
pass
-
-
-# Not Work Yet
-class IdentityEnvBox(IdentityEnv[np.ndarray]):
- def __init__(
- self,
- low: float = -1.0,
- high: float = 1.0,
- eps: float = 0.05,
- ep_length: int = 100,
- ):
- """
- Identity environment for testing purposes
-
- :param low: the lower bound of the box dim
- :param high: the upper bound of the box dim
- :param eps: the epsilon bound for correct value
- :param ep_length: the length of each episode in timesteps
- """
- space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32)
- super().__init__(ep_length=ep_length, space=space)
- self.eps = eps
-
- def step(
- self, action: np.ndarray
- ) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
- reward = self._get_reward(action)
- self._choose_next_state()
- self.current_step += 1
- done = self.current_step >= self.ep_length
- return self.state, reward, done, {}
-
- def _get_reward(self, action: np.ndarray) -> float:
- return (
- 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
- )
-
-
-# Not Work Yet
-class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
- def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
- """
- Identity environment for testing purposes
-
- :param dim: the size of the dimensions you want to learn
- :param ep_length: the length of each episode in timesteps
- """
- space = spaces.MultiDiscrete([dim, dim])
- super().__init__(ep_length=ep_length, space=space)
-
-
-# Not Work Yet
-class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
- def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
- """
- Identity environment for testing purposes
-
- :param dim: the size of the dimensions you want to learn
- :param ep_length: the length of each episode in timesteps
- """
- space = spaces.MultiBinary(dim)
- super().__init__(ep_length=ep_length, space=space)
-
-
-# Not Work Yet
-class FakeImageEnv(gym.Env):
- """
- Fake image environment for testing purposes, it mimics Atari games.
-
- :param action_dim: Number of discrete actions
- :param screen_height: Height of the image
- :param screen_width: Width of the image
- :param n_channels: Number of color channels
- :param discrete: Create discrete action space instead of continuous
- :param channel_first: Put channels on first axis instead of last
- """
-
- def __init__(
- self,
- action_dim: int = 6,
- screen_height: int = 84,
- screen_width: int = 84,
- n_channels: int = 1,
- discrete: bool = True,
- channel_first: bool = False,
- ) -> None:
- self.observation_shape = (screen_height, screen_width, n_channels)
- if channel_first:
- self.observation_shape = (n_channels, screen_height, screen_width)
- self.observation_space = spaces.Box(
- low=0, high=255, shape=self.observation_shape, dtype=np.uint8
- )
- if discrete:
- self.action_space = spaces.Discrete(action_dim)
- else:
- self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
- self.ep_length = 10
- self.current_step = 0
-
- def reset(self) -> np.ndarray:
- self.current_step = 0
- return self.observation_space.sample()
-
- def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
- reward = 0.0
- self.current_step += 1
- done = self.current_step >= self.ep_length
- return self.observation_space.sample(), reward, done, {}
-
- def render(self, mode: str = "human") -> None:
- pass
diff --git a/openrl/envs/toy_envs/multi_input_envs.py b/openrl/envs/toy_envs/multi_input_envs.py
deleted file mode 100644
index 952a5b04..00000000
--- a/openrl/envs/toy_envs/multi_input_envs.py
+++ /dev/null
@@ -1,187 +0,0 @@
-from typing import Dict, Union
-
-import gymnasium as gym
-import numpy as np
-from gymnasium import spaces
-
-from openrl.utils.type_aliases import GymStepReturn
-
-
-# Not Work Yet
-class SimpleMultiObsEnv(gym.Env):
- """
- Base class for GridWorld-based MultiObs Environments 4x4 grid world.
-
- .. code-block:: text
-
- ____________
- | 0 1 2 3|
- | 4|¯5¯¯6¯| 7|
- | 8|_9_10_|11|
- |12 13 14 15|
- ¯¯¯¯¯¯¯¯¯¯¯¯¯¯
-
- start is 0
- states 5, 6, 9, and 10 are blocked
- goal is 15
- actions are = [left, down, right, up]
-
- simple linear state env of 15 states but encoded with a vector and an image observation:
- each column is represented by a random vector and each row is
- represented by a random image, both sampled once at creation time.
-
- :param num_col: Number of columns in the grid
- :param num_row: Number of rows in the grid
- :param random_start: If true, agent starts in random position
- :param channel_last: If true, the image will be channel last, else it will be channel first
- """
-
- def __init__(
- self,
- num_col: int = 4,
- num_row: int = 4,
- random_start: bool = True,
- discrete_actions: bool = True,
- channel_last: bool = True,
- ):
- super().__init__()
-
- self.vector_size = 5
- if channel_last:
- self.img_size = [64, 64, 1]
- else:
- self.img_size = [1, 64, 64]
-
- self.random_start = random_start
- self.discrete_actions = discrete_actions
- if discrete_actions:
- self.action_space = spaces.Discrete(4)
- else:
- self.action_space = spaces.Box(0, 1, (4,))
-
- self.observation_space = spaces.Dict(
- spaces={
- "vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64),
- "img": spaces.Box(0, 255, self.img_size, dtype=np.uint8),
- }
- )
- self.count = 0
- # Timeout
- self.max_count = 100
- self.log = ""
- self.state = 0
- self.action2str = ["left", "down", "right", "up"]
- self.init_possible_transitions()
-
- self.num_col = num_col
- self.state_mapping = []
- self.init_state_mapping(num_col, num_row)
-
- self.max_state = len(self.state_mapping) - 1
-
- def init_state_mapping(self, num_col: int, num_row: int) -> None:
- """
- Initializes the state_mapping array which holds the observation values for each state
-
- :param num_col: Number of columns.
- :param num_row: Number of rows.
- """
- # Each column is represented by a random vector
- col_vecs = np.random.random((num_col, self.vector_size))
- # Each row is represented by a random image
- row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8)
-
- for i in range(num_col):
- for j in range(num_row):
- self.state_mapping.append(
- {"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)}
- )
-
- def get_state_mapping(self) -> Dict[str, np.ndarray]:
- """
- Uses the state to get the observation mapping.
-
- :return: observation dict {'vec': ..., 'img': ...}
- """
- return self.state_mapping[self.state]
-
- def init_possible_transitions(self) -> None:
- """
- Initializes the transitions of the environment
- The environment exploits the cardinal directions of the grid by noting that
- they correspond to simple addition and subtraction from the cell id within the grid
-
- - up => means moving up a row => means subtracting the length of a column
- - down => means moving down a row => means adding the length of a column
- - left => means moving left by one => means subtracting 1
- - right => means moving right by one => means adding 1
-
- Thus one only needs to specify in which states each action is possible
- in order to define the transitions of the environment
- """
- self.left_possible = [1, 2, 3, 13, 14, 15]
- self.down_possible = [0, 4, 8, 3, 7, 11]
- self.right_possible = [0, 1, 2, 12, 13, 14]
- self.up_possible = [4, 8, 12, 7, 11, 15]
-
- def step(self, action: Union[float, np.ndarray]) -> GymStepReturn:
- """
- Run one timestep of the environment's dynamics. When end of
- episode is reached, you are responsible for calling `reset()`
- to reset this environment's state.
- Accepts an action and returns a tuple (observation, reward, done, info).
-
- :param action:
- :return: tuple (observation, reward, done, info).
- """
- if not self.discrete_actions:
- action = np.argmax(action)
- else:
- action = int(action)
-
- self.count += 1
-
- prev_state = self.state
-
- reward = -0.1
- # define state transition
- if self.state in self.left_possible and action == 0: # left
- self.state -= 1
- elif self.state in self.down_possible and action == 1: # down
- self.state += self.num_col
- elif self.state in self.right_possible and action == 2: # right
- self.state += 1
- elif self.state in self.up_possible and action == 3: # up
- self.state -= self.num_col
-
- got_to_end = self.state == self.max_state
- reward = 1 if got_to_end else reward
- done = self.count > self.max_count or got_to_end
-
- self.log = (
- f"Went {self.action2str[action]} in state {prev_state}, got to state"
- f" {self.state}"
- )
-
- return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end}
-
- def render(self, mode: str = "human") -> None:
- """
- Prints the log of the environment.
-
- :param mode:
- """
- print(self.log)
-
- def reset(self) -> Dict[str, np.ndarray]:
- """
- Resets the environment state and step count and returns reset observation.
-
- :return: observation dict {'vec': ..., 'img': ...}
- """
- self.count = 0
- if not self.random_start:
- self.state = 0
- else:
- self.state = np.random.randint(0, self.max_state)
- return self.state_mapping[self.state]
diff --git a/openrl/envs/vec_env/async_venv.py b/openrl/envs/vec_env/async_venv.py
index 02d6fec2..141532ba 100644
--- a/openrl/envs/vec_env/async_venv.py
+++ b/openrl/envs/vec_env/async_venv.py
@@ -1,4 +1,5 @@
"""An async vector environment."""
+
import multiprocessing as mp
import sys
import time
@@ -233,10 +234,8 @@ def reset_send(
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
- (
- "Calling `reset_send` while waiting for a pending call to"
- f" `{self._state.value}` to complete"
- ),
+ "Calling `reset_send` while waiting for a pending call to"
+ f" `{self._state.value}` to complete",
self._state.value,
)
@@ -328,10 +327,8 @@ def step_send(self, actions: np.ndarray):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
- (
- "Calling `step_send` while waiting for a pending call to"
- f" `{self._state.value}` to complete."
- ),
+ "Calling `step_send` while waiting for a pending call to"
+ f" `{self._state.value}` to complete.",
self._state.value,
)
@@ -341,9 +338,7 @@ def step_send(self, actions: np.ndarray):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP
- def step_fetch(
- self, timeout: Optional[Union[int, float]] = None
- ) -> Union[
+ def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[
Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
]:
@@ -575,10 +570,8 @@ def call_send(self, name: str, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
- (
- "Calling `call_send` while waiting "
- f"for a pending call to `{self._state.value}` to complete."
- ),
+ "Calling `call_send` while waiting "
+ f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
@@ -635,10 +628,8 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
- (
- "Calling `exec_func_send` while waiting "
- f"for a pending call to `{self._state.value}` to complete."
- ),
+ "Calling `exec_func_send` while waiting "
+ f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
@@ -674,6 +665,7 @@ def exec_func_fetch(self, timeout: Union[int, float, None] = None) -> list:
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
+
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
@@ -715,10 +707,8 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]):
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
- (
- "Calling `set_attr` while waiting "
- f"for a pending call to `{self._state.value}` to complete."
- ),
+ "Calling `set_attr` while waiting "
+ f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
@@ -732,7 +722,7 @@ def _worker(
index: int,
env_fn: callable,
pipe: Connection,
- parent_pipe: Connection,
+ parent_pipe: Optional[Connection],
shared_memory: bool,
error_queue: Queue,
auto_reset: bool = True,
@@ -756,7 +746,8 @@ def prepare_obs(observation):
observation = None
return observation
- parent_pipe.close()
+ if parent_pipe is not None:
+ parent_pipe.close()
try:
while True:
command, data = pipe.recv()
@@ -837,7 +828,7 @@ def prepare_obs(observation):
)
elif command == "_func_exec":
function, indices, args, kwargs = data
- if index in indices:
+ if indices is None or index in indices:
if callable(function):
pipe.send((function(env, *args, **kwargs), True))
else:
diff --git a/openrl/envs/vec_env/base_venv.py b/openrl/envs/vec_env/base_venv.py
index f2e54744..c10d0d0d 100644
--- a/openrl/envs/vec_env/base_venv.py
+++ b/openrl/envs/vec_env/base_venv.py
@@ -272,7 +272,7 @@ def exec_func_fetch(self, timeout: Union[int, float, None] = None) -> list:
"""
def exec_func(
- self, func: Callable, indices: List[int], *args, **kwargs
+ self, func: Callable, indices: Optional[List[int]] = None, *args, **kwargs
) -> List[Any]:
"""Call a method, or get a property, from each parallel environment.
diff --git a/openrl/envs/vec_env/sync_venv.py b/openrl/envs/vec_env/sync_venv.py
index a670ec33..1e208e4c 100644
--- a/openrl/envs/vec_env/sync_venv.py
+++ b/openrl/envs/vec_env/sync_venv.py
@@ -15,6 +15,7 @@
# limitations under the License.
""""""
+import time
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union
@@ -202,6 +203,7 @@ def _step(self, actions: ActType):
self._truncateds[i],
info,
) = returns
+
need_reset = _need_reset and (
all(self._terminateds[i]) or all(self._truncateds[i])
)
@@ -281,7 +283,9 @@ def env_name(self):
else:
return self.envs[0].unwrapped.spec.id
- def exec_func(self, func: Callable, indices: List[int], *args, **kwargs) -> tuple:
+ def exec_func(
+ self, func: Callable, indices: Optional[List[int]] = None, *args, **kwargs
+ ) -> tuple:
"""Calls the method with name and applies args and kwargs.
Args:
@@ -294,7 +298,7 @@ def exec_func(self, func: Callable, indices: List[int], *args, **kwargs) -> tupl
"""
results = []
for i, env in enumerate(self.envs):
- if i in indices:
+ if indices is None or i in indices:
if callable(func):
results.append(func(env, *args, **kwargs))
else:
diff --git a/openrl/envs/vec_env/wrappers/reward_wrapper.py b/openrl/envs/vec_env/wrappers/reward_wrapper.py
index d0a4d630..25cdc424 100644
--- a/openrl/envs/vec_env/wrappers/reward_wrapper.py
+++ b/openrl/envs/vec_env/wrappers/reward_wrapper.py
@@ -29,8 +29,8 @@ class RewardWrapper(VecEnvWrapper):
def __init__(self, env: BaseVecEnv, reward_class: BaseReward):
super().__init__(env)
self.reward_class = reward_class
- if len(self.reward_class.inner_rew_funcs) > 0:
- env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs})
+ # if len(self.reward_class.inner_rew_funcs) > 0:
+ # env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs})
def step(
self, action: ActType, extra_data: Optional[Dict[str, Any]]
diff --git a/openrl/envs/wrappers/extra_wrappers.py b/openrl/envs/wrappers/extra_wrappers.py
index da819a87..27359d9e 100644
--- a/openrl/envs/wrappers/extra_wrappers.py
+++ b/openrl/envs/wrappers/extra_wrappers.py
@@ -21,6 +21,9 @@
import gymnasium as gym
import numpy as np
from gymnasium import spaces
+from gymnasium.utils.step_api_compatibility import (
+ convert_to_terminated_truncated_step_api,
+)
from gymnasium.wrappers import AutoResetWrapper, StepAPICompatibility
from openrl.envs.wrappers import BaseObservationWrapper, BaseRewardWrapper, BaseWrapper
@@ -46,6 +49,76 @@ def step(self, action):
return obs, total_reward, term, trunc, info
+def convert_to_done_step_api(
+ step_returns,
+ is_vector_env: bool = False,
+):
+ if len(step_returns) == 4:
+ return step_returns
+ else:
+ assert len(step_returns) == 5
+ observations, rewards, terminated, truncated, infos = step_returns
+
+ # Cases to handle - info single env / info vector env (list) / info vector env (dict)
+ # if truncated[0]:
+ # import pdb;
+ # pdb.set_trace()
+
+ if is_vector_env is False:
+ if isinstance(terminated, list):
+ infos["TimeLimit.truncated"] = truncated[0] and not terminated[0]
+ done_return = np.logical_or(terminated, truncated)
+ else:
+ if truncated or terminated:
+ infos["TimeLimit.truncated"] = truncated and not terminated
+ done_return = terminated or truncated
+ return (
+ observations,
+ rewards,
+ done_return,
+ infos,
+ )
+ elif isinstance(infos, list):
+ for info, env_truncated, env_terminated in zip(
+ infos, truncated, terminated
+ ):
+ if env_truncated or env_terminated:
+ info["TimeLimit.truncated"] = env_truncated and not env_terminated
+ return (
+ observations,
+ rewards,
+ np.logical_or(terminated, truncated),
+ infos,
+ )
+ elif isinstance(infos, dict):
+ if np.logical_or(np.any(truncated), np.any(terminated)):
+ infos["TimeLimit.truncated"] = np.logical_and(
+ truncated, np.logical_not(terminated)
+ )
+ return (
+ observations,
+ rewards,
+ np.logical_or(terminated, truncated),
+ infos,
+ )
+ else:
+ raise TypeError(
+ "Unexpected value of infos, as is_vector_envs=False, expects `info` to"
+ f" be a list or dict, actual type: {type(infos)}"
+ )
+
+
+def step_api_compatibility(
+ step_returns,
+ output_truncation_bool: bool = True,
+ is_vector_env: bool = False,
+):
+ if output_truncation_bool:
+ return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
+ else:
+ return convert_to_done_step_api(step_returns, is_vector_env)
+
+
class RemoveTruncated(StepAPICompatibility, BaseWrapper):
def __init__(
self,
@@ -54,6 +127,12 @@ def __init__(
output_truncation_bool = False
super().__init__(env, output_truncation_bool=output_truncation_bool)
+ def step(self, action):
+ step_returns = self.env.step(action)
+ return step_api_compatibility(
+ step_returns, self.output_truncation_bool, self.is_vector_env
+ )
+
class FlattenObservation(BaseObservationWrapper):
def __init__(self, env: gym.Env):
diff --git a/openrl/envs/wrappers/pettingzoo_wrappers.py b/openrl/envs/wrappers/pettingzoo_wrappers.py
index 226fdb9f..c571ff79 100644
--- a/openrl/envs/wrappers/pettingzoo_wrappers.py
+++ b/openrl/envs/wrappers/pettingzoo_wrappers.py
@@ -96,8 +96,9 @@ def last(self, observe: bool = True):
winners = None
losers = None
+
for agent in self.terminations:
- if self.terminations[agent]:
+ if self.terminations[agent] or all(self.truncations):
if winners is None:
winners = self.get_winners()
losers = [player for player in self.agents if player not in winners]
diff --git a/openrl/envs/wrappers/util.py b/openrl/envs/wrappers/util.py
index a0a97576..614a5879 100644
--- a/openrl/envs/wrappers/util.py
+++ b/openrl/envs/wrappers/util.py
@@ -41,7 +41,9 @@ def nest_expand_dim(input: Any) -> Any:
elif input is None:
return [input]
else:
- raise NotImplementedError("Not support type: {}".format(type(input)))
+ raise NotImplementedError(
+ "Not support type: {}, value={}".format(type(input), input)
+ )
def unwrap_wrapper(
diff --git a/openrl/modules/common/ppo_net.py b/openrl/modules/common/ppo_net.py
index 93dbaa64..7c537c91 100644
--- a/openrl/modules/common/ppo_net.py
+++ b/openrl/modules/common/ppo_net.py
@@ -15,7 +15,7 @@
# limitations under the License.
""""""
-
+import copy
from typing import Any, Dict, Optional, Tuple, Union
import gymnasium as gym
@@ -30,6 +30,23 @@
from openrl.utils.util import set_seed
+def reset_rnn_states(
+ rnn_states, episode_starts, env_num, agent_num, rnn_layers, hidden_size
+):
+ # First we reshape the episode_starts to match the rnn_states shape
+ # Since episode_starts affects all agents in the environment, we repeat it agent_num times
+ episode_starts = np.repeat(copy.copy(episode_starts), agent_num)
+ # We then need to expand the dimensions of episode_starts to match rnn_states
+ # The new shape of episode_starts should be (env_num * agent_num, 1, 1) to broadcast correctly
+ episode_starts = episode_starts[:, None, None]
+ # Now, episode_starts should broadcast over the last two dimensions of rnn_states when multiplied
+ # We want to set rnn_states to zero where episode_starts is 1, so we invert the episode_starts as a mask
+ mask = 1 - episode_starts
+ # Apply the mask to rnn_states, setting the appropriate states to zero
+ rnn_states *= mask
+ return rnn_states
+
+
class PPONet(BaseNet):
def __init__(
self,
@@ -89,7 +106,18 @@ def act(
observation: Union[np.ndarray, Dict[str, np.ndarray]],
action_masks: Optional[np.ndarray] = None,
deterministic: bool = False,
+ episode_starts: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
+ if episode_starts is not None:
+ self.rnn_states_actor = reset_rnn_states(
+ self.rnn_states_actor,
+ episode_starts,
+ self.env.parallel_env_num,
+ self.env.agent_num,
+ self.rnn_states_actor.shape[1],
+ self.rnn_states_actor.shape[2],
+ )
+
actions, self.rnn_states_actor = self.module.act(
obs=observation,
rnn_states_actor=self.rnn_states_actor,
diff --git a/openrl/modules/networks/policy_network.py b/openrl/modules/networks/policy_network.py
index 422eaa58..e3ebb025 100644
--- a/openrl/modules/networks/policy_network.py
+++ b/openrl/modules/networks/policy_network.py
@@ -56,6 +56,9 @@ def __init__(
self.use_half = use_half
self.tpdv = dict(dtype=torch.float32, device=device)
+ self._use_fp16 = cfg.use_fp16
+ assert not (cfg.use_fp16 and not cfg.use_deepspeed)
+
policy_obs_shape = get_policy_obs_space(input_space)
if "Dict" in policy_obs_shape.__class__.__name__:
@@ -135,8 +138,9 @@ def forward_original(
policy_obs[key].half()
else:
policy_obs = check(policy_obs, self.use_half, self.tpdv)
- # if self.use_half:
- # obs = obs.half()
+ if self.use_half or self._use_fp16:
+ policy_obs = policy_obs.half()
+
rnn_states = check(rnn_states, self.use_half, self.tpdv)
masks = check(masks, self.use_half, self.tpdv)
@@ -165,6 +169,8 @@ def eval_actions(
obs[key] = check(obs[key], self.use_half, self.tpdv)
else:
obs = check(obs, self.use_half, self.tpdv)
+ if self._use_fp16:
+ obs = obs.half()
rnn_states = check(rnn_states, self.use_half, self.tpdv)
action = check(action, self.use_half, self.tpdv)
@@ -202,6 +208,8 @@ def get_policy_values(self, obs, rnn_states, masks):
obs[key] = check(obs[key], self.use_half, self.tpdv)
else:
obs = check(obs).to(**self.tpdv)
+ if self.use_half or self._use_fp16:
+ obs = obs.half()
rnn_states = check(rnn_states, self.use_half, self.tpdv)
masks = check(masks, self.use_half, self.tpdv)
diff --git a/openrl/modules/networks/policy_network_gpt.py b/openrl/modules/networks/policy_network_gpt.py
new file mode 100644
index 00000000..193094a7
--- /dev/null
+++ b/openrl/modules/networks/policy_network_gpt.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2021 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from transformers.modeling_utils import unwrap_model
+
+from openrl.buffers.utils.util import get_policy_obs, get_policy_obs_space
+from openrl.envs.nlp.utils.distribution import CategoricalDistribution
+from openrl.modules.networks.base_policy_network import BasePolicyNetwork
+from openrl.modules.networks.utils.act import ACTLayer
+from openrl.modules.networks.utils.cnn import CNNBase
+from openrl.modules.networks.utils.mix import MIXBase
+from openrl.modules.networks.utils.mlp import MLPBase, MLPLayer
+from openrl.modules.networks.utils.popart import PopArt
+from openrl.modules.networks.utils.rnn import RNNLayer
+from openrl.modules.networks.utils.util import init
+from openrl.utils.util import check_v2 as check
+
+
+class PolicyNetworkGPT(BasePolicyNetwork):
+ def __init__(
+ self,
+ cfg,
+ input_space,
+ action_space,
+ device=torch.device("cpu"),
+ use_half=False,
+ disable_drop_out: bool = True,
+ extra_args=None,
+ ) -> None:
+ self.device = device
+ self.use_fp16 = cfg.use_fp16
+ self.use_deepspeed = cfg.use_deepspeed
+ self.use_half = False
+ self.use_data_parallel = not cfg.use_deepspeed # default to use data parallel
+ self.use_model_parallel = False
+
+ assert not (self.use_deepspeed and self.use_data_parallel)
+ assert not (self.use_deepspeed and self.use_model_parallel)
+ assert not (self.use_data_parallel and self.use_model_parallel)
+
+ super(PolicyNetworkGPT, self).__init__(cfg, device)
+
+ self.disable_drop_out = disable_drop_out
+
+ self._action_dist = CategoricalDistribution(action_space.n)
+
+ from transformers import AutoConfig, AutoModelForCausalLM
+
+ config = AutoConfig.from_pretrained(cfg.model_path)
+ config_dict = config.to_dict()
+ for key in config_dict:
+ if "drop" in key:
+ config_dict[key] = 0.0
+ config = config.from_dict(config_dict)
+ self._policy_model = AutoModelForCausalLM.from_pretrained(
+ cfg.model_path, config=config
+ )
+ self._policy_model.config.use_cache = False
+
+ if torch.cuda.is_available():
+ if self.use_model_parallel:
+ self._policy_model.parallelize()
+ elif self.use_data_parallel:
+ if self.use_half:
+ self._policy_model = self._policy_model.half()
+ self._policy_model = torch.nn.DataParallel(self._policy_model)
+ self._policy_model = self._policy_model.to(self.device)
+
+ def forward(self, forward_type, *args, **kwargs):
+ if forward_type == "original":
+ return self.forward_original(*args, **kwargs)
+ elif forward_type == "eval_actions":
+ return self.eval_actions(*args, **kwargs)
+ else:
+ raise NotImplementedError
+
+ def _prepare_inputs_for_model(
+ self,
+ model: Any,
+ input_ids: torch.tensor,
+ model_kwargs: Optional[Dict[str, torch.tensor]] = None,
+ ):
+ model_inputs = unwrap_model(model).prepare_inputs_for_generation(
+ input_ids, **model_kwargs
+ )
+
+ if self.use_model_parallel:
+ model_inputs = {
+ key: (
+ value.to(model.transformer.first_device)
+ if isinstance(value, torch.Tensor)
+ and hasattr(model.transformer, "first_device")
+ else value
+ )
+ for key, value in model_inputs.items()
+ }
+
+ return model_inputs
+
+ def forward_original(
+ self, raw_obs, rnn_states, masks, action_masks=None, deterministic=False
+ ):
+ for key in raw_obs.keys():
+ raw_obs[key] = (
+ torch.from_numpy(raw_obs[key])
+ if type(raw_obs[key]) == np.ndarray
+ else raw_obs[key]
+ )
+ rnn_states = check(rnn_states)
+
+ if self.use_half:
+ input_ids = raw_obs["input_encoded_pt"].int()
+ attention_mask = raw_obs["input_attention_mask_pt"].int()
+ else:
+ input_ids = raw_obs["input_encoded_pt"].long()
+ attention_mask = raw_obs["input_attention_mask_pt"].long()
+
+ for key in raw_obs.keys():
+ if self.use_data_parallel:
+ input_ids = input_ids.to(self.device)
+ attention_mask = attention_mask.to(self.device)
+ else:
+ input_ids = input_ids.to(self._policy_model.device)
+ attention_mask = attention_mask.to(self._policy_model.device)
+
+ past_model_kwargs = None
+
+ if past_model_kwargs is None:
+ past_model_kwargs = {
+ "attention_mask": attention_mask,
+ }
+
+ model_inputs = self._prepare_inputs_for_model(
+ self._policy_model, input_ids, past_model_kwargs
+ )
+
+ # forward pass to transformers
+ output = self._policy_model(**model_inputs)
+
+ # compute action probs - policy head
+ next_token_logits = output.logits[:, -1]
+ dist = self._action_dist.proba_distribution(action_logits=next_token_logits)
+
+ actions = dist.mode() if deterministic else dist.sample()
+ action_log_probs = dist.log_prob(actions)
+
+ return actions.unsqueeze(-1), action_log_probs.unsqueeze(-1), rnn_states
+
+ def eval_actions(
+ self, obs, rnn_states, action, masks, action_masks=None, active_masks=None
+ ):
+ for key in obs.keys():
+ obs[key] = (
+ torch.from_numpy(obs[key]) if type(obs[key]) == np.ndarray else obs[key]
+ )
+ if self.use_data_parallel:
+ obs[key] = obs[key].to(self.device)
+ else:
+ obs[key] = obs[key].to(self._policy_model.device)
+ if self.use_data_parallel:
+ action = check(action).to(self.device).squeeze()
+ else:
+ action = check(action).to(self._policy_model.device).squeeze()
+ rnn_states = check(rnn_states)
+
+ if self.half:
+ input_ids = obs["input_encoded_pt"].int()
+ attention_mask = obs["input_attention_mask_pt"].int()
+ else:
+ input_ids = obs["input_encoded_pt"].long()
+ attention_mask = obs["input_attention_mask_pt"].long()
+
+ past_model_kwargs = None
+
+ if past_model_kwargs is None:
+ past_model_kwargs = {
+ "attention_mask": attention_mask,
+ }
+
+ model_inputs = self._prepare_inputs_for_model(
+ self._policy_model, input_ids, past_model_kwargs
+ )
+
+ # forward pass to transformers
+ output = self._policy_model(**model_inputs)
+
+ # compute action probs - policy head
+ next_token_logits = output.logits[:, -1]
+ dist = self._action_dist.proba_distribution(action_logits=next_token_logits)
+
+ action_log_probs = dist.log_prob(action)
+ dist_entropy = dist.entropy()
+ values = None
+
+ return action_log_probs.unsqueeze(-1), dist_entropy.mean(), values
+
+ def get_policy_values(self, obs, rnn_states, masks):
+ raise NotImplementedError
diff --git a/openrl/modules/networks/policy_value_network_gpt.py b/openrl/modules/networks/policy_value_network_gpt.py
index e87e146b..85daef3a 100644
--- a/openrl/modules/networks/policy_value_network_gpt.py
+++ b/openrl/modules/networks/policy_value_network_gpt.py
@@ -37,6 +37,7 @@ def __init__(
self.disable_drop_out = disable_drop_out
self._use_valuenorm = cfg.use_valuenorm
super(CausalLMActorCriticPolicy, self).__init__(
+ cfg,
input_space,
action_space,
model_name=cfg.model_path,
@@ -45,6 +46,9 @@ def __init__(
self.use_half = use_half
self.tpdv = dict(dtype=torch.float32, device=device)
+ self._use_fp16 = cfg.use_fp16
+ assert not (cfg.use_fp16 and not cfg.use_deepspeed)
+
def get_actor_para(self):
return self._policy_model.parameters()
@@ -66,6 +70,8 @@ def get_actions(
):
for key in obs.keys():
obs[key] = check(obs[key], self.use_half, self.tpdv)
+ if self._use_fp16:
+ obs[key] = obs[key].half()
rnn_states = check(rnn_states, self.use_half, self.tpdv)
past_model_kwargs = None
@@ -83,6 +89,8 @@ def eval_actions(
):
for key in obs.keys():
obs[key] = check(obs[key], self.use_half, self.tpdv)
+ if self._use_fp16:
+ obs[key] = obs[key].half()
action = check(action, self.use_half, self.tpdv).squeeze()
eval_output = super().evaluate_actions(obs, action)
@@ -95,20 +103,11 @@ def eval_actions(
def get_values(self, obs, rnn_states, masks):
for key in obs.keys():
obs[key] = check(obs[key], self.use_half, self.tpdv)
+ if self._use_fp16:
+ obs[key] = obs[key].half()
rnn_states = check(rnn_states, self.use_half, self.tpdv)
value_output = super().forward_value(obs)
values = value_output.values
return values, rnn_states
-
- def get_log_probs_ref_model(self, obs, action):
- for key in obs.keys():
- obs[key] = check(obs[key], self.use_half, self.tpdv)
- action = check(action, self.use_half, self.tpdv)
- action = action.squeeze(-1)
-
- policy_output = super().get_log_probs_ref_model(obs, action)
- action_log_probs = policy_output.log_probs
-
- return action_log_probs.detach().cpu().numpy()
diff --git a/openrl/modules/networks/utils/attention.py b/openrl/modules/networks/utils/attention.py
index c00a3f24..a05a9a84 100644
--- a/openrl/modules/networks/utils/attention.py
+++ b/openrl/modules/networks/utils/attention.py
@@ -234,10 +234,13 @@ def forward(self, x, self_idx=-1):
K = self.split_shape[i][0]
L = self.split_shape[i][1]
for j in range(K):
- torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
- exec("x1.append(self.fc_{}(temp))".format(i))
- x[self_idx]
- exec("x1.append(self.fc_{}(temp))".format(N - 1))
+ # torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
+ # exec("x1.append(self.fc_{}(temp))".format(i))
+ temp = torch.cat((x[i][:, (L * j) : (L * j + L)], self_x), dim=-1)
+ x1.append(getattr(self, "fc_" + str(i))(temp))
+ x1.append(getattr(self, "fc_" + str(N - 1))(self_x))
+ # x[self_idx]
+ # exec("x1.append(self.fc_{}(temp))".format(N - 1))
out = torch.stack(x1, 1)
@@ -278,8 +281,10 @@ def forward(self, x, self_idx=None):
K = self.split_shape[i][0]
L = self.split_shape[i][1]
for j in range(K):
- x[i][:, (L * j) : (L * j + L)]
- exec("x1.append(self.fc_{}(temp))".format(i))
+ # x[i][:, (L * j) : (L * j + L)]
+ # exec("x1.append(self.fc_{}(temp))".format(i))
+ temp = x[i][:, (L * j) : (L * j + L)]
+ x1.append(getattr(self, "fc_" + str(i))(temp))
out = torch.stack(x1, 1)
diff --git a/openrl/modules/networks/utils/distributions.py b/openrl/modules/networks/utils/distributions.py
index 340015a4..fd3ef8ca 100644
--- a/openrl/modules/networks/utils/distributions.py
+++ b/openrl/modules/networks/utils/distributions.py
@@ -68,7 +68,7 @@ def init_(m):
def forward(self, x, action_masks=None):
x = self.linear(x)
if action_masks is not None:
- x[action_masks == 0] = -1e10
+ x[action_masks == 0] = -6e4 # fp16
return FixedCategorical(logits=x)
diff --git a/openrl/modules/networks/utils/nlp/base_policy.py b/openrl/modules/networks/utils/nlp/base_policy.py
index 9051b886..dd0e2032 100644
--- a/openrl/modules/networks/utils/nlp/base_policy.py
+++ b/openrl/modules/networks/utils/nlp/base_policy.py
@@ -124,13 +124,14 @@ class GenerationOutputs:
class LMActorCriticPolicy(nn.Module):
def __init__(
self,
+ cfg: Any,
observation_space: DictSpace,
action_space: Discrete,
model_name: str,
optimizer_kwargs: Dict[str, Any] = {},
weight_decay: float = 1e-6,
use_sde: bool = None,
- apply_model_parallel: bool = True,
+ # apply_model_parallel: bool = True,
optimizer_class: torch.optim.Optimizer = torch.optim.AdamW,
generation_kwargs: Dict[str, Any] = {},
prompt_truncation_side: str = "left",
@@ -146,16 +147,16 @@ def __init__(
optimizer_kwargs (Dict[str, Any], optional): optimizer kwargs. Defaults to {}.
weight_decay (float, optional): weight decay. Defaults to 1e-6.
use_sde (bool, optional): Use state-dependent exploration. Defaults to None.
- apply_model_parallel (bool, optional): whether to apply model parallel. Defaults to True.
+ apply_model_parallel (bool, optional): default to use model parallel when not using deepspeed.
optimizer_class (torch.optim.Optimizer, optional): Optimizer class. Defaults to torch.optim.AdamW.
generation_kwargs (Dict[str, Any], optional): generation parameters for rollout. Defaults to {}.
prompt_truncation_side (str, optional): truncation side for prompt text. Defaults to "left".
"""
super().__init__()
+ self._use_deepspeed = cfg.use_deepspeed
self._action_space = action_space
- self._apply_model_parallel = apply_model_parallel
+ self._apply_model_parallel = not cfg.use_deepspeed # TODO
self._build_model_heads(model_name, config, device)
- self._setup_optimizer(optimizer_kwargs, weight_decay, optimizer_class)
self._action_dist = CategoricalDistribution(self._action_space.n)
self._generation_kwargs = generation_kwargs
self._prompt_truncation_side = prompt_truncation_side
diff --git a/openrl/modules/networks/utils/nlp/causal_policy.py b/openrl/modules/networks/utils/nlp/causal_policy.py
index dedfc4aa..f0b86d0d 100644
--- a/openrl/modules/networks/utils/nlp/causal_policy.py
+++ b/openrl/modules/networks/utils/nlp/causal_policy.py
@@ -15,16 +15,13 @@
PolicyType,
ValueOutput,
)
-from openrl.modules.networks.utils.nlp.hf_generation_utils import (
- override_generation_routines,
- unwrap_generation_routines,
-)
from openrl.modules.utils.valuenorm import ValueNorm
class CausalLMActorCriticPolicy(LMActorCriticPolicy):
def __init__(
self,
+ cfg: Any,
observation_space: DictSpace,
action_space: Discrete,
model_name: str,
@@ -40,6 +37,7 @@ def __init__(
device: str = "cpu",
):
super().__init__(
+ cfg,
observation_space,
action_space,
model_name,
@@ -65,29 +63,37 @@ def load_from_dict(self, state_dict: dict = None):
@property
def policy(self):
policy_model = self._policy_model
- policy_model.__class__ = unwrap_generation_routines(type(policy_model))
return policy_model
def _build_model_heads(self, model_name: str, config: str, device: str):
if self.disable_drop_out:
- config = AutoConfig.from_pretrained(model_name)
+ if model_name == "test_gpt2":
+ from transformers import GPT2Config
+
+ config = GPT2Config()
+
+ else:
+ config = AutoConfig.from_pretrained(model_name)
config_dict = config.to_dict()
for key in config_dict:
if "drop" in key:
config_dict[key] = 0.0
config = config.from_dict(config_dict)
- self._policy_model = AutoModelForCausalLM.from_pretrained(
- model_name, config=config
- )
+ if model_name == "test_gpt2":
+ from transformers import GPT2LMHeadModel
- self._policy_model.__class__ = override_generation_routines(
- type(self._policy_model)
- )
+ self._policy_model = GPT2LMHeadModel(config)
+ self._value_model = GPT2LMHeadModel(config)
- self._value_model = AutoModelForCausalLM.from_pretrained(
- model_name, config=config
- )
+ else:
+ self._policy_model = AutoModelForCausalLM.from_pretrained(
+ model_name, config=config
+ )
+
+ self._value_model = AutoModelForCausalLM.from_pretrained(
+ model_name, config=config
+ )
self._value_head = nn.Linear(
self._value_model.config.hidden_size, 1, bias=False
@@ -99,7 +105,17 @@ def _build_model_heads(self, model_name: str, config: str, device: str):
torch.multiprocessing.set_sharing_strategy("file_system")
# apply model parallel
if torch.cuda.is_available():
- if self._apply_model_parallel and self._policy_model.is_parallelizable:
+ if self._use_deepspeed:
+ if self.value_normalizer is not None:
+ import deepspeed
+
+ para = self.value_normalizer.running_mean
+ deepspeed.zero.register_external_parameter(self, para)
+ para = self.value_normalizer.running_mean_sq
+ deepspeed.zero.register_external_parameter(self, para)
+ para = self.value_normalizer.debiasing_term
+ deepspeed.zero.register_external_parameter(self, para)
+ elif self._apply_model_parallel and self._policy_model.is_parallelizable:
self._policy_model.parallelize()
self._value_model.parallelize()
self._value_head = self._value_head.to(self.device)
@@ -126,17 +142,18 @@ def _prepare_inputs_for_model(
input_ids, **model_kwargs
)
- if self._apply_model_parallel and unwrap_model(model).is_parallelizable:
- # if model is in parallel mode, move the tensors to the first device
- model_inputs = {
- key: (
- value.to(model.transformer.first_device)
- if isinstance(value, torch.Tensor)
- and hasattr(model.transformer, "first_device")
- else value
- )
- for key, value in model_inputs.items()
- }
+ if not self._use_deepspeed:
+ if self._apply_model_parallel and unwrap_model(model).is_parallelizable:
+ # if model is in parallel mode, move the tensors to the first device
+ model_inputs = {
+ key: (
+ value.to(model.transformer.first_device)
+ if isinstance(value, torch.Tensor)
+ and hasattr(model.transformer, "first_device")
+ else value
+ )
+ for key, value in model_inputs.items()
+ }
return model_inputs
def forward_policy(
diff --git a/openrl/modules/networks/utils/nlp/hf_generation_utils.py b/openrl/modules/networks/utils/nlp/hf_generation_utils.py
deleted file mode 100644
index 37d80875..00000000
--- a/openrl/modules/networks/utils/nlp/hf_generation_utils.py
+++ /dev/null
@@ -1,4000 +0,0 @@
-# coding=utf-8
-# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
-# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import inspect
-import warnings
-from dataclasses import dataclass
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
-
-import torch
-import torch.distributed as dist
-from torch import nn
-from transformers.generation_beam_constraints import (
- Constraint,
- DisjunctiveConstraint,
- PhrasalConstraint,
-)
-from transformers.generation_beam_search import (
- BeamScorer,
- BeamSearchScorer,
- ConstrainedBeamSearchScorer,
-)
-from transformers.generation_logits_process import (
- EncoderNoRepeatNGramLogitsProcessor,
- ExponentialDecayLengthPenalty,
- ForcedBOSTokenLogitsProcessor,
- ForcedEOSTokenLogitsProcessor,
- HammingDiversityLogitsProcessor,
- InfNanRemoveLogitsProcessor,
- LogitsProcessorList,
- MinLengthLogitsProcessor,
- NoBadWordsLogitsProcessor,
- NoRepeatNGramLogitsProcessor,
- PrefixConstrainedLogitsProcessor,
- RepetitionPenaltyLogitsProcessor,
- TemperatureLogitsWarper,
- TopKLogitsWarper,
- TopPLogitsWarper,
- TypicalLogitsWarper,
-)
-from transformers.generation_stopping_criteria import (
- MaxLengthCriteria,
- MaxTimeCriteria,
- StoppingCriteria,
- StoppingCriteriaList,
- validate_stopping_criteria,
-)
-from transformers.generation_utils import GenerationMixin
-from transformers.pytorch_utils import torch_int_div
-from transformers.utils import ModelOutput, logging
-
-logger = logging.get_logger(__name__)
-
-
-@dataclass
-class GreedySearchDecoderOnlyOutput(ModelOutput):
- """
- Base class for outputs of decoder-only generation models using greedy search.
-
-
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each
- tensor of shape `(batch_size, config.vocab_size)`).
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
- """
-
- sequences: torch.LongTensor = None
- scores: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
-
-
-@dataclass
-class GreedySearchEncoderDecoderOutput(ModelOutput):
- """
- Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
- weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
- encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
-
-
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
- `(batch_size, config.vocab_size)`).
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
- sequence_length, sequence_length)`.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
- """
-
- sequences: torch.LongTensor = None
- scores: Optional[Tuple[torch.FloatTensor]] = None
- encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
-
-
-@dataclass
-class SampleDecoderOnlyOutput(ModelOutput):
- """
- Base class for outputs of decoder-only generation models using sampling.
-
-
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each
- tensor of shape `(batch_size*num_return_sequences, config.vocab_size)`).
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(num_return_sequences*batch_size, num_heads, generated_length,
- sequence_length)`.
- hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
- """
-
- sequences: torch.LongTensor = None
- scores: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
-
-
-@dataclass
-class SampleEncoderDecoderOutput(ModelOutput):
- """
- Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
- the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
- attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
-
-
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
- at each generation step. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
- `(batch_size*num_return_sequences, config.vocab_size)`).
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape
- `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size*num_return_sequences, sequence_length, hidden_size)`.
- decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_return_sequences, num_heads, generated_length,
- sequence_length)`.
- cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
- """
-
- sequences: torch.LongTensor = None
- scores: Optional[Tuple[torch.FloatTensor]] = None
- encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
-
-
-@dataclass
-class BeamSearchDecoderOnlyOutput(ModelOutput):
- """
- Base class for outputs of decoder-only generation models using beam search.
-
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Final beam scores of the generated `sequences`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
- of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
- `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
- beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
- tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
- hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
- """
-
- sequences: torch.LongTensor = None
- sequences_scores: Optional[torch.FloatTensor] = None
- scores: Optional[Tuple[torch.FloatTensor]] = None
- beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
- attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
-
-
-@dataclass
-class BeamSearchEncoderDecoderOutput(ModelOutput):
- """
- Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
- of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
- attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
-
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Final beam scores of the generated `sequences`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
- of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
- config.vocab_size)`).
- beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
- tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
- sequence_length, sequence_length)`.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
- decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
- sequence_length)`.
- cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
- """
-
- sequences: torch.LongTensor = None
- sequences_scores: Optional[torch.FloatTensor] = None
- scores: Optional[Tuple[torch.FloatTensor]] = None
- beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
- encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
-
-
-@dataclass
-class BeamSampleDecoderOnlyOutput(ModelOutput):
- """
- Base class for outputs of decoder-only generation models using beam sample.
-
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Final beam scores of the generated `sequences`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
- of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
- `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
- beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
- tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
- hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
- """
-
- sequences: torch.LongTensor = None
- sequences_scores: Optional[torch.FloatTensor] = None
- scores: Optional[Tuple[torch.FloatTensor]] = None
- beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
- attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
-
-
-@dataclass
-class BeamSampleEncoderDecoderOutput(ModelOutput):
- """
- Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
- weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
- encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
-
- Args:
- sequences (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
- The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
- if all batches finished early due to the `eos_token_id`.
- sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Final beam scores of the generated `sequences`.
- scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
- of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
- `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
- config.vocab_size)`).
- beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
- Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
- tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
- sequence_length, sequence_length)`.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size*num_beams, sequence_length, hidden_size)`.
- decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
- cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
- decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
- `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
- """
-
- sequences: torch.LongTensor = None
- sequences_scores: Optional[torch.FloatTensor] = None
- scores: Optional[Tuple[torch.FloatTensor]] = None
- beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
- encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
-
-
-GreedySearchOutput = Union[
- GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput
-]
-SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
-BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
-BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
-
-
-class GenerationMixinWithRawScores:
- """
- A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
-
- The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
- - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
- `do_sample=False`.
- - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
- `do_sample=True`.
- - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
- `do_sample=False`.
- - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
- `num_beams>1` and `do_sample=True`.
- - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
- `num_beams>1` and `num_beam_groups>1`.
- - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
- if `constraints!=None` or `force_words_ids!=None`.
- """
-
- def _prepare_model_inputs(
- self,
- inputs: Optional[torch.Tensor] = None,
- bos_token_id: Optional[int] = None,
- model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
- """
- This function extracts the model-specific `inputs` for generation.
- """
- # 1. retrieve all kwargs that are non-None or non-model input related.
- # some encoder-decoder models have different names for model and encoder
- if (
- self.config.is_encoder_decoder
- and hasattr(self, "encoder")
- and self.encoder.main_input_name != self.main_input_name
- ):
- input_name = self.encoder.main_input_name
- else:
- input_name = self.main_input_name
-
- model_kwargs = {
- k: v for k, v in model_kwargs.items() if v is not None or k != input_name
- }
-
- # 2. check whether model_input_name is passed as kwarg
- # if yes and `inputs` is None use kwarg inputs
- inputs_kwarg = model_kwargs.pop(input_name, None)
- if inputs_kwarg is not None and inputs is not None:
- raise ValueError(
- f"`inputs`: {inputs}` were passed alongside "
- f"{input_name} which is not allowed."
- f"Make sure to either pass {inputs} or {input_name}=..."
- )
- elif inputs_kwarg is not None:
- inputs = inputs_kwarg
-
- # 3. models with `input_ids` can also make use of `inputs_embeds`
- if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
- inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
-
- # 4. Only encoder-decoder models can have non `input_ids` input format
- if not self.config.is_encoder_decoder and input_name != "input_ids":
- raise ValueError(
- f"If {input_name} is passed as model-specific keyword "
- "input then model has to be an encoder-decoder and not a "
- f"{self.__class__.__name__}."
- )
-
- # 5. if `inputs` is still None, try to create `input_ids` from BOS token
- if inputs is None:
- inputs = self._prepare_input_ids_for_generation(
- bos_token_id, model_kwargs.get("encoder_outputs")
- )
-
- return inputs, input_name, model_kwargs
-
- def _can_retrieve_inputs_from_name(
- self,
- inputs: Optional[torch.Tensor],
- name: str,
- model_kwargs: Dict[str, torch.Tensor],
- ) -> torch.Tensor:
- """
- If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
- from name
- """
- can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
- inspect.signature(self.forward).parameters.keys()
- )
-
- if can_retrieve_inputs and inputs is not None:
- raise ValueError(
- f"Cannot only pass one of {name} and {self.main_input_name}"
- )
-
- return can_retrieve_inputs
-
- def prepare_inputs_for_generation(
- self, input_ids: torch.LongTensor, **kwargs
- ) -> Dict[str, Any]:
- """
- Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
- """
- return {"input_ids": input_ids}
-
- def adjust_logits_during_generation(
- self, logits: torch.FloatTensor, **kwargs
- ) -> torch.FloatTensor:
- """
- Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
- """
- return logits
-
- def _prepare_input_ids_for_generation(
- self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
- ) -> torch.LongTensor:
- if self.config.is_encoder_decoder and encoder_outputs is not None:
- # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
- shape = encoder_outputs.last_hidden_state.size()[:-1]
- return torch.ones(shape, dtype=torch.long, device=self.device) * -100
-
- if bos_token_id is None:
- raise ValueError(
- "`bos_token_id` has to be defined when no `input_ids` are provided."
- )
- return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
-
- def _prepare_attention_mask_for_generation(
- self,
- inputs: torch.Tensor,
- pad_token_id: int,
- eos_token_id: int,
- ) -> torch.LongTensor:
- is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [
- torch.int,
- torch.long,
- ]
- is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
- is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
- (eos_token_id is not None) and (pad_token_id != eos_token_id)
- )
- # Check if input is input_ids and padded -> only then is attention_mask defined
- if (
- is_input_ids
- and is_pad_token_in_inputs
- and is_pad_token_not_equal_to_eos_token_id
- ):
- return inputs.ne(pad_token_id).long()
- else:
- return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device)
-
- def _prepare_encoder_decoder_kwargs_for_generation(
- self,
- inputs_tensor: torch.Tensor,
- model_kwargs,
- model_input_name: Optional[str] = None,
- ) -> Dict[str, Any]:
- # 1. get encoder
- encoder = self.get_encoder()
-
- # 2. prepare encoder args and encoder kwargs from model kwargs
- irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
- encoder_kwargs = {
- argument: value
- for argument, value in model_kwargs.items()
- if not any(argument.startswith(p) for p in irrelevant_prefix)
- }
-
- # 3. make sure that encoder returns `ModelOutput`
- model_input_name = (
- model_input_name if model_input_name is not None else self.main_input_name
- )
- encoder_kwargs["return_dict"] = True
- encoder_kwargs[model_input_name] = inputs_tensor
- model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
-
- return model_kwargs
-
- def _prepare_decoder_input_ids_for_generation(
- self,
- batch_size: int,
- decoder_start_token_id: int = None,
- bos_token_id: int = None,
- model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- ) -> torch.LongTensor:
- if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
- return model_kwargs.pop("decoder_input_ids")
- else:
- decoder_start_token_id = self._get_decoder_start_token_id(
- decoder_start_token_id, bos_token_id
- )
- return (
- torch.ones((batch_size, 1), dtype=torch.long, device=self.device)
- * decoder_start_token_id
- )
-
- def _get_decoder_start_token_id(
- self, decoder_start_token_id: int = None, bos_token_id: int = None
- ) -> int:
- decoder_start_token_id = (
- decoder_start_token_id
- if decoder_start_token_id is not None
- else self.config.decoder_start_token_id
- )
- bos_token_id = (
- bos_token_id if bos_token_id is not None else self.config.bos_token_id
- )
-
- if decoder_start_token_id is not None:
- return decoder_start_token_id
- elif (
- hasattr(self.config, "decoder")
- and hasattr(self.config.decoder, "decoder_start_token_id")
- and self.config.decoder.decoder_start_token_id is not None
- ):
- return self.config.decoder.decoder_start_token_id
- elif bos_token_id is not None:
- return bos_token_id
- elif (
- hasattr(self.config, "decoder")
- and hasattr(self.config.decoder, "bos_token_id")
- and self.config.decoder.bos_token_id is not None
- ):
- return self.config.decoder.bos_token_id
- raise ValueError(
- "`decoder_start_token_id` or `bos_token_id` has to be defined for"
- " encoder-decoder generation."
- )
-
- @staticmethod
- def _expand_inputs_for_generation(
- input_ids: torch.LongTensor,
- expand_size: int = 1,
- is_encoder_decoder: bool = False,
- attention_mask: Optional[torch.LongTensor] = None,
- encoder_outputs: Optional[ModelOutput] = None,
- **model_kwargs,
- ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
- expanded_return_idx = (
- torch.arange(input_ids.shape[0])
- .view(-1, 1)
- .repeat(1, expand_size)
- .view(-1)
- .to(input_ids.device)
- )
- input_ids = input_ids.index_select(0, expanded_return_idx)
-
- if "token_type_ids" in model_kwargs:
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = token_type_ids.index_select(
- 0, expanded_return_idx
- )
-
- if attention_mask is not None:
- model_kwargs["attention_mask"] = attention_mask.index_select(
- 0, expanded_return_idx
- )
-
- if is_encoder_decoder:
- if encoder_outputs is None:
- raise ValueError(
- "If `is_encoder_decoder` is True, make sure that `encoder_outputs`"
- " is defined."
- )
- encoder_outputs["last_hidden_state"] = (
- encoder_outputs.last_hidden_state.index_select(
- 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
- )
- )
- model_kwargs["encoder_outputs"] = encoder_outputs
- return input_ids, model_kwargs
-
- @staticmethod
- def _update_model_kwargs_for_generation(
- outputs: ModelOutput,
- model_kwargs: Dict[str, Any],
- is_encoder_decoder: bool = False,
- ) -> Dict[str, Any]:
- # update past
- if "past_key_values" in outputs:
- model_kwargs["past"] = outputs.past_key_values
- elif "mems" in outputs:
- model_kwargs["past"] = outputs.mems
- elif "past_buckets_states" in outputs:
- model_kwargs["past"] = outputs.past_buckets_states
- else:
- model_kwargs["past"] = None
-
- # update token_type_ids with last value
- if "token_type_ids" in model_kwargs:
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = torch.cat(
- [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
- )
-
- # update attention mask
- if not is_encoder_decoder:
- if "attention_mask" in model_kwargs:
- attention_mask = model_kwargs["attention_mask"]
- model_kwargs["attention_mask"] = torch.cat(
- [
- attention_mask,
- attention_mask.new_ones((attention_mask.shape[0], 1)),
- ],
- dim=-1,
- )
-
- return model_kwargs
-
- def _reorder_cache(self, past, beam_idx):
- raise NotImplementedError(
- "Make sure that a `_reorder_cache` function is correctly implemented in"
- f" {self.__class__.__module__} to enable beam search for {self.__class__}"
- )
-
- def _get_logits_warper(
- self,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- typical_p: Optional[float] = None,
- temperature: Optional[float] = None,
- num_beams: Optional[int] = None,
- ) -> LogitsProcessorList:
- """
- This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
- used for multinomial sampling.
- """
-
- # init warp parameters
- top_k = top_k if top_k is not None else self.config.top_k
- top_p = top_p if top_p is not None else self.config.top_p
- typical_p = typical_p if typical_p is not None else self.config.typical_p
- temperature = (
- temperature if temperature is not None else self.config.temperature
- )
- # instantiate warpers list
- warpers = LogitsProcessorList()
-
- # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
- # all samplers can be found in `generation_utils_samplers.py`
- if temperature is not None and temperature != 1.0:
- warpers.append(TemperatureLogitsWarper(temperature))
- if top_k is not None and top_k != 0:
- warpers.append(
- TopKLogitsWarper(
- top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)
- )
- )
- if top_p is not None and top_p < 1.0:
- warpers.append(
- TopPLogitsWarper(
- top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)
- )
- )
- if typical_p is not None and typical_p < 1.0:
- warpers.append(
- TypicalLogitsWarper(
- mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)
- )
- )
- return warpers
-
- def _get_logits_processor(
- self,
- repetition_penalty: float,
- no_repeat_ngram_size: int,
- encoder_no_repeat_ngram_size: int,
- input_ids_seq_length: int,
- encoder_input_ids: torch.LongTensor,
- bad_words_ids: List[List[int]],
- min_length: int,
- max_length: int,
- eos_token_id: int,
- forced_bos_token_id: int,
- forced_eos_token_id: int,
- prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
- num_beams: int,
- num_beam_groups: int,
- diversity_penalty: float,
- remove_invalid_values: bool,
- exponential_decay_length_penalty: Tuple,
- logits_processor: Optional[LogitsProcessorList],
- ) -> LogitsProcessorList:
- """
- This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
- instances used to modify the scores of the language model head.
- """
- processors = LogitsProcessorList()
-
- # init warp parameters
- repetition_penalty = (
- repetition_penalty
- if repetition_penalty is not None
- else self.config.repetition_penalty
- )
- no_repeat_ngram_size = (
- no_repeat_ngram_size
- if no_repeat_ngram_size is not None
- else self.config.no_repeat_ngram_size
- )
- encoder_no_repeat_ngram_size = (
- encoder_no_repeat_ngram_size
- if encoder_no_repeat_ngram_size is not None
- else self.config.encoder_no_repeat_ngram_size
- )
- bad_words_ids = (
- bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
- )
- min_length = min_length if min_length is not None else self.config.min_length
- eos_token_id = (
- eos_token_id if eos_token_id is not None else self.config.eos_token_id
- )
- diversity_penalty = (
- diversity_penalty
- if diversity_penalty is not None
- else self.config.diversity_penalty
- )
- forced_bos_token_id = (
- forced_bos_token_id
- if forced_bos_token_id is not None
- else self.config.forced_bos_token_id
- )
- forced_eos_token_id = (
- forced_eos_token_id
- if forced_eos_token_id is not None
- else self.config.forced_eos_token_id
- )
- remove_invalid_values = (
- remove_invalid_values
- if remove_invalid_values is not None
- else self.config.remove_invalid_values
- )
- exponential_decay_length_penalty = (
- exponential_decay_length_penalty
- if exponential_decay_length_penalty is not None
- else self.config.exponential_decay_length_penalty
- )
- # instantiate processors list
-
- # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
- # all samplers can be found in `generation_utils_samplers.py`
- if diversity_penalty is not None and diversity_penalty > 0.0:
- processors.append(
- HammingDiversityLogitsProcessor(
- diversity_penalty=diversity_penalty,
- num_beams=num_beams,
- num_beam_groups=num_beam_groups,
- )
- )
- if repetition_penalty is not None and repetition_penalty != 1.0:
- processors.append(
- RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
- )
- if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
- processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
- if (
- encoder_no_repeat_ngram_size is not None
- and encoder_no_repeat_ngram_size > 0
- ):
- if self.config.is_encoder_decoder:
- processors.append(
- EncoderNoRepeatNGramLogitsProcessor(
- encoder_no_repeat_ngram_size, encoder_input_ids
- )
- )
- else:
- raise ValueError(
- "It's impossible to use `encoder_no_repeat_ngram_size` with"
- " decoder-only architecture"
- )
- if bad_words_ids is not None:
- processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
- if min_length is not None and eos_token_id is not None and min_length > 0:
- processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
- if prefix_allowed_tokens_fn is not None:
- processors.append(
- PrefixConstrainedLogitsProcessor(
- prefix_allowed_tokens_fn, num_beams // num_beam_groups
- )
- )
- if forced_bos_token_id is not None:
- processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
- if forced_eos_token_id is not None:
- processors.append(
- ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
- )
- if remove_invalid_values is True:
- processors.append(InfNanRemoveLogitsProcessor())
- if exponential_decay_length_penalty is not None:
- processors.append(
- ExponentialDecayLengthPenalty(
- exponential_decay_length_penalty, eos_token_id, input_ids_seq_length
- )
- )
- processors = self._merge_criteria_processor_list(processors, logits_processor)
- return processors
-
- def _get_stopping_criteria(
- self,
- max_length: Optional[int],
- max_time: Optional[float],
- stopping_criteria: Optional[StoppingCriteriaList],
- ) -> StoppingCriteriaList:
- criteria = StoppingCriteriaList()
- if max_length is not None:
- criteria.append(MaxLengthCriteria(max_length=max_length))
- if max_time is not None:
- criteria.append(MaxTimeCriteria(max_time=max_time))
- criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
- return criteria
-
- def _merge_criteria_processor_list(
- self,
- default_list: Union[LogitsProcessorList, StoppingCriteriaList],
- custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
- ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
- if len(custom_list) == 0:
- return default_list
- for default in default_list:
- for custom in custom_list:
- if type(custom) is type(default):
- object_type = (
- "stopping criteria"
- if isinstance(custom, StoppingCriteria)
- else "logits processor"
- )
- raise ValueError(
- f"A custom {object_type} of type {type(custom)} with values"
- f" {custom} has been passed to `generate`, but it has already"
- f" been created with the values {default}. {default} has been"
- " created by passing the corresponding arguments to generate"
- " or by the model's config default values. If you just want to"
- f" change the default values of {object_type} consider passing"
- " them as arguments to `generate` instead of using a custom"
- f" {object_type}."
- )
- default_list.extend(custom_list)
- return default_list
-
- def compute_beam_search_raw_logits(
- self,
- sequences: torch.Tensor,
- scores: Tuple[torch.Tensor],
- beam_indices: torch.Tensor,
- eos_token_id: int = None,
- ):
- """Compute raw logits for beam search"""
-
- if not self.config.is_encoder_decoder:
- raise NotImplementedError(
- "Beam Search raw logits code is implemented only for enoder-decoder"
- " only models"
- )
-
- # since sequences can be shorter than scores (probably due to beam search finalization)
- # we always have to generate raw_logits only for generated sequences
- # cut off the start tokens from generated
- sequences = sequences.clone()
- sequences = sequences[:, 1:]
- gen_steps = sequences.shape[1]
-
- # align scores and beam indices according to gen_steps
- # scores(gen_steps x(batch_size * num_beams) x vocab_size)
- scores = scores[:gen_steps]
- scores = torch.stack(scores)
- _, _, vocab_size = scores.shape
-
- beam_indices = torch.tensor(beam_indices).T.to(scores.device)
- beam_indices = beam_indices[:gen_steps, :]
- batch_size = beam_indices.shape[1]
-
- # gen_steps x batch_size x vocab_size
- beam_indices = beam_indices.unsqueeze(-1).repeat(1, 1, vocab_size)
- step_wise_logits = scores.gather(dim=1, index=beam_indices)
- assert step_wise_logits.shape == torch.Size((gen_steps, batch_size, vocab_size))
-
- # finally convert to tuples
- step_wise_logits = [(step_wise_logits[t], None) for t in range(gen_steps)]
- return step_wise_logits
-
- @torch.no_grad()
- def generate(
- self,
- inputs: Optional[torch.Tensor] = None,
- max_length: Optional[int] = None,
- min_length: Optional[int] = None,
- do_sample: Optional[bool] = None,
- early_stopping: Optional[bool] = None,
- num_beams: Optional[int] = None,
- temperature: Optional[float] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- typical_p: Optional[float] = None,
- repetition_penalty: Optional[float] = None,
- bad_words_ids: Optional[Iterable[int]] = None,
- force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
- bos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- length_penalty: Optional[float] = None,
- no_repeat_ngram_size: Optional[int] = None,
- encoder_no_repeat_ngram_size: Optional[int] = None,
- num_return_sequences: Optional[int] = None,
- max_time: Optional[float] = None,
- max_new_tokens: Optional[int] = None,
- decoder_start_token_id: Optional[int] = None,
- use_cache: Optional[bool] = None,
- num_beam_groups: Optional[int] = None,
- diversity_penalty: Optional[float] = None,
- prefix_allowed_tokens_fn: Optional[
- Callable[[int, torch.Tensor], List[int]]
- ] = None,
- logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
- stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
- constraints: Optional[List[Constraint]] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- forced_bos_token_id: Optional[int] = None,
- forced_eos_token_id: Optional[int] = None,
- remove_invalid_values: Optional[bool] = None,
- synced_gpus: Optional[bool] = False,
- exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
- **model_kwargs,
- ) -> Union[
- GreedySearchOutput,
- SampleOutput,
- BeamSearchOutput,
- BeamSampleOutput,
- torch.LongTensor,
- ]:
- r"""
-
- Generates sequences of token ids for models with a language modeling head. The method supports the following
- generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
-
- - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
- `do_sample=False`.
- - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
- `do_sample=True`.
- - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
- `do_sample=False`.
- - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
- `num_beams>1` and `do_sample=True`.
- - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
- `num_beams>1` and `num_beam_groups>1`.
- - *constrained beam-search decoding* by calling
- [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
- `force_words_ids!=None`.
-
-
-
- Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
- defined in the model's config (`config.json`) which in turn defaults to the
- [`~modeling_utils.PretrainedConfig`] of the model.
-
-
-
- Most of these parameters are explained in more detail in [this blog
- post](https://huggingface.co/blog/how-to-generate).
-
- Parameters:
- inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
- The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
- method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
- should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
- `input_ids`, `input_values`, `input_features`, or `pixel_values`.
- max_length (`int`, *optional*, defaults to `model.config.max_length`):
- The maximum length of the sequence to be generated.
- max_new_tokens (`int`, *optional*, defaults to None):
- The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
- `max_new_tokens` or `max_length` but not both, they serve the same purpose.
- min_length (`int`, *optional*, defaults to 10):
- The minimum length of the sequence to be generated.
- do_sample (`bool`, *optional*, defaults to `False`):
- Whether or not to use sampling ; use greedy decoding otherwise.
- early_stopping (`bool`, *optional*, defaults to `False`):
- Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
- num_beams (`int`, *optional*, defaults to 1):
- Number of beams for beam search. 1 means no beam search.
- temperature (`float`, *optional*, defaults to 1.0):
- The value used to module the next token probabilities.
- top_k (`int`, *optional*, defaults to 50):
- The number of highest probability vocabulary tokens to keep for top-k-filtering.
- top_p (`float`, *optional*, defaults to 1.0):
- If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
- are kept for generation.
- repetition_penalty (`float`, *optional*, defaults to 1.0):
- The parameter for repetition penalty. 1.0 means no penalty. See [this
- paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- bos_token_id (`int`, *optional*):
- The id of the *beginning-of-sequence* token.
- eos_token_id (`int`, *optional*):
- The id of the *end-of-sequence* token.
- length_penalty (`float`, *optional*, defaults to 1.0):
- Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
- model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
- sequences.
- no_repeat_ngram_size (`int`, *optional*, defaults to 0):
- If set to int > 0, all ngrams of that size can only occur once.
- encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
- If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
- `decoder_input_ids`.
- bad_words_ids(`List[List[int]]`, *optional*):
- List of token ids that are not allowed to be generated. In order to get the token ids of the words that
- should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
- add_special_tokens=False).input_ids`.
- force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
- List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
- list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
- this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
- where one can allow different forms of each word.
- num_return_sequences(`int`, *optional*, defaults to 1):
- The number of independently computed returned sequences for each element in the batch.
- max_time(`float`, *optional*, defaults to None):
- The maximum amount of time you allow the computation to run for in seconds. generation will still
- finish the current pass after allocated time has been passed.
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
- that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
- as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
- decoder_start_token_id (`int`, *optional*):
- If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
- use_cache: (`bool`, *optional*, defaults to `True`):
- Whether or not the model should use the past last key/values attentions (if applicable to the model) to
- speed up decoding.
- num_beam_groups (`int`, *optional*, defaults to 1):
- Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
- beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
- diversity_penalty (`float`, *optional*, defaults to 0.0):
- This value is subtracted from a beam's score if it generates a token same as any beam from other group
- at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
- enabled.
- prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
- If provided, this function constraints the beam search to allowed tokens only at each step. If not
- provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
- `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
- on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
- for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
- Retrieval](https://arxiv.org/abs/2010.00904).
- logits_processor (`LogitsProcessorList`, *optional*):
- Custom logits processors that complement the default logits processors built from arguments and a
- model's config. If a logit processor is passed that is already created with the arguments or a model's
- config an error is thrown. This feature is intended for advanced users.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- Custom stopping criteria that complement the default stopping criteria built from arguments and a
- model's config. If a stopping criteria is passed that is already created with the arguments or a
- model's config an error is thrown. This feature is intended for advanced users.
- constraints (`List[Constraint]`, *optional*):
- Custom constraints that can be added to the generation to ensure that the output will contain the use
- of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- forced_bos_token_id (`int`, *optional*):
- The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
- for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
- the target language token.
- forced_eos_token_id (`int`, *optional*):
- The id of the token to force as the last generated token when `max_length` is reached.
- remove_invalid_values (`bool`, *optional*):
- Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
- crash. Note that using `remove_invalid_values` can slow down generation.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
- This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
- generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
- where penalty starts and `decay_factor` represents the factor of exponential decay
-
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
- is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
- should be prefixed with *decoder_*.
-
- Return:
- [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
- or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
-
- If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
- [`~utils.ModelOutput`] types are:
-
- - [`~generation_utils.GreedySearchDecoderOnlyOutput`],
- - [`~generation_utils.SampleDecoderOnlyOutput`],
- - [`~generation_utils.BeamSearchDecoderOnlyOutput`],
- - [`~generation_utils.BeamSampleDecoderOnlyOutput`]
-
- If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
- [`~utils.ModelOutput`] types are:
-
- - [`~generation_utils.GreedySearchEncoderDecoderOutput`],
- - [`~generation_utils.SampleEncoderDecoderOutput`],
- - [`~generation_utils.BeamSearchEncoderDecoderOutput`],
- - [`~generation_utils.BeamSampleEncoderDecoderOutput`]
-
- Examples:
-
- Greedy Decoding:
-
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM
-
- >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
-
- >>> prompt = "Today I believe we can finally"
- >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
-
- >>> # generate up to 30 tokens
- >>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
- ```
-
- Multinomial Sampling:
-
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM
- >>> import torch
-
- >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
-
- >>> prompt = "Today I believe we can finally"
- >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
-
- >>> # sample up to 30 tokens
- >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
- >>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
- ```
-
- Beam-search decoding:
-
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
-
- >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
- >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
-
- >>> sentence = "Paris is one of the densest populated areas in Europe."
- >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
-
- >>> outputs = model.generate(input_ids)
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
- ```"""
- # 1. Set generation parameters if not already defined
- bos_token_id = (
- bos_token_id if bos_token_id is not None else self.config.bos_token_id
- )
- num_beams = num_beams if num_beams is not None else self.config.num_beams
- length_penalty = (
- length_penalty if length_penalty is not None else self.config.length_penalty
- )
- early_stopping = (
- early_stopping if early_stopping is not None else self.config.early_stopping
- )
- num_beam_groups = (
- num_beam_groups
- if num_beam_groups is not None
- else self.config.num_beam_groups
- )
- do_sample = do_sample if do_sample is not None else self.config.do_sample
- num_return_sequences = (
- num_return_sequences
- if num_return_sequences is not None
- else self.config.num_return_sequences
- )
-
- pad_token_id = (
- pad_token_id if pad_token_id is not None else self.config.pad_token_id
- )
- eos_token_id = (
- eos_token_id if eos_token_id is not None else self.config.eos_token_id
- )
-
- if eos_token_id is None and hasattr(self.config, "decoder"):
- eos_token_id = self.config.decoder.eos_token_id
-
- if pad_token_id is None and eos_token_id is not None:
- # special case if pad_token_id is not defined
- # logger.warning(
- # f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
-
- pad_token_id = eos_token_id
-
- output_scores = (
- output_scores if output_scores is not None else self.config.output_scores
- )
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.config.return_dict_in_generate
- )
-
- # 2. Define model inputs
- # inputs_tensor has to be defined
- # model_input_name is defined if model-specific keyword input is passed
- # otherwise model_input_name is None
- # all model-specific keyword inputs are removed from `model_kwargs`
- inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
- inputs, bos_token_id, model_kwargs
- )
- batch_size = inputs_tensor.shape[0]
-
- # 3. Define other model kwargs
- model_kwargs["output_attentions"] = output_attentions
- model_kwargs["output_hidden_states"] = output_hidden_states
- model_kwargs["use_cache"] = use_cache
-
- accepts_attention_mask = "attention_mask" in set(
- inspect.signature(self.forward).parameters.keys()
- )
- requires_attention_mask = "encoder_outputs" not in model_kwargs
-
- if (
- model_kwargs.get("attention_mask", None) is None
- and requires_attention_mask
- and accepts_attention_mask
- ):
- model_kwargs["attention_mask"] = (
- self._prepare_attention_mask_for_generation(
- inputs_tensor, pad_token_id, eos_token_id
- )
- )
-
- if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
- # if model is encoder decoder encoder_outputs are created
- # and added to `model_kwargs`
- model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
- inputs_tensor, model_kwargs, model_input_name
- )
-
- # 4. Prepare `input_ids` which will be used for auto-regressive generation
- if self.config.is_encoder_decoder:
- input_ids = self._prepare_decoder_input_ids_for_generation(
- batch_size,
- decoder_start_token_id=decoder_start_token_id,
- bos_token_id=bos_token_id,
- model_kwargs=model_kwargs,
- )
- else:
- # if decoder-only then inputs_tensor has to be `input_ids`
- input_ids = inputs_tensor
-
- input_ids_seq_length = input_ids.shape[-1]
-
- # 5. Prepare `max_length` depending on other stopping criteria
- # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
- if max_length is None and max_new_tokens is not None:
- max_length = max_new_tokens + input_ids_seq_length
- elif max_length is not None and max_new_tokens is not None:
- # Both are set, this is odd, raise a warning
- warnings.warn(
- (
- "Both `max_length` and `max_new_tokens` have been set "
- f"but they serve the same purpose. `max_length` {max_length} "
- f"will take priority over `max_new_tokens` {max_new_tokens}."
- ),
- UserWarning,
- )
- # default to config if still None
- max_length = max_length if max_length is not None else self.config.max_length
-
- if input_ids_seq_length >= max_length:
- input_ids_string = (
- "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
- )
- logger.warning(
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but"
- f" ``max_length`` is set to {max_length}. This can lead to unexpected"
- " behavior. You should consider increasing ``config.max_length`` or"
- " ``max_length``."
- )
-
- # 6. determine generation mode
- is_constraint_gen_mode = constraints is not None or force_words_ids is not None
- is_greedy_gen_mode = (
- (num_beams == 1)
- and (num_beam_groups == 1)
- and do_sample is False
- and not is_constraint_gen_mode
- )
- is_sample_gen_mode = (
- (num_beams == 1)
- and (num_beam_groups == 1)
- and do_sample is True
- and not is_constraint_gen_mode
- )
- is_beam_gen_mode = (
- (num_beams > 1)
- and (num_beam_groups == 1)
- and do_sample is False
- and not is_constraint_gen_mode
- )
- is_beam_sample_gen_mode = (
- (num_beams > 1)
- and (num_beam_groups == 1)
- and do_sample is True
- and not is_constraint_gen_mode
- )
- is_group_beam_gen_mode = (
- (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
- )
-
- if num_beam_groups > num_beams:
- raise ValueError(
- "`num_beam_groups` has to be smaller or equal to `num_beams`"
- )
- if is_group_beam_gen_mode and do_sample is True:
- raise ValueError(
- "Diverse beam search cannot be used in sampling mode. Make sure that"
- " `do_sample` is set to `False`."
- )
-
- # 7. prepare distribution pre_processing samplers
- logits_processor = self._get_logits_processor(
- repetition_penalty=repetition_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
- input_ids_seq_length=input_ids_seq_length,
- encoder_input_ids=inputs_tensor,
- bad_words_ids=bad_words_ids,
- min_length=min_length,
- max_length=max_length,
- eos_token_id=eos_token_id,
- forced_bos_token_id=forced_bos_token_id,
- forced_eos_token_id=forced_eos_token_id,
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
- num_beams=num_beams,
- num_beam_groups=num_beam_groups,
- diversity_penalty=diversity_penalty,
- remove_invalid_values=remove_invalid_values,
- exponential_decay_length_penalty=exponential_decay_length_penalty,
- logits_processor=logits_processor,
- )
-
- # 8. prepare stopping criteria
- stopping_criteria = self._get_stopping_criteria(
- max_length=max_length,
- max_time=max_time,
- stopping_criteria=stopping_criteria,
- )
-
- # 9. go into different generation modes
- if is_greedy_gen_mode:
- if num_return_sequences > 1:
- raise ValueError(
- "num_return_sequences has to be 1, but is"
- f" {num_return_sequences} when doing greedy search."
- )
-
- # 10. run greedy search
- return self.greedy_search(
- input_ids,
- logits_processor=logits_processor,
- stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
-
- elif is_sample_gen_mode:
- # 10. prepare logits warper
- logits_warper = self._get_logits_warper(
- top_k=top_k,
- top_p=top_p,
- typical_p=typical_p,
- temperature=temperature,
- num_beams=num_beams,
- )
-
- # 11. expand input_ids with `num_return_sequences` additional sequences per batch
- input_ids, model_kwargs = self._expand_inputs_for_generation(
- input_ids,
- expand_size=num_return_sequences,
- is_encoder_decoder=self.config.is_encoder_decoder,
- **model_kwargs,
- )
-
- # 12. run sample
- return self.sample(
- input_ids,
- logits_processor=logits_processor,
- logits_warper=logits_warper,
- stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
-
- elif is_beam_gen_mode:
- if num_return_sequences > num_beams:
- raise ValueError(
- "`num_return_sequences` has to be smaller or equal to `num_beams`."
- )
-
- if stopping_criteria.max_length is None:
- raise ValueError(
- "`max_length` needs to be a stopping_criteria for now."
- )
-
- # 10. prepare beam search scorer
- beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- num_beams=num_beams,
- device=self.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
- num_beam_hyps_to_keep=num_return_sequences,
- )
- # 11. interleave input_ids with `num_beams` additional sequences per batch
- input_ids, model_kwargs = self._expand_inputs_for_generation(
- input_ids,
- expand_size=num_beams,
- is_encoder_decoder=self.config.is_encoder_decoder,
- **model_kwargs,
- )
- # 12. run beam search
- return self.beam_search(
- input_ids,
- beam_scorer,
- logits_processor=logits_processor,
- stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
-
- elif is_beam_sample_gen_mode:
- # 10. prepare logits warper
- logits_warper = self._get_logits_warper(
- top_k=top_k,
- top_p=top_p,
- typical_p=typical_p,
- temperature=temperature,
- num_beams=num_beams,
- )
-
- if stopping_criteria.max_length is None:
- raise ValueError(
- "`max_length` needs to be a stopping_criteria for now."
- )
- # 11. prepare beam search scorer
- beam_scorer = BeamSearchScorer(
- batch_size=batch_size * num_return_sequences,
- num_beams=num_beams,
- device=self.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
- )
-
- # 12. interleave input_ids with `num_beams` additional sequences per batch
- input_ids, model_kwargs = self._expand_inputs_for_generation(
- input_ids,
- expand_size=num_beams * num_return_sequences,
- is_encoder_decoder=self.config.is_encoder_decoder,
- **model_kwargs,
- )
-
- # 13. run beam sample
- return self.beam_sample(
- input_ids,
- beam_scorer,
- logits_processor=logits_processor,
- logits_warper=logits_warper,
- stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
-
- elif is_group_beam_gen_mode:
- if num_return_sequences > num_beams:
- raise ValueError(
- "`num_return_sequences` has to be smaller or equal to `num_beams`."
- )
-
- if num_beams % num_beam_groups != 0:
- raise ValueError(
- "`num_beams` should be divisible by `num_beam_groups` for group"
- " beam search."
- )
-
- if stopping_criteria.max_length is None:
- raise ValueError(
- "`max_length` needs to be a stopping_criteria for now."
- )
-
- # 10. prepare beam search scorer
- beam_scorer = BeamSearchScorer(
- batch_size=batch_size,
- num_beams=num_beams,
- max_length=stopping_criteria.max_length,
- device=self.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
- num_beam_hyps_to_keep=num_return_sequences,
- num_beam_groups=num_beam_groups,
- )
- # 11. interleave input_ids with `num_beams` additional sequences per batch
- input_ids, model_kwargs = self._expand_inputs_for_generation(
- input_ids,
- expand_size=num_beams,
- is_encoder_decoder=self.config.is_encoder_decoder,
- **model_kwargs,
- )
- # 12. run beam search
- return self.group_beam_search(
- input_ids,
- beam_scorer,
- logits_processor=logits_processor,
- stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
-
- elif is_constraint_gen_mode:
- if num_return_sequences > num_beams:
- raise ValueError(
- "`num_return_sequences` has to be smaller or equal to `num_beams`."
- )
-
- if stopping_criteria.max_length is None:
- raise ValueError(
- "`max_length` needs to be a stopping_criteria for now."
- )
-
- if num_beams <= 1:
- raise ValueError(
- "`num_beams` needs to be greater than 1 for constrained"
- " genertation."
- )
-
- if do_sample:
- raise ValueError(
- "`do_sample` needs to be false for constrained generation."
- )
-
- if num_beam_groups is not None and num_beam_groups > 1:
- raise ValueError(
- "`num_beam_groups` not supported yet for constrained generation."
- )
-
- final_constraints = []
- if constraints is not None:
- final_constraints = constraints
-
- if force_words_ids is not None:
-
- def typeerror():
- raise ValueError(
- "`force_words_ids` has to either be a `List[List[List[int]]]`"
- " or `List[List[int]]`of positive integers, but is"
- f" {force_words_ids}."
- )
-
- if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
- typeerror()
-
- for word_ids in force_words_ids:
- if isinstance(word_ids[0], list):
- if not isinstance(word_ids, list) or len(word_ids) == 0:
- typeerror()
- if any(
- not isinstance(token_ids, list) for token_ids in word_ids
- ):
- typeerror()
- if any(
- any(
- (not isinstance(token_id, int) or token_id < 0)
- for token_id in token_ids
- )
- for token_ids in word_ids
- ):
- typeerror()
-
- constraint = DisjunctiveConstraint(word_ids)
- else:
- if not isinstance(word_ids, list) or len(word_ids) == 0:
- typeerror()
- if any(
- (not isinstance(token_id, int) or token_id < 0)
- for token_id in word_ids
- ):
- typeerror()
-
- constraint = PhrasalConstraint(word_ids)
- final_constraints.append(constraint)
-
- # 10. prepare beam search scorer
- constrained_beam_scorer = ConstrainedBeamSearchScorer(
- constraints=final_constraints,
- batch_size=batch_size,
- num_beams=num_beams,
- device=self.device,
- length_penalty=length_penalty,
- do_early_stopping=early_stopping,
- num_beam_hyps_to_keep=num_return_sequences,
- )
- # 11. interleave input_ids with `num_beams` additional sequences per batch
- input_ids, model_kwargs = self._expand_inputs_for_generation(
- input_ids,
- expand_size=num_beams,
- is_encoder_decoder=self.config.is_encoder_decoder,
- **model_kwargs,
- )
- # 12. run beam search
- return self.constrained_beam_search(
- input_ids,
- constrained_beam_scorer=constrained_beam_scorer,
- logits_processor=logits_processor,
- stopping_criteria=stopping_criteria,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- output_scores=output_scores,
- return_dict_in_generate=return_dict_in_generate,
- synced_gpus=synced_gpus,
- **model_kwargs,
- )
-
- def greedy_search(
- self,
- input_ids: torch.LongTensor,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- max_length: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: Optional[bool] = False,
- **model_kwargs,
- ) -> Union[GreedySearchOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
- used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
- Parameters:
-
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- logits_processor (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
-
- max_length (`int`, *optional*, defaults to 20):
- **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
- tokens. The maximum length of the sequence to be generated.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`int`, *optional*):
- The id of the *end-of-sequence* token.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- model_kwargs:
- Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
- If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
-
- Return:
- [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`]
- or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
-
- Examples:
-
- ```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForCausalLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... StoppingCriteriaList,
- ... MaxLengthCriteria,
- ... )
-
- >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
-
- >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
- >>> model.config.pad_token_id = model.config.eos_token_id
-
- >>> input_prompt = "It might be possible to"
- >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
-
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [
- ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
- ... ]
- ... )
- >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
-
- >>> outputs = model.greedy_search(
- ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
- ... )
-
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
- ```"""
- # init values
- logits_processor = (
- logits_processor if logits_processor is not None else LogitsProcessorList()
- )
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- warnings.warn(
- (
- "`max_length` is deprecated in this function, use"
- " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])`"
- " instead."
- ),
- UserWarning,
- )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- pad_token_id = (
- pad_token_id if pad_token_id is not None else self.config.pad_token_id
- )
- eos_token_id = (
- eos_token_id if eos_token_id is not None else self.config.eos_token_id
- )
- output_scores = (
- output_scores if output_scores is not None else self.config.output_scores
- )
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.config.return_dict_in_generate
- )
-
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- decoder_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- cross_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- decoder_hidden_states = (
- () if (return_dict_in_generate and output_hidden_states) else None
- )
-
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = (
- model_kwargs["encoder_outputs"].get("attentions")
- if output_attentions
- else None
- )
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states")
- if output_hidden_states
- else None
- )
-
- # keep track of which sequences are already finished
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
- cur_len = input_ids.shape[-1]
-
- this_peer_finished = False # used by synced_gpus only
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(
- 0.0 if this_peer_finished else 1.0
- ).to(input_ids.device)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
-
- # prepare model inputs
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
- # forward pass to get next token
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
-
- next_token_logits = outputs.logits[:, -1, :]
-
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- scores += (next_token_logits,)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,)
- if self.config.is_encoder_decoder
- else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
-
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
-
- # pre-process distribution
- next_tokens_scores = logits_processor(
- input_ids, next_token_logits, model_inputs
- )
-
- # argmax
- next_tokens = torch.argmax(next_tokens_scores, dim=-1)
-
- # finished sentences should have their next token be a padding token
- if eos_token_id is not None:
- if pad_token_id is None:
- raise ValueError(
- "If `eos_token_id` is defined, make sure that `pad_token_id` is"
- " defined."
- )
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
- 1 - unfinished_sequences
- )
-
- # update generated ids, model inputs, and length for next step
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- cur_len = cur_len + 1
-
- # if eos_token was found in one sentence, set sentence to finished
- if eos_token_id is not None:
- unfinished_sequences = unfinished_sequences.mul(
- (next_tokens != eos_token_id).long()
- )
-
- # stop when each sentence is finished, or if we exceed the maximum length
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
- if not synced_gpus:
- break
- else:
- this_peer_finished = True
-
- if return_dict_in_generate:
- if self.config.is_encoder_decoder:
- return GreedySearchEncoderDecoderOutput(
- sequences=input_ids,
- scores=scores,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- )
- else:
- return GreedySearchDecoderOnlyOutput(
- sequences=input_ids,
- scores=scores,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- )
- else:
- return input_ids
-
- def sample(
- self,
- input_ids: torch.LongTensor,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- logits_warper: Optional[LogitsProcessorList] = None,
- max_length: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: Optional[bool] = False,
- **model_kwargs,
- ) -> Union[SampleOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
- can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
- Parameters:
-
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- logits_processor (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- logits_warper (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
- to warp the prediction score distribution of the language modeling head applied before multinomial
- sampling at each generation step.
- max_length (`int`, *optional*, defaults to 20):
- **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
- tokens. The maximum length of the sequence to be generated.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`int`, *optional*):
- The id of the *end-of-sequence* token.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
- an encoder-decoder model the kwargs should include `encoder_outputs`.
-
- Return:
- [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
-
- Examples:
-
- ```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForCausalLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... TopKLogitsWarper,
- ... TemperatureLogitsWarper,
- ... StoppingCriteriaList,
- ... MaxLengthCriteria,
- ... )
- >>> import torch
-
- >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
-
- >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
- >>> model.config.pad_token_id = model.config.eos_token_id
-
- >>> input_prompt = "Today is a beautiful day, and"
- >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
-
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [
- ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
- ... ]
- ... )
- >>> # instantiate logits processors
- >>> logits_warper = LogitsProcessorList(
- ... [
- ... TopKLogitsWarper(50),
- ... TemperatureLogitsWarper(0.7),
- ... ]
- ... )
-
- >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
-
- >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
- >>> outputs = model.sample(
- ... input_ids,
- ... logits_processor=logits_processor,
- ... logits_warper=logits_warper,
- ... stopping_criteria=stopping_criteria,
- ... )
-
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
- ```"""
-
- # init values
- logits_processor = (
- logits_processor if logits_processor is not None else LogitsProcessorList()
- )
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- warnings.warn(
- (
- "`max_length` is deprecated in this function, use"
- " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`"
- " instead."
- ),
- UserWarning,
- )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- logits_warper = (
- logits_warper if logits_warper is not None else LogitsProcessorList()
- )
- pad_token_id = (
- pad_token_id if pad_token_id is not None else self.config.pad_token_id
- )
- eos_token_id = (
- eos_token_id if eos_token_id is not None else self.config.eos_token_id
- )
- output_scores = (
- output_scores if output_scores is not None else self.config.output_scores
- )
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.config.return_dict_in_generate
- )
-
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- decoder_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- cross_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- decoder_hidden_states = (
- () if (return_dict_in_generate and output_hidden_states) else None
- )
-
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = (
- model_kwargs["encoder_outputs"].get("attentions")
- if output_attentions
- else None
- )
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states")
- if output_hidden_states
- else None
- )
-
- # keep track of which sequences are already finished
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
- cur_len = input_ids.shape[-1]
-
- this_peer_finished = False # used by synced_gpus only
- # auto-regressive generation
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(
- 0.0 if this_peer_finished else 1.0
- ).to(input_ids.device)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
-
- # prepare model inputs
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
- # forward pass to get next token
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
-
- next_token_logits_raw = outputs.logits[:, -1, :].clone()
- next_token_logits = outputs.logits[:, -1, :]
-
- # pre-process distribution
- next_token_scores = logits_processor(
- input_ids, next_token_logits, model_inputs=model_inputs
- )
- next_token_scores = logits_warper(input_ids, next_token_scores)
-
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- scores += ((next_token_logits_raw, next_token_scores),)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,)
- if self.config.is_encoder_decoder
- else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
-
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
-
- # sample
- probs = nn.functional.softmax(next_token_scores, dim=-1)
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
-
- # finished sentences should have their next token be a padding token
- if eos_token_id is not None:
- if pad_token_id is None:
- raise ValueError(
- "If `eos_token_id` is defined, make sure that `pad_token_id` is"
- " defined."
- )
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
- 1 - unfinished_sequences
- )
-
- # update generated ids, model inputs, and length for next step
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- cur_len = cur_len + 1
-
- # if eos_token was found in one sentence, set sentence to finished
- if eos_token_id is not None:
- unfinished_sequences = unfinished_sequences.mul(
- (next_tokens != eos_token_id).long()
- )
-
- # stop when each sentence is finished, or if we exceed the maximum length
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
- if not synced_gpus:
- break
- else:
- this_peer_finished = True
-
- if return_dict_in_generate:
- if self.config.is_encoder_decoder:
- return SampleEncoderDecoderOutput(
- sequences=input_ids,
- scores=scores,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- )
- else:
- return SampleDecoderOnlyOutput(
- sequences=input_ids,
- scores=scores,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- )
- else:
- return input_ids
-
- def beam_search(
- self,
- input_ids: torch.LongTensor,
- beam_scorer: BeamScorer,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- max_length: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: Optional[bool] = False,
- **model_kwargs,
- ) -> Union[BeamSearchOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
- can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
- Parameters:
-
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- beam_scorer (`BeamScorer`):
- An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
- sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
- logits_processor (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- max_length (`int`, *optional*, defaults to 20):
- **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
- tokens. The maximum length of the sequence to be generated.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`int`, *optional*):
- The id of the *end-of-sequence* token.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
- an encoder-decoder model the kwargs should include `encoder_outputs`.
-
- Return:
- [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
-
-
- Examples:
-
- ```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForSeq2SeqLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... BeamSearchScorer,
- ... )
- >>> import torch
-
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
- >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
-
- >>> encoder_input_str = "translate English to German: How old are you?"
- >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
-
-
- >>> # lets run beam search using 3 beams
- >>> num_beams = 3
- >>> # define decoder start token ids
- >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
- >>> input_ids = input_ids * model.config.decoder_start_token_id
-
- >>> # add encoder_outputs to model keyword arguments
- >>> model_kwargs = {
- ... "encoder_outputs": model.get_encoder()(
- ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
- ... )
- ... }
-
- >>> # instantiate beam scorer
- >>> beam_scorer = BeamSearchScorer(
- ... batch_size=1,
- ... num_beams=num_beams,
- ... device=model.device,
- ... )
-
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [
- ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
- ... ]
- ... )
-
- >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
-
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Wie alt bist du?']
- ```"""
- # init values
- logits_processor = (
- logits_processor if logits_processor is not None else LogitsProcessorList()
- )
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- warnings.warn(
- (
- "`max_length` is deprecated in this function, use"
- " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`"
- " instead."
- ),
- UserWarning,
- )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- if len(stopping_criteria) == 0:
- warnings.warn(
- (
- "You don't have defined any stopping_criteria, this will likely"
- " loop forever"
- ),
- UserWarning,
- )
- pad_token_id = (
- pad_token_id if pad_token_id is not None else self.config.pad_token_id
- )
- eos_token_id = (
- eos_token_id if eos_token_id is not None else self.config.eos_token_id
- )
- output_scores = (
- output_scores if output_scores is not None else self.config.output_scores
- )
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.config.return_dict_in_generate
- )
-
- batch_size = len(beam_scorer._beam_hyps)
- num_beams = beam_scorer.num_beams
-
- batch_beam_size, cur_len = input_ids.shape
-
- if num_beams * batch_size != batch_beam_size:
- raise ValueError(
- f"Batch dimension of `input_ids` should be {num_beams * batch_size},"
- f" but is {batch_beam_size}."
- )
-
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- beam_indices = (
- tuple(() for _ in range(batch_beam_size))
- if (return_dict_in_generate and output_scores)
- else None
- )
- decoder_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- cross_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- decoder_hidden_states = (
- () if (return_dict_in_generate and output_hidden_states) else None
- )
-
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = (
- model_kwargs["encoder_outputs"].get("attentions")
- if output_attentions
- else None
- )
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states")
- if output_hidden_states
- else None
- )
-
- beam_scores = torch.zeros(
- (batch_size, num_beams), dtype=torch.float, device=input_ids.device
- )
- beam_scores[:, 1:] = -1e9
- beam_scores = beam_scores.view((batch_size * num_beams,))
-
- this_peer_finished = False # used by synced_gpus only
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(
- 0.0 if this_peer_finished else 1.0
- ).to(input_ids.device)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
-
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
-
- next_token_logits = outputs.logits[:, -1, :]
- next_token_logits_raw = next_token_logits.clone()
-
- # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
- # cannot be generated both before and after the `nn.functional.log_softmax` operation.
- next_token_logits = self.adjust_logits_during_generation(
- next_token_logits, cur_len=cur_len
- )
- next_token_scores = nn.functional.log_softmax(
- next_token_logits, dim=-1
- ) # (batch_size * num_beams, vocab_size)
-
- next_token_scores_processed = logits_processor(
- input_ids, next_token_scores, model_inputs=model_inputs
- )
- next_token_scores = next_token_scores_processed + beam_scores[
- :, None
- ].expand_as(next_token_scores)
-
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- scores += (next_token_logits_raw,)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,)
- if self.config.is_encoder_decoder
- else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
-
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
-
- # reshape for beam search
- vocab_size = next_token_scores.shape[-1]
- next_token_scores = next_token_scores.view(
- batch_size, num_beams * vocab_size
- )
-
- next_token_scores, next_tokens = torch.topk(
- next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
- )
-
- next_indices = torch_int_div(next_tokens, vocab_size)
- next_tokens = next_tokens % vocab_size
-
- # stateless
- beam_outputs = beam_scorer.process(
- input_ids,
- next_token_scores,
- next_tokens,
- next_indices,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- )
-
- beam_scores = beam_outputs["next_beam_scores"]
- beam_next_tokens = beam_outputs["next_beam_tokens"]
- beam_idx = beam_outputs["next_beam_indices"]
-
- input_ids = torch.cat(
- [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
- )
-
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- if model_kwargs["past"] is not None:
- model_kwargs["past"] = self._reorder_cache(
- model_kwargs["past"], beam_idx
- )
-
- if return_dict_in_generate and output_scores:
- beam_indices = tuple(
- (
- beam_indices[beam_idx[i]] + (beam_idx[i],)
- for i in range(len(beam_indices))
- )
- )
-
- # increase cur_len
- cur_len = cur_len + 1
-
- if beam_scorer.is_done or stopping_criteria(input_ids, scores):
- if not synced_gpus:
- break
- else:
- this_peer_finished = True
-
- sequence_outputs = beam_scorer.finalize(
- input_ids,
- beam_scores,
- next_tokens,
- next_indices,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- max_length=stopping_criteria.max_length,
- )
-
- if return_dict_in_generate:
- if not output_scores:
- sequence_outputs["sequence_scores"] = None
- else:
- num_return_sequences = beam_scorer.num_beam_hyps_to_keep
- # return only as many indices as sequences
- beam_indices = tuple(
- (
- beam_indices[
- i * num_beams : i * num_beams + num_return_sequences
- ]
- for i in range(batch_size)
- )
- )
- beam_indices = sum(beam_indices, ())
-
- step_wise_raw_logits = self.compute_beam_search_raw_logits(
- sequence_outputs["sequences"].clone(),
- scores,
- beam_indices,
- eos_token_id,
- )
-
- if self.config.is_encoder_decoder:
- return BeamSearchEncoderDecoderOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=step_wise_raw_logits, # raw logits
- beam_indices=beam_indices,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- )
- else:
- return BeamSearchDecoderOnlyOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- beam_indices=beam_indices,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- )
- else:
- return sequence_outputs["sequences"]
-
- def beam_sample(
- self,
- input_ids: torch.LongTensor,
- beam_scorer: BeamScorer,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- logits_warper: Optional[LogitsProcessorList] = None,
- max_length: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: Optional[bool] = False,
- **model_kwargs,
- ) -> Union[BeamSampleOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **beam search multinomial
- sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
- Parameters:
-
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- beam_scorer (`BeamScorer`):
- A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
- sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
- logits_processor (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- logits_warper (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
- to warp the prediction score distribution of the language modeling head applied before multinomial
- sampling at each generation step.
- max_length (`int`, *optional*, defaults to 20):
- **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
- tokens. The maximum length of the sequence to be generated.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`int`, *optional*):
- The id of the *end-of-sequence* token.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
- an encoder-decoder model the kwargs should include `encoder_outputs`.
-
- Return:
- [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
-
- Examples:
-
- ```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForSeq2SeqLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... TopKLogitsWarper,
- ... TemperatureLogitsWarper,
- ... BeamSearchScorer,
- ... )
- >>> import torch
-
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
- >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
-
- >>> encoder_input_str = "translate English to German: How old are you?"
- >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
-
- >>> # lets run beam search using 3 beams
- >>> num_beams = 3
- >>> # define decoder start token ids
- >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
- >>> input_ids = input_ids * model.config.decoder_start_token_id
-
- >>> # add encoder_outputs to model keyword arguments
- >>> model_kwargs = {
- ... "encoder_outputs": model.get_encoder()(
- ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
- ... )
- ... }
-
- >>> # instantiate beam scorer
- >>> beam_scorer = BeamSearchScorer(
- ... batch_size=1,
- ... max_length=model.config.max_length,
- ... num_beams=num_beams,
- ... device=model.device,
- ... )
-
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
- ... )
- >>> # instantiate logits processors
- >>> logits_warper = LogitsProcessorList(
- ... [
- ... TopKLogitsWarper(50),
- ... TemperatureLogitsWarper(0.7),
- ... ]
- ... )
-
- >>> outputs = model.beam_sample(
- ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
- ... )
-
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Wie alt bist du?']
- ```"""
- # init values
- logits_processor = (
- logits_processor if logits_processor is not None else LogitsProcessorList()
- )
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- warnings.warn(
- (
- "`max_length` is deprecated in this function, use"
- " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`"
- " instead."
- ),
- UserWarning,
- )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- pad_token_id = (
- pad_token_id if pad_token_id is not None else self.config.pad_token_id
- )
- eos_token_id = (
- eos_token_id if eos_token_id is not None else self.config.eos_token_id
- )
- output_scores = (
- output_scores if output_scores is not None else self.config.output_scores
- )
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.config.return_dict_in_generate
- )
-
- batch_size = len(beam_scorer._beam_hyps)
- num_beams = beam_scorer.num_beams
-
- batch_beam_size, cur_len = input_ids.shape
-
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- beam_indices = (
- tuple(() for _ in range(batch_beam_size))
- if (return_dict_in_generate and output_scores)
- else None
- )
- decoder_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- cross_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- decoder_hidden_states = (
- () if (return_dict_in_generate and output_hidden_states) else None
- )
-
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = (
- model_kwargs["encoder_outputs"].get("attentions")
- if output_attentions
- else None
- )
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states")
- if output_hidden_states
- else None
- )
-
- beam_scores = torch.zeros(
- (batch_size, num_beams), dtype=torch.float, device=input_ids.device
- )
- beam_scores = beam_scores.view((batch_size * num_beams,))
-
- this_peer_finished = False # used by synced_gpus only
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(
- 0.0 if this_peer_finished else 1.0
- ).to(input_ids.device)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
-
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
-
- next_token_logits_raw = outputs.logits[:, -1, :]
-
- # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
- # cannot be generated both before and after the `nn.functional.log_softmax` operation.
- next_token_logits = self.adjust_logits_during_generation(
- next_token_logits_raw, cur_len=cur_len
- )
- next_token_scores = nn.functional.log_softmax(
- next_token_logits, dim=-1
- ) # (batch_size * num_beams, vocab_size)
-
- next_token_scores_processed = logits_processor(
- input_ids, next_token_logits, model_inputs=model_inputs
- )
- next_token_scores = next_token_scores_processed + beam_scores[
- :, None
- ].expand_as(next_token_scores)
- next_token_scores = logits_warper(input_ids, next_token_scores)
-
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- # return raw scores instead of post-processed
- scores += ((next_token_logits_raw, next_token_scores),)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,)
- if self.config.is_encoder_decoder
- else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
-
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
-
- # reshape for beam search
- vocab_size = next_token_scores.shape[-1]
- next_token_scores = next_token_scores.view(
- batch_size, num_beams * vocab_size
- )
-
- probs = nn.functional.softmax(next_token_scores, dim=-1)
-
- next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
- next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
-
- next_token_scores, _indices = torch.sort(
- next_token_scores, descending=True, dim=1
- )
- next_tokens = torch.gather(next_tokens, -1, _indices)
-
- next_indices = torch_int_div(next_tokens, vocab_size)
- next_tokens = next_tokens % vocab_size
-
- # stateless
- beam_outputs = beam_scorer.process(
- input_ids,
- next_token_scores,
- next_tokens,
- next_indices,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- )
- beam_scores = beam_outputs["next_beam_scores"]
- beam_next_tokens = beam_outputs["next_beam_tokens"]
- beam_idx = beam_outputs["next_beam_indices"]
-
- input_ids = torch.cat(
- [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
- )
-
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- if model_kwargs["past"] is not None:
- model_kwargs["past"] = self._reorder_cache(
- model_kwargs["past"], beam_idx
- )
-
- if return_dict_in_generate and output_scores:
- beam_indices = tuple(
- (
- beam_indices[beam_idx[i]] + (beam_idx[i],)
- for i in range(len(beam_indices))
- )
- )
-
- # increase cur_len
- cur_len = cur_len + 1
-
- if beam_scorer.is_done or stopping_criteria(input_ids, scores):
- if not synced_gpus:
- break
- else:
- this_peer_finished = True
-
- sequence_outputs = beam_scorer.finalize(
- input_ids,
- beam_scores,
- next_tokens,
- next_indices,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- max_length=stopping_criteria.max_length,
- )
-
- if return_dict_in_generate:
- if not output_scores:
- sequence_outputs["sequence_scores"] = None
- else:
- num_return_sequences = beam_scorer.num_beam_hyps_to_keep
- # return only as many indices as sequences
- beam_indices = tuple(
- (
- beam_indices[
- i * num_beams : i * num_beams + num_return_sequences
- ]
- for i in range(batch_size)
- )
- )
- beam_indices = sum(beam_indices, ())
-
- if self.config.is_encoder_decoder:
- return BeamSampleEncoderDecoderOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- beam_indices=beam_indices,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- )
- else:
- return BeamSampleDecoderOnlyOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- beam_indices=beam_indices,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- )
- else:
- return sequence_outputs["sequences"]
-
- def group_beam_search(
- self,
- input_ids: torch.LongTensor,
- beam_scorer: BeamScorer,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- max_length: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: Optional[bool] = False,
- **model_kwargs,
- ):
- r"""
- Generates sequences of token ids for models with a language modeling head using **diverse beam search
- decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
- Parameters:
-
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- beam_scorer (`BeamScorer`):
- An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
- sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
- logits_processor (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- max_length (`int`, *optional*, defaults to 20):
- **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
- tokens. The maximum length of the sequence to be generated.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`int`, *optional*):
- The id of the *end-of-sequence* token.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
-
- model_kwargs:
- Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
- model is an encoder-decoder model the kwargs should include `encoder_outputs`.
-
- Return:
- [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if
- `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
- [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
-
- Examples:
-
- ```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForSeq2SeqLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... HammingDiversityLogitsProcessor,
- ... BeamSearchScorer,
- ... )
- >>> import torch
-
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
- >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
-
- >>> encoder_input_str = "translate English to German: How old are you?"
- >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
-
-
- >>> # lets run diverse beam search using 6 beams
- >>> num_beams = 6
- >>> # define decoder start token ids
- >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
- >>> input_ids = input_ids * model.config.decoder_start_token_id
-
- >>> # add encoder_outputs to model keyword arguments
- >>> model_kwargs = {
- ... "encoder_outputs": model.get_encoder()(
- ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
- ... )
- ... }
-
- >>> # instantiate beam scorer
- >>> beam_scorer = BeamSearchScorer(
- ... batch_size=1,
- ... max_length=model.config.max_length,
- ... num_beams=num_beams,
- ... device=model.device,
- ... num_beam_groups=3,
- ... )
-
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [
- ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
- ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
- ... ]
- ... )
-
- >>> outputs = model.group_beam_search(
- ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
- ... )
-
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Wie alt bist du?']
- ```"""
- # init values
- logits_processor = (
- logits_processor if logits_processor is not None else LogitsProcessorList()
- )
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- warnings.warn(
- (
- "`max_length` is deprecated in this function, use"
- " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`"
- " instead."
- ),
- UserWarning,
- )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- pad_token_id = (
- pad_token_id if pad_token_id is not None else self.config.pad_token_id
- )
- eos_token_id = (
- eos_token_id if eos_token_id is not None else self.config.eos_token_id
- )
- output_scores = (
- output_scores if output_scores is not None else self.config.output_scores
- )
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.config.return_dict_in_generate
- )
-
- batch_size = len(beam_scorer._beam_hyps)
- num_beams = beam_scorer.num_beams
- num_beam_groups = beam_scorer.num_beam_groups
- num_sub_beams = num_beams // num_beam_groups
- device = input_ids.device
-
- batch_beam_size, cur_len = input_ids.shape
-
- if return_dict_in_generate and output_scores:
- beam_indices = [
- tuple(() for _ in range(num_sub_beams * batch_size))
- for _ in range(num_beam_groups)
- ]
- else:
- beam_indices = None
-
- if num_beams * batch_size != batch_beam_size:
- raise ValueError(
- f"Batch dimension of `input_ids` should be {num_beams * batch_size},"
- f" but is {batch_beam_size}."
- )
-
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- decoder_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- cross_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- decoder_hidden_states = (
- () if (return_dict_in_generate and output_hidden_states) else None
- )
-
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = (
- model_kwargs["encoder_outputs"].get("attentions")
- if output_attentions
- else None
- )
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states")
- if output_hidden_states
- else None
- )
-
- beam_scores = torch.full(
- (batch_size, num_beams), -1e9, dtype=torch.float, device=device
- )
- # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
- # the same group don't produce same tokens everytime.
- beam_scores[:, ::num_sub_beams] = 0
- beam_scores = beam_scores.view((batch_size * num_beams,))
-
- this_peer_finished = False # used by synced_gpus only
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(
- 0.0 if this_peer_finished else 1.0
- ).to(input_ids.device)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
-
- # predicted tokens in cur_len step
- current_tokens = torch.zeros(
- batch_size * num_beams, dtype=input_ids.dtype, device=device
- )
-
- # indices which will form the beams in the next time step
- reordering_indices = torch.zeros(
- batch_size * num_beams, dtype=torch.long, device=device
- )
-
- # do one decoder step on all beams of all sentences in batch
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
-
- if output_scores:
- processed_score = torch.zeros_like(outputs.logits[:, -1, :])
-
- for beam_group_idx in range(num_beam_groups):
- group_start_idx = beam_group_idx * num_sub_beams
- group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
- group_size = group_end_idx - group_start_idx
-
- # indices of beams of current group among all sentences in batch
- batch_group_indices = []
-
- for batch_idx in range(batch_size):
- batch_group_indices.extend(
- [
- batch_idx * num_beams + idx
- for idx in range(group_start_idx, group_end_idx)
- ]
- )
- group_input_ids = input_ids[batch_group_indices]
-
- # select outputs of beams of current group only
- next_token_logits_raw = outputs.logits[batch_group_indices, -1, :]
-
- # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
- # cannot be generated both before and after the `nn.functional.log_softmax` operation.
- next_token_logits = self.adjust_logits_during_generation(
- next_token_logits_raw, cur_len=cur_len
- )
- next_token_scores = nn.functional.log_softmax(
- next_token_logits, dim=-1
- ) # (batch_size * group_size, vocab_size)
- vocab_size = next_token_scores.shape[-1]
-
- next_token_scores_processed = logits_processor(
- group_input_ids,
- next_token_scores,
- current_tokens=current_tokens,
- beam_group_idx=beam_group_idx,
- model_inputs=model_inputs,
- )
- next_token_scores = next_token_scores_processed + beam_scores[
- batch_group_indices
- ].unsqueeze(-1)
- next_token_scores = next_token_scores.expand_as(
- next_token_scores_processed
- )
-
- if output_scores:
- processed_score[batch_group_indices] = next_token_logits_raw
-
- # reshape for beam search
- next_token_scores = next_token_scores.view(
- batch_size, group_size * vocab_size
- )
-
- next_token_scores, next_tokens = torch.topk(
- next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
- )
-
- next_indices = torch_int_div(next_tokens, vocab_size)
- next_tokens = next_tokens % vocab_size
-
- # stateless
- beam_outputs = beam_scorer.process(
- group_input_ids,
- next_token_scores,
- next_tokens,
- next_indices,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- )
- beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
- beam_next_tokens = beam_outputs["next_beam_tokens"]
- beam_idx = beam_outputs["next_beam_indices"]
-
- if return_dict_in_generate and output_scores:
- beam_indices[beam_group_idx] = tuple(
- beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],)
- for i in range(len(beam_indices[0]))
- )
-
- input_ids[batch_group_indices] = group_input_ids[beam_idx]
- group_input_ids = torch.cat(
- [group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)],
- dim=-1,
- )
- current_tokens[batch_group_indices] = group_input_ids[:, -1]
-
- # (beam_idx // group_size) -> batch_idx
- # (beam_idx % group_size) -> offset of idx inside the group
- reordering_indices[batch_group_indices] = (
- num_beams * torch_int_div(beam_idx, group_size)
- + group_start_idx
- + (beam_idx % group_size)
- )
-
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- scores += (processed_score,)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,)
- if self.config.is_encoder_decoder
- else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
-
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
-
- input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
-
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- if model_kwargs["past"] is not None:
- model_kwargs["past"] = self._reorder_cache(
- model_kwargs["past"], reordering_indices
- )
-
- # increase cur_len
- cur_len = cur_len + 1
-
- if beam_scorer.is_done or stopping_criteria(input_ids, scores):
- if not synced_gpus:
- break
- else:
- this_peer_finished = True
-
- sequence_outputs = beam_scorer.finalize(
- input_ids,
- beam_scores,
- next_tokens,
- next_indices,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- max_length=stopping_criteria.max_length,
- )
-
- if return_dict_in_generate:
- if not output_scores:
- sequence_outputs["sequence_scores"] = None
- else:
- beam_indices = sum(beam_indices, ())
- num_return_sequences = beam_scorer.num_beam_hyps_to_keep
- # return only as many indices as sequences
- beam_indices = tuple(
- (
- beam_indices[
- i * num_beams : i * num_beams + num_return_sequences
- ]
- for i in range(batch_size)
- )
- )
- beam_indices = sum(beam_indices, ())
-
- if self.config.is_encoder_decoder:
- return BeamSearchEncoderDecoderOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- beam_indices=beam_indices,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- )
- else:
- return BeamSearchDecoderOnlyOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- )
- else:
- return sequence_outputs["sequences"]
-
- def constrained_beam_search(
- self,
- input_ids: torch.LongTensor,
- constrained_beam_scorer: ConstrainedBeamSearchScorer,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- max_length: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: Optional[bool] = None,
- **model_kwargs,
- ) -> Union[BeamSearchOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **constrained beam search
- decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
- Parameters:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
- A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
- sorted during generation, while satisfying a list of positive constraints. For more information, the
- documentation of [`ConstrainedBeamSearchScorer`] should be read.
- logits_processor (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- logits_warper (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
- to warp the prediction score distribution of the language modeling head applied before multinomial
- sampling at each generation step.
- max_length (`int`, *optional*, defaults to 20):
- **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
- tokens. The maximum length of the sequence to be generated.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`int`, *optional*):
- The id of the *end-of-sequence* token.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
- an encoder-decoder model the kwargs should include `encoder_outputs`.
-
- Return:
- [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
-
-
- Examples:
-
- ```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForSeq2SeqLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... ConstrainedBeamSearchScorer,
- ... PhrasalConstraint,
- ... )
- >>> import torch
-
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
- >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
-
- >>> encoder_input_str = "translate English to German: How old are you?"
- >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
-
-
- >>> # lets run beam search using 3 beams
- >>> num_beams = 3
- >>> # define decoder start token ids
- >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
- >>> input_ids = input_ids * model.config.decoder_start_token_id
-
- >>> # add encoder_outputs to model keyword arguments
- >>> model_kwargs = {
- ... "encoder_outputs": model.get_encoder()(
- ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
- ... )
- ... }
-
- >>> constraint_str = "Sie"
- >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token
- >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
-
-
- >>> # instantiate beam scorer
- >>> beam_scorer = ConstrainedBeamSearchScorer(
- ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
- ... )
-
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [
- ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
- ... ]
- ... )
-
- >>> outputs = model.constrained_beam_search(
- ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
- ... )
-
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Wie alt sind Sie?']
- ```"""
- # init values
- logits_processor = (
- logits_processor if logits_processor is not None else LogitsProcessorList()
- )
- stopping_criteria = (
- stopping_criteria
- if stopping_criteria is not None
- else StoppingCriteriaList()
- )
- if max_length is not None:
- warnings.warn(
- (
- "`max_length` is deprecated in this function, use"
- " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`"
- " instead."
- ),
- UserWarning,
- )
- stopping_criteria = validate_stopping_criteria(
- stopping_criteria, max_length
- )
- if len(stopping_criteria) == 0:
- warnings.warn(
- (
- "You don't have defined any stopping_criteria, this will likely"
- " loop forever"
- ),
- UserWarning,
- )
- pad_token_id = (
- pad_token_id if pad_token_id is not None else self.config.pad_token_id
- )
- eos_token_id = (
- eos_token_id if eos_token_id is not None else self.config.eos_token_id
- )
- output_scores = (
- output_scores if output_scores is not None else self.config.output_scores
- )
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.config.return_dict_in_generate
- )
-
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- decoder_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- cross_attentions = (
- () if (return_dict_in_generate and output_attentions) else None
- )
- decoder_hidden_states = (
- () if (return_dict_in_generate and output_hidden_states) else None
- )
-
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = (
- model_kwargs["encoder_outputs"].get("attentions")
- if output_attentions
- else None
- )
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states")
- if output_hidden_states
- else None
- )
-
- batch_size = len(constrained_beam_scorer._beam_hyps)
- num_beams = constrained_beam_scorer.num_beams
-
- batch_beam_size, cur_len = input_ids.shape
-
- if num_beams * batch_size != batch_beam_size:
- raise ValueError(
- f"Batch dimension of `input_ids` should be {num_beams * batch_size},"
- f" but is {batch_beam_size}."
- )
-
- beam_scores = torch.zeros(
- (batch_size, num_beams), dtype=torch.float, device=input_ids.device
- )
- beam_scores[:, 1:] = -1e9
- beam_scores = beam_scores.view((batch_size * num_beams,))
-
- this_peer_finished = False # used by synced_gpus only
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(
- 0.0 if this_peer_finished else 1.0
- ).to(input_ids.device)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
-
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
-
- next_token_logits_raw = outputs.logits[:, -1, :]
- # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
- # cannot be generated both before and after the `nn.functional.log_softmax` operation.
- next_token_logits = self.adjust_logits_during_generation(
- next_token_logits_raw, cur_len=cur_len
- )
- next_token_scores = nn.functional.log_softmax(
- next_token_logits, dim=-1
- ) # (batch_size * num_beams, vocab_size)
-
- next_token_scores_processed = logits_processor(
- input_ids, next_token_scores, model_inputs=model_inputs
- )
-
- scores_for_all_vocab = next_token_scores_processed.clone()
-
- next_token_scores = next_token_scores_processed + beam_scores[
- :, None
- ].expand_as(next_token_scores)
-
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- scores += ((next_token_logits_raw, next_token_scores),)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,)
- if self.config.is_encoder_decoder
- else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
-
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
-
- # reshape for beam search
- vocab_size = next_token_scores.shape[-1]
- next_token_scores = next_token_scores.view(
- batch_size, num_beams * vocab_size
- )
-
- next_token_scores, next_tokens = torch.topk(
- next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
- )
-
- next_indices = (next_tokens / vocab_size).long()
- next_tokens = next_tokens % vocab_size
-
- # stateless
- beam_outputs = constrained_beam_scorer.process(
- input_ids,
- next_token_scores,
- next_tokens,
- next_indices,
- scores_for_all_vocab,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- )
- beam_scores = beam_outputs["next_beam_scores"]
- beam_next_tokens = beam_outputs["next_beam_tokens"]
- beam_idx = beam_outputs["next_beam_indices"]
-
- input_ids = torch.cat(
- [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
- )
- model_kwargs = self._update_model_kwargs_for_generation(
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
- )
- if model_kwargs["past"] is not None:
- model_kwargs["past"] = self._reorder_cache(
- model_kwargs["past"], beam_idx
- )
-
- # increase cur_len
- cur_len = cur_len + 1
-
- if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
- if not synced_gpus:
- break
- else:
- this_peer_finished = True
-
- sequence_outputs = constrained_beam_scorer.finalize(
- input_ids,
- beam_scores,
- next_tokens,
- next_indices,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- max_length=stopping_criteria.max_length,
- )
-
- if return_dict_in_generate:
- if not output_scores:
- sequence_outputs["sequence_scores"] = None
- if self.config.is_encoder_decoder:
- return BeamSearchEncoderDecoderOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- )
- else:
- return BeamSearchDecoderOnlyOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- )
- else:
- return sequence_outputs["sequences"]
-
-
-def top_k_top_p_filtering(
- logits: torch.FloatTensor,
- top_k: int = 0,
- top_p: float = 1.0,
- filter_value: float = -float("Inf"),
- min_tokens_to_keep: int = 1,
-) -> torch.FloatTensor:
- """
- Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
-
- Args:
- logits: logits distribution shape (batch size, vocabulary size)
- top_k (`int`, *optional*, defaults to 0):
- If > 0, only keep the top k tokens with highest probability (top-k filtering)
- top_p (`float`, *optional*, defaults to 1.0):
- If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
- filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
- min_tokens_to_keep (`int`, *optional*, defaults to 1):
- Minimumber of tokens we keep per batch example in the output.
-
- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
- """
- if top_k > 0:
- logits = TopKLogitsWarper(
- top_k=top_k,
- filter_value=filter_value,
- min_tokens_to_keep=min_tokens_to_keep,
- )(None, logits)
-
- if 0 <= top_p <= 1.0:
- logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(
- None, logits
- )
-
- return logits
-
-
-def override_generation_routines(cls):
- bases = list(cls.__bases__)
- for base_ix in range(len(bases)):
- if bases[base_ix] == GenerationMixin:
- bases[base_ix] = GenerationMixinWithRawScores
-
- # recursively look up
- if bases[base_ix] != object:
- bases[base_ix] = override_generation_routines(bases[base_ix])
-
- cls.__bases__ = tuple(bases)
- return cls
-
-
-def unwrap_generation_routines(cls):
- bases = list(cls.__bases__)
- for base_ix in range(len(bases)):
- if bases[base_ix] == GenerationMixinWithRawScores:
- bases[base_ix] = GenerationMixin
-
- # recursively look up
- if bases[base_ix] != object:
- bases[base_ix] = unwrap_generation_routines(bases[base_ix])
-
- cls.__bases__ = tuple(bases)
- return cls
diff --git a/openrl/modules/networks/value_network.py b/openrl/modules/networks/value_network.py
index 187eb465..bce574c5 100644
--- a/openrl/modules/networks/value_network.py
+++ b/openrl/modules/networks/value_network.py
@@ -49,6 +49,7 @@ def __init__(
self._use_recurrent_policy = cfg.use_recurrent_policy
self._use_influence_policy = cfg.use_influence_policy
self._use_popart = cfg.use_popart
+ self._use_fp16 = cfg.use_fp16 and cfg.use_deepspeed
self._influence_layer_N = cfg.influence_layer_N
self._recurrent_N = cfg.recurrent_N
self.tpdv = dict(dtype=torch.float32, device=device)
@@ -118,6 +119,9 @@ def forward(self, critic_obs, rnn_states, masks):
rnn_states = check(rnn_states).to(**self.tpdv)
masks = check(masks).to(**self.tpdv)
+ if self._use_fp16:
+ critic_obs = critic_obs.half()
+
critic_features = self.base(critic_obs)
if self._use_naive_recurrent_policy or self._use_recurrent_policy:
diff --git a/openrl/modules/networks/value_network_gpt.py b/openrl/modules/networks/value_network_gpt.py
new file mode 100644
index 00000000..0c5b1154
--- /dev/null
+++ b/openrl/modules/networks/value_network_gpt.py
@@ -0,0 +1,149 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2021 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from transformers.modeling_utils import unwrap_model
+
+from openrl.buffers.utils.util import get_critic_obs_space
+from openrl.modules.networks.base_value_network import BaseValueNetwork
+from openrl.modules.networks.utils.cnn import CNNBase
+from openrl.modules.networks.utils.mix import MIXBase
+from openrl.modules.networks.utils.mlp import MLPBase, MLPLayer
+from openrl.modules.networks.utils.popart import PopArt
+from openrl.modules.networks.utils.rnn import RNNLayer
+from openrl.modules.networks.utils.util import init
+from openrl.modules.utils.valuenorm import ValueNorm
+from openrl.utils.util import check_v2 as check
+
+
+class ValueNetworkGPT(BaseValueNetwork):
+ def __init__(
+ self,
+ cfg,
+ input_space,
+ action_space=None,
+ use_half=False,
+ device=torch.device("cpu"),
+ extra_args=None,
+ ):
+ self.device = device
+
+ self.use_fp16 = cfg.use_fp16
+ self.use_deepspeed = cfg.use_deepspeed
+ self.use_half = False
+ self.use_data_parallel = not cfg.use_deepspeed
+ self.use_model_parallel = False
+ assert not (self.use_deepspeed and self.use_data_parallel)
+ assert not (self.use_deepspeed and self.use_model_parallel)
+ assert not (self.use_data_parallel and self.use_model_parallel)
+
+ super(ValueNetworkGPT, self).__init__(cfg, device)
+
+ from transformers import AutoModelForCausalLM
+
+ self._value_model = AutoModelForCausalLM.from_pretrained(cfg.model_path)
+ self._value_model.config.use_cache = False
+ self._value_head = nn.Linear(
+ self._value_model.config.n_embd,
+ 1,
+ bias=False, # gpt2
+ # self._value_model.config.word_embed_proj_dim, 1, bias=False # opt-x
+ )
+ self.value_normalizer = (
+ ValueNorm(1, device=device) if self._use_valuenorm else None
+ )
+
+ if self.use_deepspeed:
+ self._value_head.to(self.device)
+ else:
+ if self.use_model_parallel:
+ self._value_model.parallelize()
+ elif self.use_data_parallel:
+ if self.use_half:
+ self._value_model = self._value_model.half()
+ self._value_head = self._value_head.half()
+ self._value_model = torch.nn.DataParallel(self._value_model)
+ self._value_model = self._value_model.to(self.device)
+ self._value_head = torch.nn.DataParallel(self._value_head)
+ self._value_head = self._value_head.to(self.device)
+
+ def _prepare_inputs_for_model(
+ self,
+ model: Any,
+ input_ids: torch.tensor,
+ model_kwargs: Optional[Dict[str, torch.tensor]] = None,
+ ):
+ model_inputs = unwrap_model(model).prepare_inputs_for_generation(
+ input_ids, **model_kwargs
+ )
+
+ if self.use_model_parallel:
+ model_inputs = {
+ key: (
+ value.to(model.transformer.first_device)
+ if isinstance(value, torch.Tensor)
+ and hasattr(model.transformer, "first_device")
+ else value
+ )
+ for key, value in model_inputs.items()
+ }
+
+ return model_inputs
+
+ def forward(self, critic_obs, rnn_states, masks):
+ for key in critic_obs.keys():
+ critic_obs[key] = (
+ torch.from_numpy(critic_obs[key])
+ if type(critic_obs[key]) == np.ndarray
+ else critic_obs[key]
+ )
+ if self.use_data_parallel:
+ critic_obs[key] = critic_obs[key].to(self.device)
+ else:
+ critic_obs[key] = critic_obs[key].to(self._value_model.device)
+
+ rnn_states = check(rnn_states)
+
+ if self.use_half:
+ input_ids = critic_obs["input_encoded_pt"].int()
+ attention_mask = critic_obs["input_attention_mask_pt"].int()
+ else:
+ input_ids = critic_obs["input_encoded_pt"].long()
+ attention_mask = critic_obs["input_attention_mask_pt"].long()
+
+ past_model_kwargs = None
+ if not past_model_kwargs:
+ past_model_kwargs = {
+ "attention_mask": attention_mask,
+ }
+
+ model_inputs = self._prepare_inputs_for_model(
+ self._value_model, input_ids, past_model_kwargs
+ )
+ output = self._value_model(output_hidden_states=True, **model_inputs)
+ last_tokens_hidden = output.hidden_states[-1][:, -1]
+
+ if self.use_model_parallel:
+ last_tokens_hidden = last_tokens_hidden.to(self.device)
+
+ values = self._value_head.forward(last_tokens_hidden)
+
+ return values, rnn_states
diff --git a/openrl/modules/rl_module.py b/openrl/modules/rl_module.py
index 430e60d6..7b7e390e 100644
--- a/openrl/modules/rl_module.py
+++ b/openrl/modules/rl_module.py
@@ -55,6 +55,8 @@ def __init__(
self.rank = rank
self.world_size = world_size
+ self.use_deepspeed = cfg.use_deepspeed
+
use_half_actor = self.program_type == "actor" and cfg.use_half_actor
if model_configs is None:
@@ -70,18 +72,57 @@ def __init__(
use_half=use_half_actor,
extra_args=model_cg["extra_args"] if "extra_args" in model_cg else None,
)
- self.models.update({model_key: model})
if self.program_type == "actor":
continue
- optimizer = torch.optim.Adam(
- model.parameters(),
- lr=model_cg["lr"],
- eps=cfg.opti_eps,
- weight_decay=cfg.weight_decay,
- )
- self.optimizers.update({model_key: optimizer})
+ if not self.use_deepspeed:
+ optimizer = torch.optim.Adam(
+ model.parameters(),
+ lr=model_cg["lr"],
+ eps=cfg.opti_eps,
+ weight_decay=cfg.weight_decay,
+ )
+ self.models.update({model_key: model})
+ self.optimizers.update({model_key: optimizer})
+ else:
+ import json
+
+ import deepspeed
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
+ from transformers import get_constant_schedule
+
+ self.use_fp16 = cfg.use_fp16
+ self.use_offload = cfg.use_offload
+
+ # Check for inconsistencies in configuration files
+ assert not (self.use_fp16 and not self.use_deepspeed)
+ assert not (self.use_offload and not self.use_deepspeed)
+ assert cfg.deepspeed_config is not None
+ with open(cfg.deepspeed_config) as file:
+ ds_config = json.load(file)
+ if "fp16" in ds_config:
+ assert ds_config["fp16"]["enabled"] == self.use_fp16
+
+ AdamOptimizer = DeepSpeedCPUAdam if self.use_offload else FusedAdam
+ optim_params = filter(lambda p: p.requires_grad, model.parameters())
+ optim = AdamOptimizer(
+ optim_params, lr=model_cg["lr"], betas=(0.9, 0.95)
+ )
+
+ # LR Scheduler
+ lr_scheduler = get_constant_schedule(
+ optimizer=optim,
+ )
+
+ engine, *_ = deepspeed.initialize(
+ args=cfg,
+ model=model,
+ optimizer=optim,
+ lr_scheduler=lr_scheduler,
+ )
+ self.models.update({model_key: engine})
+ self.optimizers.update({model_key: engine})
if cfg.use_amp:
self.scaler = torch.cuda.amp.GradScaler()
diff --git a/openrl/modules/utils/valuenorm.py b/openrl/modules/utils/valuenorm.py
index bed1d705..43aaad9c 100644
--- a/openrl/modules/utils/valuenorm.py
+++ b/openrl/modules/utils/valuenorm.py
@@ -24,15 +24,21 @@ def __init__(
self.per_element_update = per_element_update
self.tpdv = dict(dtype=torch.float32, device=device)
- # self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv)
- # self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv)
- # self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv)
-
- self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False)
+ self.running_mean = nn.Parameter(
+ torch.zeros(input_shape), requires_grad=False
+ ).to(**self.tpdv)
self.running_mean_sq = nn.Parameter(
torch.zeros(input_shape), requires_grad=False
+ ).to(**self.tpdv)
+ self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(
+ **self.tpdv
)
- self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False)
+
+ # self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False)
+ # self.running_mean_sq = nn.Parameter(
+ # torch.zeros(input_shape), requires_grad=False
+ # )
+ # self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False)
self.reset_parameters()
diff --git a/openrl/modules/vdn_module.py b/openrl/modules/vdn_module.py
index 32987372..10a9b541 100644
--- a/openrl/modules/vdn_module.py
+++ b/openrl/modules/vdn_module.py
@@ -68,6 +68,8 @@ def __init__(
device=device,
)
self.cfg = cfg
+ self.obs_space = input_space
+ self.act_space = act_space
def lr_decay(self, episode, episodes):
update_linear_schedule(self.optimizers["q_net"], episode, episodes, self.lr)
diff --git a/openrl/rewards/nlp_reward.py b/openrl/rewards/nlp_reward.py
index c653c7c8..38cd306a 100644
--- a/openrl/rewards/nlp_reward.py
+++ b/openrl/rewards/nlp_reward.py
@@ -10,20 +10,34 @@
class NLPReward(BaseReward):
- def __init__(self, env: Env, ref_model: str, intent_model: str):
+ def __init__(
+ self,
+ env: Env,
+ ref_model: str,
+ intent_model: str,
+ use_deepspeed: bool = False,
+ ref_ds_config: str = "default",
+ intent_ds_config: str = "default",
+ ):
self.rew_infos = []
self.env_infos = []
- meteor_config = {
- "meteor_coeff": 0.5,
- }
- self.inner_rew_funcs = {
- "meteor": Meteor(**meteor_config),
- }
+ # bug unfixed
+ self.inner_rew_funcs = dict()
+
+ # meteor_config = {
+ # "meteor_coeff": 0.5,
+ # "test": ref_model == "builtin_ref",
+ # }
+ # self.inner_rew_funcs = {
+ # "meteor": Meteor(**meteor_config),
+ # }
kl_config = {
"action_space": env.action_space,
"ref_model": ref_model,
+ "use_deepspeed": use_deepspeed,
+ "ds_config": ref_ds_config,
}
self.step_rew_funcs = {
"kl_pen": KLPenalty(**kl_config),
@@ -32,6 +46,8 @@ def __init__(self, env: Env, ref_model: str, intent_model: str):
intent_config = {
"intent_model": intent_model,
"intent_coeff": 0.5,
+ "use_deepspeed": use_deepspeed,
+ "ds_config": intent_ds_config,
}
self.batch_rew_funcs = {
"intent_acc": Intent(**intent_config),
diff --git a/openrl/runners/common/ppo_agent.py b/openrl/runners/common/ppo_agent.py
index ad7d0a84..414ff409 100644
--- a/openrl/runners/common/ppo_agent.py
+++ b/openrl/runners/common/ppo_agent.py
@@ -136,6 +136,7 @@ def act(
observation: Union[np.ndarray, Dict[str, np.ndarray]],
info: Optional[List[Dict[str, Any]]] = None,
deterministic: bool = True,
+ episode_starts: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
assert self.net is not None, "net is None"
observation = ObsData.prepare_input(observation)
@@ -149,6 +150,7 @@ def act(
observation,
action_masks=action_masks,
deterministic=deterministic,
+ episode_starts=episode_starts,
)
action = np.array(np.split(_t2n(action), self.env_num))
diff --git a/openrl/selfplay/callbacks/selfplay_api.py b/openrl/selfplay/callbacks/selfplay_api.py
index 3d148749..e2214ecb 100644
--- a/openrl/selfplay/callbacks/selfplay_api.py
+++ b/openrl/selfplay/callbacks/selfplay_api.py
@@ -50,14 +50,17 @@ def _init_callback(self) -> None:
)
self.bind = SelfplayAPIServer.bind()
- serve.run(self.bind)
+ serve.run(self.bind, route_prefix="/selfplay")
success = False
try_time = 10
while not success:
success = self.api_client.set_sample_strategy(self.sample_strategy)
try_time -= 1
if try_time <= 0:
- raise RuntimeError("Failed to set sample strategy.")
+ raise RuntimeError(
+ f"Failed to set sample strategy: {self.sample_strategy}. host:"
+ f" {self.host}, port: {self.port}"
+ )
def _on_step(self) -> bool:
# print("To send request to API server.")
@@ -72,5 +75,6 @@ def _on_training_end(self) -> None:
print(f"deleting {application_name}")
serve.delete(application_name)
del self.bind
+ serve.shutdown()
if self.verbose >= 2:
print(f"delete {application_name} done!")
diff --git a/openrl/selfplay/opponents/random_opponent.py b/openrl/selfplay/opponents/random_opponent.py
index 1f396c34..501d571a 100644
--- a/openrl/selfplay/opponents/random_opponent.py
+++ b/openrl/selfplay/opponents/random_opponent.py
@@ -47,11 +47,20 @@ def _sample_random_action(
action = []
for obs, space in zip(observation, action_space):
- mask = obs.get("action_mask", None)
- action.append(space.sample(mask))
+ if termination or truncation:
+ action.append(None)
+ else:
+ if isinstance(obs, dict):
+ mask = obs.get("action_mask", None)
+ else:
+ mask = None
+ action.append(space.sample(mask))
else:
- mask = observation.get("action_mask", None)
- action = action_space.sample(mask)
+ if termination or truncation:
+ action = None
+ else:
+ mask = observation.get("action_mask", None)
+ action = action_space.sample(mask)
return action
def _load(self, opponent_path: Union[str, Path]):
diff --git a/openrl/selfplay/opponents/utils.py b/openrl/selfplay/opponents/utils.py
index d1d983d5..42ddbb2b 100644
--- a/openrl/selfplay/opponents/utils.py
+++ b/openrl/selfplay/opponents/utils.py
@@ -28,6 +28,9 @@
def check_opponent_template(opponent_template: Union[str, Path]):
+ assert isinstance(opponent_template, Path) or isinstance(
+ opponent_template, str
+ ), f"opponent_template {opponent_template} must be a Path or str"
if isinstance(opponent_template, str):
opponent_template = Path(opponent_template)
assert (
diff --git a/openrl/selfplay/selfplay_api/opponent_model.py b/openrl/selfplay/selfplay_api/opponent_model.py
index af9ec6b9..b836519b 100644
--- a/openrl/selfplay/selfplay_api/opponent_model.py
+++ b/openrl/selfplay/selfplay_api/opponent_model.py
@@ -49,7 +49,7 @@ def get_battle_info(self) -> Dict[str, Any]:
result = {}
result["win_rate"] = float(self.num_wins) / max(self.num_games, 1)
result["draw_rate"] = float(self.num_draws) / max(self.num_games, 1)
- result["loss_rate"] = float(self.num_losses) / max(self.num_games, 1)
+ result["lose_rate"] = float(self.num_losses) / max(self.num_games, 1)
result["total_games"] = self.num_games
return result
diff --git a/openrl/selfplay/selfplay_api/selfplay_api.py b/openrl/selfplay/selfplay_api/selfplay_api.py
index 2c346b46..307c4fcc 100644
--- a/openrl/selfplay/selfplay_api/selfplay_api.py
+++ b/openrl/selfplay/selfplay_api/selfplay_api.py
@@ -33,7 +33,7 @@
from openrl.selfplay.selfplay_api.opponent_model import BattleResult
-@serve.deployment(route_prefix="/selfplay")
+@serve.deployment()
@serve.ingress(app)
class SelfplayAPIServer(BaseSelfplayAPIServer):
@app.post("/set_sample_strategy")
diff --git a/openrl/selfplay/strategies/__init__.py b/openrl/selfplay/strategies/__init__.py
deleted file mode 100644
index 2908f8b4..00000000
--- a/openrl/selfplay/strategies/__init__.py
+++ /dev/null
@@ -1,41 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Copyright 2023 The OpenRL Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-""""""
-from openrl.selfplay.strategies.strategies import (
- NaiveSelfplayStrategy,
- OnlyLatestSelfplayStrategy,
- VarExistEnemySelfplayStrategy,
- WeightExistEnemySelfplayStrategy,
- WeightSelfplayStrategy,
- WinRateSelfplayStrategy,
-)
-
-
-def make_strategy(strategy_name):
- if strategy_name == "Naive":
- selfplay_strategy = NaiveSelfplayStrategy
- elif strategy_name == "OnlyLatest":
- selfplay_strategy = OnlyLatestSelfplayStrategy
- elif strategy_name == "Weight":
- selfplay_strategy = WeightSelfplayStrategy
- elif strategy_name == "WinRate":
- selfplay_strategy = WinRateSelfplayStrategy
- elif strategy_name == "VarExistEnemy":
- selfplay_strategy = VarExistEnemySelfplayStrategy
- elif strategy_name == "WeightExistEnemy":
- selfplay_strategy = WeightExistEnemySelfplayStrategy
- return selfplay_strategy
diff --git a/openrl/selfplay/strategies/base_strategy.py b/openrl/selfplay/strategies/base_strategy.py
deleted file mode 100644
index 4e280b13..00000000
--- a/openrl/selfplay/strategies/base_strategy.py
+++ /dev/null
@@ -1,39 +0,0 @@
-from abc import abstractmethod
-
-
-class BaseSelfplayStrategy:
- @abstractmethod
- def __init__(self, all_args, nenvs, exist_enemy_num):
- raise NotImplementedError
-
- @abstractmethod
- def getcnt(self):
- raise NotImplementedError
-
- @abstractmethod
- def update_enemy_ids(self, new_enemy_ids):
- raise NotImplementedError
-
- @abstractmethod
- def restore(self, model_dir):
- raise NotImplementedError
-
- @abstractmethod
- def get_qlist(self):
- raise NotImplementedError
-
- @abstractmethod
- def update_weight(self, enemy_loses):
- raise NotImplementedError
-
- @abstractmethod
- def update_win_rate(self, dones, enemy_wins):
- raise NotImplementedError
-
- @abstractmethod
- def push_newone(self):
- raise NotImplementedError
-
- @abstractmethod
- def get_plist(self):
- raise NotImplementedError
diff --git a/openrl/selfplay/strategies/strategies.py b/openrl/selfplay/strategies/strategies.py
deleted file mode 100644
index 28e492ec..00000000
--- a/openrl/selfplay/strategies/strategies.py
+++ /dev/null
@@ -1,413 +0,0 @@
-import json
-
-import numpy as np
-
-from openrl.selfplay.strategies.base_strategy import BaseSelfplayStrategy
-
-
-class SelfplayStrategy(BaseSelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- # qlist和history_cnt的数据结构
- self.all_args = all_args
- self.qlist = []
- self.history_cnt = 0
- self.enemy_ids = [0] * nenvs
- self.length = nenvs
-
- def getcnt(self):
- return self.history_cnt
-
- def update_enemy_ids(self, new_enemy_ids):
- self.enemy_ids = new_enemy_ids
-
- def restore(self, model_dir):
- with open(model_dir + "/enemy_history_info.json") as f_obj:
- enemy_info = json.load(f_obj)
- self.qlist = enemy_info["qlist"]
- self.history_cnt = enemy_info["history_cnt"]
-
- def get_qlist(self):
- return self.qlist
-
- def update_weight(self, enemy_loses):
- pass
-
- def update_win_rate(self, dones, enemy_wins):
- pass
-
- def push_newone(self):
- pass
-
-
-class RatioSelfplayStrategy(SelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- super(RatioSelfplayStrategy, self).__init__(all_args, nenvs)
-
- def push_newone(self):
- self.history_cnt += 1
-
- def get_plist(self):
- if self.history_cnt == 1:
- return [1]
- temp_plist = np.logspace(
- 0, self.history_cnt - 1, self.history_cnt, endpoint=True, base=1.5
- )
- temp_plist[-1] = sum(temp_plist[:-1]) * 4
- temp_plist /= sum(temp_plist)
- return temp_plist
-
-
-class NaiveSelfplayStrategy(SelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- super(NaiveSelfplayStrategy, self).__init__(all_args, nenvs, exist_enemy_num)
-
- def push_newone(self):
- self.history_cnt += 1
-
- def get_plist(self):
- return [1] * (self.history_cnt - 1) + [4 * (self.history_cnt - 1)]
-
- def save_new_one(self):
- return True
-
-
-class OnlyLatestSelfplayStrategy(SelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- super(OnlyLatestSelfplayStrategy, self).__init__(
- all_args, nenvs, exist_enemy_num
- )
- self.play_list = []
- self.max_play_num = all_args.max_play_num
- self.least_win_rate = all_args.least_win_rate
-
- def push_newone(self):
- self.play_list.append([])
- self.history_cnt += 1
-
- def get_plist(self):
- return [0] * (self.history_cnt - 1) + [1]
-
- def save_new_one(self, least_win_rate):
- if sum(np.array(self.play_list[-1]) == -1) >= least_win_rate * (
- len(self.play_list[-1]) + 1
- ) and len(self.play_list[-1]) >= (self.max_play_num - 10):
- return True
-
- def update_play_list(self, win_enemy_ids, tie_enemy_ids, lose_enemy_ids):
- for win_enemy_id in win_enemy_ids:
- self.play_list[win_enemy_id].append(1)
- for tie_enemy_id in tie_enemy_ids:
- self.play_list[tie_enemy_id].append(0)
- for lose_enemy_id in lose_enemy_ids:
- self.play_list[lose_enemy_id].append(-1)
- self.cut_overflow()
-
- def update_win_rate(self, enemy_wins, enemy_ties, enemy_loses):
- win_enemy_ids = np.array(self.enemy_ids)[enemy_wins]
- tie_enemy_ids = np.array(self.enemy_ids)[enemy_ties]
- lose_enemy_ids = np.array(self.enemy_ids)[enemy_loses]
- self.update_play_list(win_enemy_ids, tie_enemy_ids, lose_enemy_ids)
-
- def cut_overflow(self):
- for index in range(len(self.play_list)):
- if len(self.play_list[index]) > self.max_play_num:
- self.play_list[index] = self.play_list[index][
- (-1) * self.max_play_num :
- ]
-
- def get_info_list(self, info_list):
- return_info = []
- for info in info_list:
- if info == "win":
- equal_num = 1
- elif info == "tie":
- equal_num = 0
- elif info == "lose":
- equal_num = -1
- num_list = []
- for enemy_play_list in self.play_list:
- if info == "play":
- num_list.append(len(enemy_play_list))
- else:
- num_list.append(int(sum(np.array(enemy_play_list) == equal_num)))
- return_info.append(num_list)
- return tuple(return_info)
-
- def get_enemy_play_dict(self):
- win_num_list, tie_num_list, lose_num_list, play_num_list = self.get_info_list(
- ["win", "tie", "lose", "play"]
- )
- return {
- "win_num_list": list(win_num_list),
- "tie_num_list": list(tie_num_list),
- "lose_num_list": list(lose_num_list),
- "play_num_list": list(play_num_list),
- }
-
-
-class WeightSelfplayStrategy(SelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- super(WeightSelfplayStrategy, self).__init__(all_args, nenvs, exist_enemy_num)
- self.recent_weight = 0.8
- self.recent_num = 3
- self.gama = 1 / (nenvs)
-
- def push_newone(self):
- self.history_cnt += 1
- if self.history_cnt <= self.recent_num:
- return
- elif self.history_cnt == self.recent_num + 1:
- self.qlist = [1]
- else:
- self.qlist.append(max(self.qlist))
-
- def get_plist(self):
- temp_plist = np.zeros([self.history_cnt])
- temp_plist[: (-1 * self.recent_num)] = (
- np.exp(self.qlist) / sum(np.exp(self.qlist)) * (1 - self.recent_weight)
- )
- temp_plist[(-1 * self.recent_num) :] = self.recent_weight / self.recent_num
- return temp_plist
-
- def update_weight(self, enemy_loses):
- if self.history_cnt < self.recent_num + 2:
- return
- lose_enemy_ids = np.array(self.enemy_ids)[
- enemy_loses
- ] # 输了的enemy_ids,进行更新,其中可能有重复的enemy_id
- for enemy_id in lose_enemy_ids:
- if enemy_id <= len(self.qlist) - 1:
- divide_num = (
- len(self.qlist)
- * np.exp(self.qlist[enemy_id])
- / sum(np.exp(self.qlist))
- )
- next_weight = self.qlist[enemy_id] - self.gama / divide_num
- self.qlist[enemy_id] = next_weight
-
-
-class WinRateSelfplayStrategy(SelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- super(WinRateSelfplayStrategy, self).__init__(all_args, nenvs, exist_enemy_num)
- self.max_play_num = all_args.max_play_num
- self.play_list = (
- []
- ) # 在该list中,每个对手维护一个长度不超过max_play_num的列表,1为该对手获胜, 0为平, -1为我方获胜
- self.recent_list = []
- self.recent_list_max_len = all_args.recent_list_max_len
- self.latest_weight = all_args.latest_weight
- self.least_win_rate = all_args.least_win_rate
- self.stage2_least_win_rate = all_args.least_win_rate
- self.stage = 1
- self.newest_pos = all_args.newest_pos
- self.newest_weight = all_args.newest_weight
-
- def push_newone(self):
- self.play_list.append([])
- self.history_cnt += 1
-
- def get_info_list(self, info_list):
- return_info = []
- for info in info_list:
- if info == "win":
- equal_num = 1
- elif info == "tie":
- equal_num = 0
- elif info == "lose":
- equal_num = -1
- num_list = []
- for enemy_play_list in self.play_list:
- if info == "play":
- num_list.append(len(enemy_play_list))
- else:
- num_list.append(int(sum(np.array(enemy_play_list) == equal_num)))
- return_info.append(num_list)
- return tuple(return_info)
-
- def get_plist(self):
- def f_hard(win_rate_list):
- p = 1
- return win_rate_list**p
-
- def f_var(win_rate_list):
- return (1 - win_rate_list) * win_rate_list
-
- win_num_list, tie_num_list, play_num_list = self.get_info_list(
- ["win", "tie", "play"]
- )
- win_rate_list = (
- np.array(win_num_list) + 0.5 * np.array(tie_num_list) + 0.5
- ) / (np.array(play_num_list) + 1)
- return f_hard(win_rate_list)
-
- def update_play_list(self, win_enemy_ids, tie_enemy_ids, lose_enemy_ids):
- if self.stage == 2:
- win_enemy_num = (np.array(win_enemy_ids) != self.newest_pos).sum()
- tie_enemy_num = (np.array(tie_enemy_ids) != self.newest_pos).sum()
- lose_enemy_num = (np.array(lose_enemy_ids) != self.newest_pos).sum()
- self.recent_list += (
- [1] * win_enemy_num + [0] * tie_enemy_num + [-1] * lose_enemy_num
- )
- for win_enemy_id in win_enemy_ids:
- self.play_list[win_enemy_id].append(1)
- for tie_enemy_id in tie_enemy_ids:
- self.play_list[tie_enemy_id].append(0)
- for lose_enemy_id in lose_enemy_ids:
- self.play_list[lose_enemy_id].append(-1)
- self.cut_overflow()
-
- def update_win_rate(self, enemy_wins, enemy_ties, enemy_loses):
- win_enemy_ids = np.array(self.enemy_ids)[enemy_wins]
- tie_enemy_ids = np.array(self.enemy_ids)[enemy_ties]
- lose_enemy_ids = np.array(self.enemy_ids)[enemy_loses]
- self.update_play_list(win_enemy_ids, tie_enemy_ids, lose_enemy_ids)
-
- def restore(self, model_dir):
- with open(model_dir + "/enemy_history_info.json") as f_obj:
- enemy_info = json.load(f_obj)
- self.history_cnt = enemy_info["history_cnt"]
- self.play_list = enemy_info["play_list"]
-
- def get_enemy_play_dict(self):
- win_num_list, tie_num_list, lose_num_list, play_num_list = self.get_info_list(
- ["win", "tie", "lose", "play"]
- )
- return {
- "win_num_list": list(win_num_list),
- "tie_num_list": list(tie_num_list),
- "lose_num_list": list(lose_num_list),
- "play_num_list": list(play_num_list),
- }
-
- def update_win_info(self, data):
- win_enemy_ids, tie_enemy_ids, lose_enemy_ids = (
- data["win_enemy_ids"],
- data["tie_enemy_ids"],
- data["lose_enemy_ids"],
- )
- self.update_play_list(win_enemy_ids, tie_enemy_ids, lose_enemy_ids)
-
- def cut_overflow(self):
- for index in range(len(self.play_list)):
- if len(self.play_list[index]) > self.max_play_num:
- self.play_list[index] = self.play_list[index][
- (-1) * self.max_play_num :
- ]
- if len(self.recent_list) > self.recent_list_max_len:
- self.recent_list = self.recent_list[(-1) * self.recent_list_max_len :]
-
- def save_new_one(self, least_win_rate):
- if self.stage == 1:
- if sum(np.array(self.play_list[-1]) == -1) >= least_win_rate * (
- len(self.play_list[-1]) + 1
- ) and len(self.play_list[-1]) >= (self.max_play_num - 10):
- if self.getcnt() - self.all_args.exist_enemy_num == 1:
- return True
- self.stage = 2
- print("switch to stage 2")
- if self.stage == 2:
- if sum(np.array(self.recent_list) == -1) >= self.stage2_least_win_rate * (
- len(self.recent_list) + 1
- ) and len(self.recent_list) >= (self.recent_list_max_len - 10):
- self.stage = 1
- self.recent_list = []
- return True
- return False
-
-
-class ExistEnemySelfplayStrategy(WinRateSelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- super(ExistEnemySelfplayStrategy, self).__init__(
- all_args, nenvs, exist_enemy_num
- )
- self.all_args = all_args
- self.enemy_ids = [0] * nenvs # 第一个step就会更新,所以初始化无所谓
- # 列表的前exist_enemy_num个为已存在的对手
- if exist_enemy_num > 0:
- self.play_list = [[]] * exist_enemy_num
- self.history_cnt = exist_enemy_num
- self.exist_enemy_num = exist_enemy_num
- self.max_enemy_num = all_args.max_enemy_num
-
- def get_final_plist(self, f_hard, f_var):
- raise NotImplementedError
-
- def get_plist(self):
- def f_hard(win_rate_list):
- p = 2
- return win_rate_list**p
-
- def f_var(win_rate_list):
- return (1 - win_rate_list) * win_rate_list
-
- plist = self.get_final_plist(f_hard, f_var)
- if self.max_enemy_num != -1:
- if self.history_cnt - self.exist_enemy_num > self.max_enemy_num:
- mask_index = np.array(
- list(
- range(
- self.exist_enemy_num, self.history_cnt - self.max_enemy_num
- )
- )
- )
- zero_vec = np.zeros(
- self.history_cnt - self.exist_enemy_num - self.max_enemy_num
- )
- plist[mask_index] = zero_vec
-
- return plist
-
-
-class VarExistEnemySelfplayStrategy(ExistEnemySelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- super(VarExistEnemySelfplayStrategy, self).__init__(
- all_args, nenvs, exist_enemy_num
- )
-
- def get_final_plist(self, f_hard, f_var):
- win_num_list, tie_num_list, play_num_list = self.get_info_list(
- ["win", "tie", "play"]
- )
- win_rate_list = (
- np.array(win_num_list) + 0.5 * np.array(tie_num_list) + 0.5
- ) / (np.array(play_num_list) + 1)
- win_rate_list = f_var(win_rate_list)
-
- return win_rate_list
-
-
-class WeightExistEnemySelfplayStrategy(ExistEnemySelfplayStrategy):
- def __init__(self, all_args, nenvs, exist_enemy_num):
- super(WeightExistEnemySelfplayStrategy, self).__init__(
- all_args, nenvs, exist_enemy_num
- )
-
- def get_final_plist(self, f_hard, f_var):
- win_num_list, tie_num_list, play_num_list = self.get_info_list(
- ["win", "tie", "play"]
- )
- win_rate_list = (
- np.array(win_num_list) + 0.5 * np.array(tie_num_list) + 0.5
- ) / (np.array(play_num_list) + 1)
-
- if self.stage == 1:
- win_rate_list = f_hard(win_rate_list)[:-1]
- # if self.newest_pos != -1:
- # win_rate_list[self.newest_pos] = 0
- win_rate_list = (
- win_rate_list / (sum(win_rate_list) + 1e-8) * (1 - self.latest_weight)
- )
- return list(win_rate_list) + [self.latest_weight]
- elif self.stage == 2:
- win_rate_list = f_hard(win_rate_list)
- if self.newest_pos != -1:
- win_rate_list[self.newest_pos] = self.newest_weight
- index_without_newest = list(range(self.history_cnt))
- index_without_newest.remove(self.newest_pos)
- win_rate_list[index_without_newest] /= sum(
- win_rate_list[index_without_newest]
- )
- win_rate_list[index_without_newest] *= 1 - self.newest_weight
- else:
- win_rate_list /= sum(win_rate_list)
- return win_rate_list
diff --git a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py
index a3de3c0f..ca8d1e95 100644
--- a/openrl/selfplay/wrappers/base_multiplayer_wrapper.py
+++ b/openrl/selfplay/wrappers/base_multiplayer_wrapper.py
@@ -104,6 +104,7 @@ def reset(self, *, seed: Optional[int] = None, **kwargs):
action = self.get_opponent_action(
player_name, observation, reward, termination, truncation, info
)
+
self.env.step(action)
def on_episode_end(
@@ -147,10 +148,18 @@ def _step(self, action):
if termination or truncation:
return (
copy.copy(self.env.observe(self.self_player)),
- self.env.rewards[self.self_player],
+ (
+ self.env.rewards[self.self_player]
+ if self.self_player in self.env.rewards
+ else 0
+ ),
termination,
truncation,
- self.env.infos[self.self_player],
+ (
+ self.env.infos[self.self_player]
+ if self.self_player in self.env.rewards
+ else {}
+ ),
)
else:
diff --git a/openrl/utils/callbacks/checkpoint_callback.py b/openrl/utils/callbacks/checkpoint_callback.py
index a4b3f5b6..56bf31b8 100644
--- a/openrl/utils/callbacks/checkpoint_callback.py
+++ b/openrl/utils/callbacks/checkpoint_callback.py
@@ -72,9 +72,7 @@ def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> st
"""
return os.path.join(
self.save_path,
- (
- f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}"
- ),
+ f"{self.name_prefix}_{checkpoint_type}{self.num_time_steps}_steps{'.' if extension else ''}{extension}",
)
def _on_step(self) -> bool:
diff --git a/openrl/utils/evaluation.py b/openrl/utils/evaluation.py
index d603daa5..c008c437 100644
--- a/openrl/utils/evaluation.py
+++ b/openrl/utils/evaluation.py
@@ -68,12 +68,10 @@ def evaluate_policy(
if not is_monitor_wrapped and warn:
warnings.warn(
- (
- "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
- " may result in reporting modified episode lengths and rewards, if"
- " other wrappers happen to modify these. Consider wrapping environment"
- " first with ``Monitor`` wrapper."
- ),
+ "Evaluation environment is not wrapped with a ``Monitor`` wrapper. This"
+ " may result in reporting modified episode lengths and rewards, if"
+ " other wrappers happen to modify these. Consider wrapping environment"
+ " first with ``Monitor`` wrapper.",
UserWarning,
)
@@ -97,9 +95,13 @@ def evaluate_policy(
episode_starts = np.ones((env.parallel_env_num,), dtype=bool)
while (episode_counts < episode_count_targets).any():
+ if not np.all(episode_starts == 0):
+ episode_starts_tmp = episode_starts
+ else:
+ episode_starts_tmp = None
+
actions, states = agent.act(
- observations,
- deterministic=deterministic,
+ observations, deterministic=deterministic, episode_starts=episode_starts_tmp
)
observations, rewards, dones, infos = env.step(actions)
rewards = np.squeeze(rewards, axis=-1)
diff --git a/openrl/utils/logger.py b/openrl/utils/logger.py
index 0f2f0e2e..d9c49f34 100644
--- a/openrl/utils/logger.py
+++ b/openrl/utils/logger.py
@@ -32,9 +32,9 @@ class Logger:
def __init__(
self,
cfg,
- project_name: str,
- scenario_name: str,
- wandb_entity: str,
+ project_name: str = "openrl",
+ scenario_name: str = "openrl",
+ wandb_entity: str = "openrl",
exp_name: Optional[str] = None,
log_path: Optional[str] = None,
use_wandb: bool = False,
@@ -46,6 +46,10 @@ def __init__(
self.use_wandb = use_wandb
self.use_tensorboard = use_tensorboard
+ self.skip_logging = False
+ if cfg.use_deepspeed and cfg.local_rank != 0:
+ self.skip_logging = True
+
self.log_level = log_level
self.log_path = log_path
self.project_name = project_name
@@ -126,20 +130,21 @@ def _init(self) -> None:
)
if self.use_wandb:
- wandb.init(
- config=self.cfg,
- project=self.project_name,
- entity=self.wandb_entity,
- notes=socket.gethostname(),
- name=self.scenario_name
- + "_"
- + str(self.exp_name)
- + "_seed"
- + str(self.cfg.seed),
- dir=str(run_dir),
- job_type="training",
- reinit=True,
- )
+ if not self.skip_logging:
+ wandb.init(
+ config=self.cfg,
+ project=self.project_name,
+ entity=self.wandb_entity,
+ notes=socket.gethostname(),
+ name=self.scenario_name
+ + "_"
+ + str(self.exp_name)
+ + "_seed"
+ + str(self.cfg.seed),
+ dir=str(run_dir),
+ job_type="training",
+ reinit=True,
+ )
elif self.use_tensorboard:
from tensorboardX import SummaryWriter
@@ -152,7 +157,8 @@ def _init(self) -> None:
def close(self):
if self.use_wandb:
- wandb.finish()
+ if not self.skip_logging:
+ wandb.finish()
def info(self, msg: str):
logging.info(msg)
@@ -167,7 +173,8 @@ def log_learner_info(
return
for k, v in infos.items():
if self.use_wandb:
- wandb.log({"Learner_{}/{}".format(leaner_id, k): v}, step=step)
+ if not self.skip_logging:
+ wandb.log({"Learner_{}/{}".format(leaner_id, k): v}, step=step)
elif self.use_tensorboard:
self.writter.add_scalars(
"Learner_{}/{}".format(leaner_id, k),
@@ -192,7 +199,8 @@ def log_info(
logging_info_str += f"\t{k}: {v}\n"
if self.use_wandb:
- wandb.log({k: v}, step=step)
+ if not self.skip_logging:
+ wandb.log({k: v}, step=step)
elif self.use_tensorboard:
self.writter.add_scalars(k, {k: v}, step)
if self.log_to_terminal:
diff --git a/openrl/utils/type_aliases.py b/openrl/utils/type_aliases.py
index 25991e24..d9012d7d 100644
--- a/openrl/utils/type_aliases.py
+++ b/openrl/utils/type_aliases.py
@@ -13,9 +13,7 @@
GymEnv = Union[gym.Env, vec_env.BaseVecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
-GymStepReturn = Union[
- Tuple[GymObs, float, bool, Dict], Tuple[GymObs, float, bool, bool, Dict]
-]
+
TensorDict = Dict[Union[str, int], th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[
diff --git a/setup.py b/setup.py
index 494e4c4c..28cffd3c 100644
--- a/setup.py
+++ b/setup.py
@@ -25,7 +25,7 @@
def get_install_requires() -> list:
return [
"setuptools>=67.0",
- "gymnasium",
+ "gymnasium>=0.29",
"click",
"termcolor",
"gym",
@@ -60,16 +60,39 @@ def get_extra_requires() -> dict:
"mpe": ["pyglet==1.5.27"],
"nlp": [
"transformers==4.18.0",
- "datasets",
+ "datasets==2.13",
"nltk",
"evaluate",
"icetk",
],
- "selfplay": ["ray[default]", "ray[serve]", "pettingzoo[classic]", "trueskill"],
+ "nlp_test": [
+ "transformers",
+ "datasets==2.13",
+ "evaluate",
+ ],
+ "selfplay": [
+ "ray[default]>=2.7",
+ "ray[serve]",
+ "async_timeout",
+ "pettingzoo[classic]",
+ "trueskill",
+ ],
+ "selfplay_test": [
+ "ray[default]>=2.7",
+ "ray[serve]",
+ "async_timeout",
+ "fastapi",
+ "pettingzoo[mpe]",
+ "pettingzoo[butterfly]",
+ ],
"retro": ["gym-retro"],
"super_mario": ["gym-super-mario-bros"],
+ "atari": ["gymnasium[atari]", "gymnasium[accept-rom-license]"],
}
req["test"].extend(req["selfplay"])
+ req["test"].extend(req["selfplay_test"])
+ req["test"].extend(req["atari"])
+ req["test"].extend(req["nlp_test"])
return req
diff --git a/tests/test_algorithm/test_a2c_algorithm.py b/tests/test_algorithm/test_a2c_algorithm.py
new file mode 100644
index 00000000..0f4f7226
--- /dev/null
+++ b/tests/test_algorithm/test_a2c_algorithm.py
@@ -0,0 +1,95 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import numpy as np
+import pytest
+from gymnasium import spaces
+
+
+@pytest.fixture
+def obs_space():
+ return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32)
+
+
+@pytest.fixture
+def act_space():
+ return spaces.Discrete(2)
+
+
+@pytest.fixture(
+ scope="module", params=["--use_share_model false", "--use_share_model true"]
+)
+def config(request):
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.fixture
+def amp_config():
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args("")
+ return cfg
+
+
+@pytest.fixture
+def init_module(config, obs_space, act_space):
+ from openrl.modules.ppo_module import PPOModule
+
+ module = PPOModule(
+ config,
+ policy_input_space=obs_space,
+ critic_input_space=obs_space,
+ act_space=act_space,
+ share_model=config.use_share_model,
+ )
+ return module
+
+
+@pytest.fixture
+def buffer_data(config, obs_space, act_space):
+ from openrl.buffers.normal_buffer import NormalReplayBuffer
+
+ buffer = NormalReplayBuffer(
+ config,
+ num_agents=1,
+ obs_space=obs_space,
+ act_space=act_space,
+ data_client=None,
+ episode_length=100,
+ )
+ return buffer.data
+
+
+@pytest.mark.unittest
+def test_a2c_algorithm(config, init_module, buffer_data):
+ from openrl.algorithms.a2c import A2CAlgorithm
+
+ a2c_algo = A2CAlgorithm(config, init_module)
+
+ a2c_algo.train(buffer_data)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_algorithm/test_bc_algorithm.py b/tests/test_algorithm/test_bc_algorithm.py
new file mode 100644
index 00000000..fa073174
--- /dev/null
+++ b/tests/test_algorithm/test_bc_algorithm.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import numpy as np
+import pytest
+from gymnasium import spaces
+
+
+@pytest.fixture
+def obs_space():
+ return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32)
+
+
+@pytest.fixture
+def act_space():
+ return spaces.Discrete(2)
+
+
+@pytest.fixture(scope="module", params=["", "--use_share_model true"])
+def config(request):
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.fixture
+def init_module(config, obs_space, act_space):
+ from openrl.modules.bc_module import BCModule
+
+ module = BCModule(
+ config,
+ policy_input_space=obs_space,
+ critic_input_space=obs_space,
+ act_space=act_space,
+ share_model=config.use_share_model,
+ )
+ return module
+
+
+@pytest.fixture
+def buffer_data(config, obs_space, act_space):
+ from openrl.buffers.normal_buffer import NormalReplayBuffer
+
+ buffer = NormalReplayBuffer(
+ config,
+ num_agents=1,
+ obs_space=obs_space,
+ act_space=act_space,
+ data_client=None,
+ episode_length=100,
+ )
+ return buffer.data
+
+
+@pytest.mark.unittest
+def test_bc_algorithm(config, init_module, buffer_data):
+ from openrl.algorithms.behavior_cloning import BCAlgorithm
+
+ bc_algo = BCAlgorithm(config, init_module)
+
+ bc_algo.train(buffer_data)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_algorithm/test_ddpg_algorithm.py b/tests/test_algorithm/test_ddpg_algorithm.py
new file mode 100644
index 00000000..b31a56df
--- /dev/null
+++ b/tests/test_algorithm/test_ddpg_algorithm.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import numpy as np
+import pytest
+from gymnasium import spaces
+
+
+@pytest.fixture
+def obs_space():
+ return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32)
+
+
+@pytest.fixture
+def act_space():
+ return spaces.box.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32)
+
+
+@pytest.fixture(scope="module", params=[""])
+def config(request):
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.fixture
+def amp_config():
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args("")
+ return cfg
+
+
+@pytest.fixture
+def init_module(config, obs_space, act_space):
+ from openrl.modules.ddpg_module import DDPGModule
+
+ module = DDPGModule(
+ config,
+ input_space=obs_space,
+ act_space=act_space,
+ )
+ return module
+
+
+@pytest.fixture
+def buffer_data(config, obs_space, act_space):
+ from openrl.buffers.offpolicy_buffer import OffPolicyReplayBuffer
+
+ buffer = OffPolicyReplayBuffer(
+ config,
+ num_agents=1,
+ obs_space=obs_space,
+ act_space=act_space,
+ data_client=None,
+ episode_length=5000,
+ )
+ return buffer.data
+
+
+@pytest.mark.unittest
+def test_ddpg_algorithm(config, init_module, buffer_data):
+ from openrl.algorithms.ddpg import DDPGAlgorithm
+
+ ddpg_algo = DDPGAlgorithm(config, init_module)
+
+ ddpg_algo.train(buffer_data)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_algorithm/test_ppo_algorithm.py b/tests/test_algorithm/test_ppo_algorithm.py
index 8ac5c865..98a8a5d4 100644
--- a/tests/test_algorithm/test_ppo_algorithm.py
+++ b/tests/test_algorithm/test_ppo_algorithm.py
@@ -33,7 +33,9 @@ def act_space():
return spaces.Discrete(2)
-@pytest.fixture(scope="module", params=["", "--use_share_model true"])
+@pytest.fixture(
+ scope="module", params=["--use_share_model false", "--use_share_model true"]
+)
def config(request):
from openrl.configs.config import create_config_parser
diff --git a/tests/test_algorithm/test_sac_algorithm.py b/tests/test_algorithm/test_sac_algorithm.py
new file mode 100644
index 00000000..80447a3a
--- /dev/null
+++ b/tests/test_algorithm/test_sac_algorithm.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import numpy as np
+import pytest
+from gymnasium import spaces
+
+
+@pytest.fixture
+def obs_space():
+ return spaces.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32)
+
+
+@pytest.fixture
+def act_space():
+ return spaces.box.Box(low=-np.inf, high=+np.inf, shape=(1,), dtype=np.float32)
+
+
+@pytest.fixture(scope="module", params=[""])
+def config(request):
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.fixture
+def amp_config():
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args("")
+ return cfg
+
+
+@pytest.fixture
+def init_module(config, obs_space, act_space):
+ from openrl.modules.sac_module import SACModule
+
+ module = SACModule(
+ config,
+ input_space=obs_space,
+ act_space=act_space,
+ )
+ return module
+
+
+@pytest.fixture
+def buffer_data(config, obs_space, act_space):
+ from openrl.buffers.offpolicy_buffer import OffPolicyReplayBuffer
+
+ buffer = OffPolicyReplayBuffer(
+ config,
+ num_agents=1,
+ obs_space=obs_space,
+ act_space=act_space,
+ data_client=None,
+ episode_length=5000,
+ )
+ return buffer.data
+
+
+@pytest.mark.unittest
+def test_sac_algorithm(config, init_module, buffer_data):
+ from openrl.algorithms.sac import SACAlgorithm
+
+ sac_algo = SACAlgorithm(config, init_module)
+
+ sac_algo.train(buffer_data)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_arena/test_new_envs.py b/tests/test_arena/test_new_envs.py
new file mode 100644
index 00000000..7a5dc01d
--- /dev/null
+++ b/tests/test_arena/test_new_envs.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import pytest
+from pettingzoo.butterfly import cooperative_pong_v5
+from pettingzoo.classic import connect_four_v3, go_v5, texas_holdem_no_limit_v6
+from pettingzoo.mpe import simple_push_v3
+
+from examples.custom_env.rock_paper_scissors import RockPaperScissors
+from openrl.arena import make_arena
+from openrl.arena.agents.local_agent import LocalAgent
+from openrl.arena.agents.random_agent import RandomAgent
+from openrl.envs.PettingZoo.registration import register
+from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
+
+
+def ConnectFourEnv(render_mode, **kwargs):
+ return connect_four_v3.env(render_mode)
+
+
+def RockPaperScissorsEnv(render_mode, **kwargs):
+ return RockPaperScissors(render_mode)
+
+
+def GoEnv(render_mode, **kwargs):
+ return go_v5.env(render_mode=render_mode, board_size=5, komi=7.5)
+
+
+def TexasHoldemEnv(render_mode, **kwargs):
+ return texas_holdem_no_limit_v6.env(render_mode=render_mode)
+
+
+# MPE
+def SimplePushEnv(render_mode, **kwargs):
+ return simple_push_v3.env(render_mode=render_mode)
+
+
+def CooperativePongEnv(render_mode, **kwargs):
+ return cooperative_pong_v5.env(render_mode=render_mode)
+
+
+def register_new_envs():
+ new_env_dict = {
+ "connect_four_v3": ConnectFourEnv,
+ "RockPaperScissors": RockPaperScissorsEnv,
+ "go_v5": GoEnv,
+ "texas_holdem_no_limit_v6": TexasHoldemEnv,
+ "simple_push_v3": SimplePushEnv,
+ "cooperative_pong_v5": CooperativePongEnv,
+ }
+
+ for env_id, env in new_env_dict.items():
+ register(env_id, env)
+ return new_env_dict.keys()
+
+
+def run_arena(
+ env_id: str,
+ parallel: bool = True,
+ seed=0,
+ total_games: int = 10,
+ max_game_onetime: int = 5,
+):
+ env_wrappers = [RecordWinner]
+
+ arena = make_arena(env_id, env_wrappers=env_wrappers, use_tqdm=False)
+
+ agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
+ agent2 = RandomAgent()
+
+ arena.reset(
+ agents={"agent1": agent1, "agent2": agent2},
+ total_games=total_games,
+ max_game_onetime=max_game_onetime,
+ seed=seed,
+ )
+ result = arena.run(parallel=parallel)
+ arena.close()
+ return result
+
+
+@pytest.mark.unittest
+def test_new_envs():
+ env_ids = register_new_envs()
+ seed = 0
+ for env_id in env_ids:
+ run_arena(env_id=env_id, seed=seed, parallel=False, total_games=1)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_arena/test_reproducibility.py b/tests/test_arena/test_reproducibility.py
index 0d186ab0..9ced525c 100644
--- a/tests/test_arena/test_reproducibility.py
+++ b/tests/test_arena/test_reproducibility.py
@@ -22,6 +22,7 @@
from openrl.arena import make_arena
from openrl.arena.agents.local_agent import LocalAgent
+from openrl.arena.agents.random_agent import RandomAgent
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
@@ -41,7 +42,7 @@ def run_arena(
arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=False)
agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
- agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
+ agent2 = RandomAgent()
arena.reset(
agents={"agent1": agent1, "agent2": agent2},
diff --git a/tests/test_buffer/test_generator.py b/tests/test_buffer/test_generator.py
new file mode 100644
index 00000000..27763635
--- /dev/null
+++ b/tests/test_buffer/test_generator.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import pytest
+
+from openrl.envs.common import make
+from openrl.modules.common import PPONet as Net
+from openrl.runners.common import PPOAgent as Agent
+
+
+@pytest.fixture(scope="module", params=["--episode_length 10"])
+def episode_length(request):
+ return request.param
+
+
+@pytest.fixture(
+ scope="module",
+ params=[
+ "--use_recurrent_policy true --use_joint_action_loss true",
+ "--use_recurrent_policy true --use_joint_action_loss false",
+ "--use_recurrent_policy false --use_naive_recurrent true",
+ "--use_recurrent_policy false --use_naive_recurrent false",
+ ],
+)
+def generator_type(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["--use_gae true", "--use_gae false"])
+def use_gae(request):
+ return request.param
+
+
+@pytest.fixture(
+ scope="module",
+ params=["--use_proper_time_limits true", "--use_proper_time_limits false"],
+)
+def use_proper_time_limits(request):
+ return request.param
+
+
+@pytest.fixture(
+ scope="module",
+ params=[
+ "--use_popart true --use_valuenorm false",
+ "--use_popart false --use_valuenorm true",
+ "--use_popart false --use_valuenorm false",
+ ],
+)
+def use_popart(request):
+ return request.param
+
+
+@pytest.fixture(scope="module")
+def config(use_proper_time_limits, use_popart, use_gae, generator_type, episode_length):
+ config_str = (
+ use_proper_time_limits
+ + " "
+ + use_popart
+ + " "
+ + use_gae
+ + " "
+ + generator_type
+ + " "
+ + episode_length
+ )
+
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(config_str.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_buffer_generator(config):
+ env = make("CartPole-v1", env_num=2)
+ agent = Agent(Net(env, cfg=config))
+ agent.train(total_time_steps=50)
+ env.close()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_buffer/test_offpolicy_generator.py b/tests/test_buffer/test_offpolicy_generator.py
new file mode 100644
index 00000000..ec960973
--- /dev/null
+++ b/tests/test_buffer/test_offpolicy_generator.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import pytest
+
+from openrl.envs.common import make
+from openrl.modules.common import DQNNet as Net
+from openrl.runners.common import DQNAgent as Agent
+
+
+@pytest.fixture(scope="module", params=["--episode_length 10"])
+def episode_length(request):
+ return request.param
+
+
+@pytest.fixture(
+ scope="module",
+ params=[
+ "--use_recurrent_policy false --use_joint_action_loss false",
+ ],
+)
+def generator_type(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["--use_proper_time_limits false"])
+def use_proper_time_limits(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=["--use_popart false --use_valuenorm false"])
+def use_popart(request):
+ return request.param
+
+
+@pytest.fixture(scope="module")
+def config(use_proper_time_limits, use_popart, generator_type, episode_length):
+ config_str = (
+ use_proper_time_limits
+ + " "
+ + use_popart
+ + " "
+ + generator_type
+ + " "
+ + episode_length
+ )
+
+ from openrl.configs.config import create_config_parser
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(config_str.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_buffer_generator(config):
+ env = make("CartPole-v1", env_num=2)
+ agent = Agent(Net(env, cfg=config))
+ agent.train(total_time_steps=50)
+ env.close()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_dataset/test_expert_dataset.py b/tests/test_dataset/test_expert_dataset.py
new file mode 100644
index 00000000..1eed9125
--- /dev/null
+++ b/tests/test_dataset/test_expert_dataset.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import pytest
+import torch
+
+from openrl.datasets.expert_dataset import ExpertDataset
+from openrl.envs.common import make
+from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper
+from openrl.envs.wrappers.monitor import Monitor
+
+env_wrappers = [
+ Monitor,
+]
+
+
+def gen_data(total_episode, data_save_path):
+ # begin to test
+ # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
+
+ env = make(
+ "IdentityEnv",
+ env_num=1,
+ asynchronous=True,
+ env_wrappers=env_wrappers,
+ )
+
+ env = GenDataWrapper(
+ env, data_save_path=data_save_path, total_episode=total_episode
+ )
+ env.reset()
+ done = False
+ ep_length = 0
+ while not done:
+ obs, r, done, info = env.step(env.random_action())
+ ep_length += 1
+ env.close()
+ return ep_length
+
+
+@pytest.mark.unittest
+def test_expert_dataset(tmp_path):
+ total_episode = 1
+ data_save_path = tmp_path / "data.pkl"
+ ep_length = gen_data(total_episode, data_save_path)
+
+ dataset = ExpertDataset(
+ data_save_path,
+ num_trajectories=None,
+ subsample_frequency=1,
+ seed=None,
+ env_id=0,
+ env_num=1,
+ )
+ assert len(dataset) == ep_length, "len(dataset)={},data_length={}".format(
+ len(dataset), ep_length
+ )
+ assert len(dataset[0]) == 2, "len(dataset[0])={}".format(len(dataset[0]))
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset=dataset, batch_size=1, shuffle=False, drop_last=True
+ )
+
+ step = 0
+ for batch_data in data_loader:
+ assert len(batch_data) == 2, "len(batch_data)={}".format(len(batch_data))
+ step += 1
+ assert step == ep_length, "step={},ep_length={}".format(step, ep_length)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_env/test_mpe_env.py b/tests/test_env/test_mpe_env.py
index 2dd664b6..0b555bb2 100644
--- a/tests/test_env/test_mpe_env.py
+++ b/tests/test_env/test_mpe_env.py
@@ -18,15 +18,16 @@
import os
import sys
+import numpy as np
import pytest
+from openrl.envs.common import make
+
@pytest.mark.unittest
def test_mpe():
- from openrl.envs.common import make
-
- env_num = 6
- env = make("simple_spread", env_num=6)
+ env_num = 3
+ env = make("simple_spread", env_num=env_num)
obs, info = env.reset()
obs, reward, done, info = env.step(env.random_action())
assert env.agent_num == 3
@@ -34,5 +35,27 @@ def test_mpe():
env.close()
+@pytest.mark.unittest
+def test_mpe_render():
+ render_model = "human"
+ env_num = 2
+ env = make(
+ "simple_spread", render_mode=render_model, env_num=env_num, asynchronous=False
+ )
+
+ env.reset(seed=0)
+ done = False
+ step = 0
+ total_reward = 0
+ while not np.any(done):
+ # Based on environmental observation input, predict next action.
+
+ obs, r, done, info = env.step(env.random_action())
+ step += 1
+ total_reward += np.mean(r)
+
+ env.close()
+
+
if __name__ == "__main__":
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_env/test_nlp/test_DailyDialogEnv.py b/tests/test_env/test_nlp/test_DailyDialogEnv.py
new file mode 100644
index 00000000..6f0ac1df
--- /dev/null
+++ b/tests/test_env/test_nlp/test_DailyDialogEnv.py
@@ -0,0 +1,44 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import pytest
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+
+
+@pytest.fixture(
+ scope="module",
+ params=["--env.args {'data_path':None,'tokenizer_path':'builtin_BPE'}"],
+)
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_DailyDialogEnv(config):
+ env = make("daily_dialog", env_num=1, asynchronous=False, cfg=config)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_env/test_snake_env.py b/tests/test_env/test_snake_env.py
new file mode 100644
index 00000000..8e558231
--- /dev/null
+++ b/tests/test_env/test_snake_env.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import gymnasium as gym
+import numpy as np
+import pytest
+from gymnasium import spaces
+
+from openrl.envs.common import make
+from openrl.envs.wrappers.base_wrapper import BaseObservationWrapper
+from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper
+
+
+class ConvertObs(BaseObservationWrapper):
+ def __init__(self, env: gym.Env):
+ BaseObservationWrapper.__init__(self, env)
+ self.observation_space = spaces.Box(
+ low=-np.inf, high=np.inf, shape=(576,), dtype=np.float32
+ )
+
+ def observation(self, observation):
+ new_obs = np.zeros((len(observation), 576), dtype=int)
+ return new_obs
+
+
+@pytest.mark.unittest
+def test_snake():
+ env_num = 2
+ for i in [1, 3]:
+ env = make(
+ f"snakes_{i}v{i}",
+ env_num=env_num,
+ asynchronous=False,
+ opponent_wrappers=[RandomOpponentWrapper],
+ env_wrappers=[ConvertObs],
+ auto_reset=False,
+ )
+ ep_num = 3
+ for ep_now in range(ep_num):
+ obs, info = env.reset()
+ done = False
+ step = 0
+
+ while not np.any(done):
+ obs, r, done, info = env.step(env.random_action())
+ step += 1
+
+ env.close()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_env/test_vec_env/test_async_env.py b/tests/test_env/test_vec_env/test_async_env.py
new file mode 100644
index 00000000..2c2301d3
--- /dev/null
+++ b/tests/test_env/test_vec_env/test_async_env.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import multiprocessing as mp
+import os
+import sys
+
+import pytest
+from gymnasium.wrappers import EnvCompatibility
+
+from openrl.envs.toy_envs import make_toy_envs
+from openrl.envs.vec_env.async_venv import AsyncVectorEnv, _worker
+
+
+class CustomEnvCompatibility(EnvCompatibility):
+ def reset(self, **kwargs):
+ return super().reset(**kwargs)[0]
+
+
+def init_envs():
+ env_wrappers = [CustomEnvCompatibility]
+ env_fns = make_toy_envs(
+ id="IdentityEnv",
+ env_num=2,
+ env_wrappers=env_wrappers,
+ )
+ return env_fns
+
+
+def assert_env_name(env, env_name):
+ if isinstance(env.metadata["name"], str):
+ assert env.metadata["name"] == env_name
+ else:
+ assert env.metadata["name"].__name__ == env_name
+
+
+@pytest.mark.unittest
+def test_async_env():
+ env_name = "IdentityEnv"
+ env = AsyncVectorEnv(init_envs(), shared_memory=True)
+ assert (
+ env._env_name == env_name
+ ), "AsyncVectorEnv should have the same metadata as the wrapped env"
+ env.exec_func(assert_env_name, indices=None, env_name=env_name)
+ env.call("render")
+ env_name_new = "IdentityEnvNew"
+ env.set_attr("metadata", {"name": env_name_new})
+ env.exec_func(assert_env_name, indices=None, env_name=env_name_new)
+
+
+def main_control(parent_pipe, child_pipe):
+ child_pipe.close()
+
+ parent_pipe.send(("reset", {"seed": 0}))
+ result, success = parent_pipe.recv()
+ assert success, result
+
+ parent_pipe.send(("step", [0]))
+ result, success = parent_pipe.recv()
+ assert success, result
+
+ parent_pipe.send(("_call", ("render", [], {})))
+ result, success = parent_pipe.recv()
+ assert success, result
+
+ parent_pipe.send(("_setattr", ("metadata", {"name": "IdentityEnvNew"})))
+ result, success = parent_pipe.recv()
+ assert success, result
+
+ parent_pipe.send(
+ ("_func_exec", (assert_env_name, None, [], {"env_name": "IdentityEnvNew"}))
+ )
+ result, success = parent_pipe.recv()
+ assert success, result
+
+ parent_pipe.send(("close", None))
+ result, success = parent_pipe.recv()
+ assert success, result
+
+
+@pytest.mark.unittest
+def test_worker():
+ for auto_reset in [True, False]:
+ ctx = mp.get_context(None)
+ parent_pipe, child_pipe = ctx.Pipe()
+
+ error_queue = ctx.Queue()
+
+ process = ctx.Process(
+ target=main_control,
+ name="test",
+ args=(parent_pipe, child_pipe),
+ )
+ process.daemon = True
+ process.start()
+ _worker(0, init_envs()[0], child_pipe, None, False, error_queue, auto_reset)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_env/test_vec_env/test_sync_env.py b/tests/test_env/test_vec_env/test_sync_env.py
new file mode 100644
index 00000000..fb3d5d0b
--- /dev/null
+++ b/tests/test_env/test_vec_env/test_sync_env.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import pytest
+from gymnasium.wrappers import EnvCompatibility
+
+from openrl.envs.toy_envs import make_toy_envs
+from openrl.envs.vec_env.sync_venv import SyncVectorEnv
+
+
+class CustomEnvCompatibility(EnvCompatibility):
+ def reset(self, **kwargs):
+ return super().reset(**kwargs)[0]
+
+
+def init_envs():
+ env_wrappers = [CustomEnvCompatibility]
+ env_fns = make_toy_envs(
+ id="IdentityEnv",
+ env_num=2,
+ env_wrappers=env_wrappers,
+ )
+ return env_fns
+
+
+def assert_env_name(env, env_name):
+ assert env.metadata["name"].__name__ == env_name
+
+
+@pytest.mark.unittest
+def test_sync_env():
+ env_name = "IdentityEnv"
+ env = SyncVectorEnv(init_envs())
+ env.exec_func(assert_env_name, indices=None, env_name=env_name)
+ env.call("render")
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_env/test_vec_env/test_vec_wrappers.py b/tests/test_env/test_vec_env/test_vec_wrappers.py
new file mode 100644
index 00000000..9b5735c9
--- /dev/null
+++ b/tests/test_env/test_vec_env/test_vec_wrappers.py
@@ -0,0 +1,87 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import pickle
+import sys
+
+import numpy as np
+import pytest
+
+from openrl.envs.common import make
+from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper, GenDataWrapper_v1
+from openrl.envs.vec_env.wrappers.zero_reward_wrapper import ZeroRewardWrapper
+from openrl.envs.wrappers.monitor import Monitor
+
+
+@pytest.mark.unittest
+def test_zero_reward_wrapper():
+ env = make("IdentityEnv", env_num=1)
+ env = ZeroRewardWrapper(env)
+ env.reset(seed=0)
+ while True:
+ obs, reward, done, info = env.step(env.random_action())
+ assert np.all(reward == 0), "reward should be zero"
+ if done:
+ break
+ env.close()
+
+
+@pytest.mark.unittest
+def test_gen_data(tmp_path):
+ total_episode = 4
+ env = make("IdentityEnv", env_wrappers=[Monitor], env_num=1)
+ data_save_path = tmp_path / "data.pkl"
+ env = GenDataWrapper(
+ env, data_save_path=str(data_save_path), total_episode=total_episode
+ )
+ obs, info = env.reset(seed=0)
+ done = False
+ while not done:
+ obs, r, done, info = env.step(env.random_action())
+ env.close()
+
+ save_data = pickle.load(open(data_save_path, "rb"))
+ assert len(save_data["episode_lengths"]) == total_episode, (
+ f"episode_lengths {len(save_data['episode_lengths'])} "
+ f"should be equal to total_episode {total_episode}"
+ )
+
+
+@pytest.mark.unittest
+def test_gen_data_old(tmp_path):
+ total_episode = 4
+ env = make("IdentityEnv", env_wrappers=[Monitor], env_num=1)
+ data_save_path = tmp_path / "data.pkl"
+ env = GenDataWrapper_v1(
+ env, data_save_path=str(data_save_path), total_episode=total_episode
+ )
+ obs, info = env.reset(seed=0)
+ done = False
+ while not done:
+ obs, r, done, info = env.step(env.random_action())
+ env.close()
+
+ save_data = pickle.load(open(data_save_path, "rb"))
+ assert save_data["total_episode"] == total_episode, (
+ f"episode_lengths {save_data['total_episode']} "
+ f"should be equal to total_episode {total_episode}"
+ )
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_env/test_wrappers.py b/tests/test_env/test_wrappers.py
new file mode 100644
index 00000000..6042eccf
--- /dev/null
+++ b/tests/test_env/test_wrappers.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import pytest
+
+
+@pytest.mark.unittest
+def test_atari_wrappers():
+ import gymnasium
+
+ from openrl.envs.wrappers.atari_wrappers import (
+ ClipRewardEnv,
+ EpisodicLifeEnv,
+ FireResetEnv,
+ NoopResetEnv,
+ WarpFrame,
+ )
+
+ env = gymnasium.make("ALE/Breakout-v5")
+ env = FireResetEnv(EpisodicLifeEnv(ClipRewardEnv(WarpFrame(NoopResetEnv(env)))))
+ env.reset(seed=0)
+ while True:
+ obs, reward, done, truncated, info = env.step(0)
+ if done:
+ break
+ env.close()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_examples/test_nlp.py b/tests/test_examples/test_nlp.py
index 99111fdc..1524ae65 100644
--- a/tests/test_examples/test_nlp.py
+++ b/tests/test_examples/test_nlp.py
@@ -17,20 +17,26 @@
# """"""
#
+import os
+import sys
+import pytest
+
+from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
-def config():
- from openrl.configs.config import create_config_parser
-
+# @pytest.fixture(scope="module", params=["--env.args {'data_path':None,'tokenizer_path':'builtin_BPE'}"])
+@pytest.fixture(scope="module", params=[""])
+def config(request):
cfg_parser = create_config_parser()
- cfg = cfg_parser.parse_args()
+ cfg = cfg_parser.parse_args(request.param.split())
return cfg
+@pytest.mark.unittest
def test_train_nlp(config):
env = make("fake_dialog_data", env_num=3, cfg=config)
agent = Agent(Net(env))
@@ -38,5 +44,4 @@ def test_train_nlp(config):
if __name__ == "__main__":
- cfg = config()
- test_train_nlp(cfg)
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_examples/test_train_atari.py b/tests/test_examples/test_train_atari.py
new file mode 100644
index 00000000..4d8d2166
--- /dev/null
+++ b/tests/test_examples/test_train_atari.py
@@ -0,0 +1,74 @@
+""""""
+
+import os
+import sys
+
+import numpy as np
+import pytest
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.envs.wrappers.atari_wrappers import (
+ ClipRewardEnv,
+ FireResetEnv,
+ NoopResetEnv,
+ WarpFrame,
+)
+from openrl.envs.wrappers.image_wrappers import TransposeImage
+from openrl.envs.wrappers.monitor import Monitor
+from openrl.modules.common import PPONet as Net
+from openrl.runners.common import PPOAgent as Agent
+
+env_wrappers = [
+ Monitor,
+ NoopResetEnv,
+ FireResetEnv,
+ WarpFrame,
+ ClipRewardEnv,
+ TransposeImage,
+]
+
+
+@pytest.fixture(
+ scope="module",
+ params=[
+ "--episode_length 5 --use_recurrent_policy false --vec_info_class.id"
+ " EPS_RewardInfo --use_valuenorm true --use_adv_normalize true"
+ " --use_share_model True --entropy_coef 0.01"
+ ],
+)
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_train_atari(config):
+ env_num = 2
+ env = make(
+ "ALE/Pong-v5",
+ env_num=env_num,
+ cfg=config,
+ asynchronous=True,
+ env_wrappers=env_wrappers,
+ )
+ net = Net(env, cfg=config)
+ agent = Agent(net)
+ agent.train(total_time_steps=30)
+ agent.save("./ppo_agent/")
+ agent.load("./ppo_agent/")
+ agent.set_env(env)
+ obs, info = env.reset(seed=0)
+ step = 0
+ while step < 5:
+ action, _ = agent.act(obs, deterministic=True)
+ obs, r, done, info = env.step(action)
+ if np.any(done):
+ break
+ step += 1
+ env.close()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_examples/test_train_gail.py b/tests/test_examples/test_train_gail.py
new file mode 100644
index 00000000..656ff2d0
--- /dev/null
+++ b/tests/test_examples/test_train_gail.py
@@ -0,0 +1,75 @@
+""""""
+
+import os
+import sys
+
+import pytest
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper
+from openrl.envs.wrappers.extra_wrappers import ZeroRewardWrapper
+from openrl.envs.wrappers.monitor import Monitor
+from openrl.modules.common import GAILNet as Net
+from openrl.modules.common import PPONet
+from openrl.runners.common import GAILAgent as Agent
+from openrl.runners.common import PPOAgent
+
+
+@pytest.fixture(scope="function")
+def gen_data(tmpdir):
+ tmp_data_path = os.path.join(tmpdir, "data.pkl")
+ env_wrappers = [
+ Monitor,
+ ]
+ print("generate data....")
+ env = make(
+ "CartPole-v1",
+ env_num=2,
+ asynchronous=True,
+ env_wrappers=env_wrappers,
+ )
+ agent = PPOAgent(PPONet(env))
+ env = GenDataWrapper(env, data_save_path=tmp_data_path, total_episode=5)
+ obs, info = env.reset()
+ done = False
+ while not done:
+ # Based on environmental observation input, predict next action.
+ action, _ = agent.act(obs, deterministic=True)
+ obs, r, done, info = env.step(action)
+ env.close()
+ print("generate data done!")
+ return tmp_data_path
+
+
+@pytest.fixture(
+ scope="function", params=[" --gail_use_action false", " --gail_use_action true"]
+)
+def config(request, gen_data):
+ input_str = (
+ "--episode_length 5 --use_recurrent_policy true --use_joint_action_loss true"
+ " --use_valuenorm true --use_adv_normalize true --reward_class.id GAILReward"
+ )
+ input_str += request.param
+ input_str += " --expert_data " + gen_data
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(input_str.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_train_gail(config):
+ env = make("CartPole-v1", env_num=2, cfg=config, env_wrappers=[ZeroRewardWrapper])
+
+ net = Net(
+ env,
+ cfg=config,
+ )
+ # initialize the trainer
+ agent = Agent(net)
+ agent.train(total_time_steps=200)
+ env.close()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_examples/test_train_mpe.py b/tests/test_examples/test_train_mpe.py
index 419b3dab..36e3e689 100644
--- a/tests/test_examples/test_train_mpe.py
+++ b/tests/test_examples/test_train_mpe.py
@@ -1,4 +1,5 @@
""""""
+
import os
import sys
diff --git a/tests/test_modules/test_common/test_ddpg_net.py b/tests/test_modules/test_common/test_ddpg_net.py
new file mode 100644
index 00000000..a4c03354
--- /dev/null
+++ b/tests/test_modules/test_common/test_ddpg_net.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import pytest
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.envs.wrappers.extra_wrappers import AddStep
+from openrl.modules.common import DDPGNet as Net
+from openrl.runners.common import DDPGAgent as Agent
+
+env_wrappers = [AddStep]
+
+
+@pytest.fixture(scope="module", params=[""])
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+def train(Agent, Net, env_name, env_num, total_time_steps, config):
+ cfg = config
+ env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers)
+
+ net = Net(
+ env,
+ cfg=cfg,
+ )
+ # initialize the trainer
+ agent = Agent(net)
+ # start training, set total number of training steps to 20000
+ agent.train(total_time_steps=total_time_steps)
+ env.close()
+
+
+@pytest.mark.unittest
+def test_ddpg_net(config):
+ train(Agent, Net, "IdentityEnvcontinuous", 2, 100, config)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_modules/test_common/test_dqn_net.py b/tests/test_modules/test_common/test_dqn_net.py
new file mode 100644
index 00000000..292c08b4
--- /dev/null
+++ b/tests/test_modules/test_common/test_dqn_net.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import pytest
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.envs.wrappers.extra_wrappers import AddStep
+from openrl.modules.common import DQNNet as Net
+from openrl.runners.common import DQNAgent as Agent
+
+env_wrappers = [AddStep]
+
+
+@pytest.fixture(scope="module", params=[""])
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+def train(Agent, Net, env_name, env_num, total_time_steps, config):
+ cfg = config
+ env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers)
+
+ net = Net(
+ env,
+ cfg=cfg,
+ )
+ # initialize the trainer
+ agent = Agent(net)
+ # start training, set total number of training steps to 20000
+ agent.train(total_time_steps=total_time_steps)
+ env.close()
+
+
+@pytest.mark.unittest
+def test_dqn_net(config):
+ train(Agent, Net, "IdentityEnv", 2, 100, config)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_modules/test_common/test_sac_net.py b/tests/test_modules/test_common/test_sac_net.py
new file mode 100644
index 00000000..8839986e
--- /dev/null
+++ b/tests/test_modules/test_common/test_sac_net.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import pytest
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.envs.wrappers.extra_wrappers import AddStep
+from openrl.modules.common import SACNet as Net
+from openrl.runners.common import SACAgent as Agent
+
+env_wrappers = [AddStep]
+
+
+@pytest.fixture(scope="module", params=[""])
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+def train(Agent, Net, env_name, env_num, total_time_steps, config):
+ cfg = config
+ env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers)
+
+ net = Net(
+ env,
+ cfg=cfg,
+ )
+ # initialize the trainer
+ agent = Agent(net)
+ # start training, set total number of training steps to 20000
+ agent.train(total_time_steps=total_time_steps)
+ env.close()
+
+
+@pytest.mark.unittest
+def test_sac_net(config):
+ train(Agent, Net, "IdentityEnvcontinuous", 2, 100, config)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_modules/test_common/test_vdn_net.py b/tests/test_modules/test_common/test_vdn_net.py
new file mode 100644
index 00000000..29f1f58f
--- /dev/null
+++ b/tests/test_modules/test_common/test_vdn_net.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import pytest
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.envs.wrappers.mat_wrapper import MATWrapper
+from openrl.modules.common import VDNNet
+from openrl.runners.common import VDNAgent as Agent
+
+
+@pytest.fixture(scope="module", params=[""])
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_vdn_net(config):
+ env_num = 2
+ env = make(
+ "simple_spread",
+ env_num=env_num,
+ asynchronous=True,
+ )
+ env = MATWrapper(env)
+
+ net = VDNNet(env, cfg=config)
+ # initialize the trainer
+ agent = Agent(net)
+ # start training
+ agent.train(total_time_steps=100)
+ env.close()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_modules/test_networks/test_MAT_network.py b/tests/test_modules/test_networks/test_MAT_network.py
new file mode 100644
index 00000000..25b4fd00
--- /dev/null
+++ b/tests/test_modules/test_networks/test_MAT_network.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import numpy as np
+import pytest
+from gymnasium import spaces
+
+from openrl.configs.config import create_config_parser
+from openrl.modules.networks.MAT_network import MultiAgentTransformer
+
+
+@pytest.fixture(scope="module", params=[""])
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_MAT_network(config):
+ net = MultiAgentTransformer(
+ config,
+ input_space=spaces.Discrete(2),
+ action_space=spaces.Discrete(2),
+ )
+ net.get_actor_para()
+ net.get_critic_para()
+
+ obs = np.zeros([1, 2])
+ rnn_states = np.zeros(2)
+ masks = np.zeros(2)
+ action = np.zeros(1)
+ net.get_actions(obs=obs, masks=masks)
+ net.eval_actions(
+ obs=obs, rnn_states=rnn_states, action=action, masks=masks, action_masks=None
+ )
+ net.get_values(critic_obs=obs, masks=masks)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_modules/test_networks/test_attention.py b/tests/test_modules/test_networks/test_attention.py
new file mode 100644
index 00000000..599d0000
--- /dev/null
+++ b/tests/test_modules/test_networks/test_attention.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import pytest
+import torch
+
+from openrl.configs.config import create_config_parser
+from openrl.modules.networks.utils.attention import Encoder
+
+
+@pytest.fixture(
+ scope="module", params=["--use_average_pool True", "--use_average_pool False"]
+)
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_attention(config):
+ for cat_self in [False, True]:
+ net = Encoder(cfg=config, split_shape=[[1, 1], [1, 1]], cat_self=cat_self)
+ net(torch.zeros((1, 1)))
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_modules/test_networks/test_policy_value_network_gpt.py b/tests/test_modules/test_networks/test_policy_value_network_gpt.py
new file mode 100644
index 00000000..66a6caaa
--- /dev/null
+++ b/tests/test_modules/test_networks/test_policy_value_network_gpt.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+
+import os
+import sys
+
+import numpy as np
+import pytest
+from gymnasium import spaces
+
+from openrl.configs.config import create_config_parser
+from openrl.modules.networks.policy_value_network_gpt import (
+ PolicyValueNetworkGPT as PolicyValueNetwork,
+)
+
+
+@pytest.fixture(scope="module", params=["--model_path test_gpt2"])
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_gpt_network(config):
+ net = PolicyValueNetwork(
+ cfg=config,
+ input_space=spaces.Discrete(2),
+ action_space=spaces.Discrete(2),
+ )
+
+ net.get_actor_para()
+ net.get_critic_para()
+
+ obs = {
+ "input_encoded_pt": np.zeros([1, 2]),
+ "input_attention_mask_pt": np.zeros([1, 2]),
+ }
+ rnn_states = np.zeros(2)
+ masks = np.zeros(2)
+ action = np.zeros(1)
+ net.get_actions(obs=obs, rnn_states=rnn_states, masks=masks)
+ net.eval_actions(
+ obs=obs, rnn_states=rnn_states, action=action, masks=masks, action_masks=None
+ )
+ net.get_values(obs=obs, rnn_states=rnn_states, masks=masks)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_rewards/test_nlp_reward.py b/tests/test_rewards/test_nlp_reward.py
new file mode 100644
index 00000000..739943ef
--- /dev/null
+++ b/tests/test_rewards/test_nlp_reward.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2023 The OpenRL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""""""
+import os
+import sys
+
+import numpy as np
+import pytest
+
+from openrl.buffers.normal_buffer import NormalReplayBuffer
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.rewards import RewardFactory
+
+
+@pytest.fixture(
+ scope="module",
+ params=[
+ "--reward_class.id NLPReward --reward_class.args"
+ " {'intent_model':'builtin_intent','ref_model':'builtin_ref','use_deepspeed':False}"
+ ],
+)
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(request.param.split())
+ return cfg
+
+
+@pytest.mark.unittest
+def test_nlp_reward(config):
+ env = make("fake_dialog_data", env_num=1)
+ reward = RewardFactory.get_reward_class(config.reward_class, env)
+ data = {}
+ data["rewards"] = np.zeros(32)
+ env_info = {}
+ env_info["final_info"] = {
+ "prompt_texts": "hello",
+ "generated_texts": "hello",
+ "meta_infos": {"intent": [1]},
+ }
+ data["infos"] = [env_info] * 32
+ data["step"] = 0
+ data["actions"] = [0]
+ data["action_log_probs"] = np.zeros(32)
+ buffer = NormalReplayBuffer(
+ config,
+ num_agents=env.agent_num,
+ obs_space=env.observation_space,
+ act_space=env.action_space,
+ data_client=None,
+ episode_length=1,
+ )
+ data["buffer"] = buffer
+ reward.step_reward(data=data)
+ reward.batch_rewards(buffer=buffer)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_selfplay/test_selfplay_strategy.py b/tests/test_selfplay/test_selfplay_strategy.py
deleted file mode 100644
index 61b04052..00000000
--- a/tests/test_selfplay/test_selfplay_strategy.py
+++ /dev/null
@@ -1,91 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-# Copyright 2023 The OpenRL Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-""""""
-import os
-import sys
-
-import pytest
-
-from openrl.selfplay.strategies import (
- NaiveSelfplayStrategy,
- OnlyLatestSelfplayStrategy,
- VarExistEnemySelfplayStrategy,
- WeightExistEnemySelfplayStrategy,
- WeightSelfplayStrategy,
- WinRateSelfplayStrategy,
-)
-
-
-@pytest.fixture(scope="module", params=[""])
-def config(request):
- from openrl.configs.config import create_config_parser
-
- cfg_parser = create_config_parser()
- cfg = cfg_parser.parse_args(request.param.split())
- return cfg
-
-
-@pytest.mark.unittest
-def test_naive_selfplay(config):
- strategy = NaiveSelfplayStrategy(config, 1, 1)
- strategy.get_plist()
- strategy.update_weight(enemy_loses=1)
- strategy.update_win_rate(dones=True, enemy_wins=1)
- strategy.push_newone()
-
-
-@pytest.mark.unittest
-def test_only_latest_selfplay(config):
- strategy = OnlyLatestSelfplayStrategy(config, 1, 1)
- strategy.get_plist()
- strategy.update_weight(enemy_loses=1)
- strategy.push_newone()
-
-
-@pytest.mark.unittest
-def test_weight_selfplay(config):
- strategy = WeightSelfplayStrategy(config, 1, 1)
- strategy.get_plist()
- strategy.update_weight(enemy_loses=1)
- strategy.push_newone()
-
-
-@pytest.mark.unittest
-def test_win_rate_selfplay(config):
- strategy = WinRateSelfplayStrategy(config, 1, 1)
- strategy.get_plist()
- strategy.update_weight(enemy_loses=1)
-
-
-@pytest.mark.unittest
-def test_var_exist_enemy_selfplay(config):
- strategy = VarExistEnemySelfplayStrategy(config, 1, 1)
- strategy.get_plist()
- strategy.update_weight(enemy_loses=1)
- strategy.push_newone()
-
-
-@pytest.mark.unittest
-def test_weight_exist_enemy_selfplay(config):
- strategy = WeightExistEnemySelfplayStrategy(config, 1, 1)
- strategy.get_plist()
- strategy.update_weight(enemy_loses=1)
- strategy.push_newone()
-
-
-if __name__ == "__main__":
- sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
diff --git a/tests/test_selfplay/test_train_selfplay.py b/tests/test_selfplay/test_train_selfplay.py
new file mode 100644
index 00000000..a67ea964
--- /dev/null
+++ b/tests/test_selfplay/test_train_selfplay.py
@@ -0,0 +1,134 @@
+import os
+import socket
+import sys
+
+import numpy as np
+import pytest
+import torch
+
+from openrl.configs.config import create_config_parser
+from openrl.envs.common import make
+from openrl.envs.wrappers import FlattenObservation
+from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner
+from openrl.modules.common import PPONet as Net
+from openrl.runners.common import PPOAgent as Agent
+from openrl.selfplay.wrappers.opponent_pool_wrapper import OpponentPoolWrapper
+from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper
+
+
+def find_free_port():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return s.getsockname()[1]
+
+
+@pytest.fixture(
+ scope="module",
+ params=[
+ {"port": find_free_port(), "strategy": "RandomOpponent"},
+ {"port": find_free_port(), "strategy": "LastOpponent"},
+ ],
+)
+def config(request):
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args(["--config", "./examples/selfplay/selfplay.yaml"])
+ cfg.selfplay_api.port = request.param["port"]
+ print("port:", request.param["port"])
+ for i, c in enumerate(cfg.callbacks):
+ if c["id"] == "SelfplayCallback":
+ c["args"][
+ "opponent_template"
+ ] = "./examples/selfplay/opponent_templates/tictactoe_opponent"
+ port = c["args"]["api_address"].split(":")[-1].split("/")[0]
+ c["args"]["api_address"] = c["args"]["api_address"].replace(
+ port, str(request.param["port"])
+ )
+ cfg.callbacks[i] = c
+ elif c["id"] == "SelfplayAPI":
+ c["args"]["sample_strategy"] = request.param["strategy"]
+ c["args"]["port"] = request.param["port"]
+ cfg.callbacks[i] = c
+
+ else:
+ pass
+
+ return cfg
+
+
+def train(cfg):
+ # Create environment
+ env_num = 2
+ render_model = None
+ env = make(
+ "tictactoe_v3",
+ render_mode=render_model,
+ env_num=env_num,
+ asynchronous=False,
+ opponent_wrappers=[RecordWinner, OpponentPoolWrapper],
+ env_wrappers=[FlattenObservation],
+ cfg=cfg,
+ )
+ # Create neural network
+
+ net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
+ # Create agent
+ agent = Agent(net)
+ # Begin training
+ agent.train(total_time_steps=10)
+ env.close()
+ agent.save("./selfplay_agent/")
+ return agent
+
+
+def evaluation():
+ print("Evaluation...")
+ env_num = 1
+ env = make(
+ "tictactoe_v3",
+ env_num=env_num,
+ asynchronous=True,
+ opponent_wrappers=[RandomOpponentWrapper],
+ env_wrappers=[FlattenObservation],
+ auto_reset=False,
+ )
+
+ cfg_parser = create_config_parser()
+ cfg = cfg_parser.parse_args([])
+ net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
+
+ agent = Agent(net)
+
+ agent.load("./selfplay_agent/")
+ agent.set_env(env)
+ env.reset(seed=0)
+
+ total_reward = 0.0
+ ep_num = 2
+ for ep_now in range(ep_num):
+ obs, info = env.reset()
+ done = False
+ step = 0
+
+ while not np.any(done):
+ # predict next action based on the observation
+ action, _ = agent.act(obs, info, deterministic=True)
+ obs, r, done, info = env.step(action)
+ step += 1
+
+ if np.any(done):
+ total_reward += np.mean(r) > 0
+ print(f"{ep_now}/{ep_num}: reward: {np.mean(r)}")
+ print(f"win rate: {total_reward/ep_num}")
+ env.close()
+ print("Evaluation finished.")
+
+
+@pytest.mark.unittest
+def test_train_selfplay(config):
+ train(config)
+ evaluation()
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))