Skip to content

Commit

Permalink
add roformer-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
JunnYu committed Mar 21, 2022
1 parent 88bc606 commit 2ffcf6a
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 77 deletions.
106 changes: 72 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
# PyTorch RoFormer
原版Tensorflow权重(https://github.com/ZhuiyiTechnology/roformer)
- [chinese_roformer_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1fiss862YsGCwf2HvU_Jm-g) (提取码:xy9x)
- [chinese_roformer_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1iIXgZHHCgrYGXVRRSSCVPg) (提取码:gy97)
- [chinese_roformer-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1Q1pq8F4Fsl6bTipUAkqeDQ) (提取码:bt94)
- [chinese_roformer-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1cc281-M0Rsjlwws5phqzbQ)(提取码:a44c)
- [chinese_roformer-sim-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1f1FB288nv1a6jYjsNCordg)(提取码:2cgz)
- [chinese_roformer-sim-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1r0eJ7shGwQ0RzV9BTFFW4g)(提取码:h68q)

已经转化为PyTorch权重
- [chinese_roformer_small.zip](https://pan.baidu.com/s/1Cx7lhtojTyRF61IKHWXEHw) (提取码:8znw)
- [chinese_roformer_base.zip](https://pan.baidu.com/s/10W5BYDQSeLyajTWjexZeoQ) (提取码:bimr)
- [chinese_roformer_char_base.zip](https://pan.baidu.com/s/18bgJ1t_1ke0BXq_Xg02qSQ) (提取码:oqb5)

## 安装(代码已经加入到huggingface仓库)
# PyTorch RoFormer & RoFormer-V2
RoFormer模型和RoFormer-V2模型

## 更新
- 2022/03/21 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!!

## 安装(代码已经加入到huggingface仓库),V2版本需要使用本仓库的代码
transformers v4.7版本已经发布,可以直接安装使用
```bash
pip install -U transformers
```

## 模型权重对照表

### 中文模型
### 中文模型 roformer-v2
| huggingface.co | bert4keras |
| ---------------------------------- | ------------------------------------------------ |
| [roformer_v2_chinese_char_small](https://huggingface.co/junnyu/roformer_v2_chinese_char_small) | [chinese_roformer-v2-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1huUrC9P60Afggo8AfiUcmA) (download code:ttn4) |
| [roformer_v2_chinese_char_base](https://huggingface.co/junnyu/roformer_v2_chinese_char_base) | [chinese_roformer-v2-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1qcnN4LVKVe0-mnHlkN3-6Q) (download code:pfoh) |
| [roformer_v2_chinese_char_large](https://huggingface.co/junnyu/roformer_v2_chinese_char_large) | [chinese_roformer-v2-char_L-24_H-1024_A-16.zip](https://pan.baidu.com/s/1QiJWSZrGxn8vek-8myvL6w) (download code:npfv) |


### 中文模型 roformer-v1
| huggingface.co | bert4keras |
| ---------------------------------- | ------------------------------------------------ |
| [roformer_chinese_base](https://huggingface.co/junnyu/roformer_chinese_base) | [chinese_roformer_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1fiss862YsGCwf2HvU_Jm-g) (download code:xy9x) |
Expand All @@ -38,34 +39,69 @@ pip install -U transformers
|[roformer_small_generator](https://huggingface.co/junnyu/roformer_small_generator)|
|[roformer_small_discriminator](https://huggingface.co/junnyu/roformer_small_discriminator)|


## 使用
## roformer-v2 MLM测试
```python
import torch
from transformers import RoFormerModel, RoFormerTokenizer, TFRoFormerModel
tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
pt_model = RoFormerModel.from_pretrained("junnyu/roformer_chinese_base")
tf_model = TFRoFormerModel.from_pretrained("junnyu/roformer_chinese_base",
from_pt=True)
text = "这里基本保留了唐宋遗留下来的坊巷格局和大量明清古建筑,其中各级文保单位29处,被誉为“里坊制度的活化石”“明清建筑博物馆”!"
import tensorflow as tf
from transformers import BertTokenizer
from roformer import RoFormerForMaskedLM, TFRoFormerForMaskedLM

text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = BertTokenizer.from_pretrained("junnyu/roformer_v2_chinese_char_base")
pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_v2_chinese_char_base")
tf_model = TFRoFormerForMaskedLM.from_pretrained(
"junnyu/roformer_v2_chinese_char_base", from_pt=True
)
pt_inputs = tokenizer(text, return_tensors="pt")
tf_inputs = tokenizer(text, return_tensors="tf")
# pytorch
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).last_hidden_state
print(pt_outputs.shape)
tf_outputs = tf_model(**tf_inputs, training=False).last_hidden_state
print(tf_outputs.shape)
pt_outputs = pt_model(**pt_inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
if id == tokenizer.mask_token_id:
tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1])
pt_outputs_sentence += "[" + "||".join(tokens) + "]"
else:
pt_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
)
print(pt_outputs_sentence)
# tf
tf_outputs = tf_model(**tf_inputs, training=False).logits[0]
tf_outputs_sentence = "tf: "
for i, id in enumerate(tokenizer.encode(text)):
if id == tokenizer.mask_token_id:
tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1])
tf_outputs_sentence += "[" + "||".join(tokens) + "]"
else:
tf_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
)
print(tf_outputs_sentence)
# small
# pytorch: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。
# tf: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。
# base
# pytorch: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。
# tf: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。
# large
# pytorch: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。
# tf: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。
```
## MLM测试

## roformer-v1 MLM测试
```python
import torch
import tensorflow as tf
from transformers import RoFormerForMaskedLM, RoFormerTokenizer, TFRoFormerForMaskedLM

text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
tf_model = TFRoFormerForMaskedLM.from_pretrained(
"junnyu/roformer_chinese_base", from_pt=True)
"junnyu/roformer_chinese_base", from_pt=True
)
pt_inputs = tokenizer(text, return_tensors="pt")
tf_inputs = tokenizer(text, return_tensors="tf")
# pytorch
Expand All @@ -78,22 +114,24 @@ for i, id in enumerate(tokenizer.encode(text)):
pt_outputs_sentence += "[" + "||".join(tokens) + "]"
else:
pt_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
)
print(pt_outputs_sentence)
# tf
tf_outputs = tf_model(**tf_inputs, training=False).logits[0]
tf_outputs_sentence = "tf: "
for i, id in enumerate(tokenizer.encode(text)):
if id == tokenizer.mask_token_id:
tokens = tokenizer.convert_ids_to_tokens(
tf.math.top_k(tf_outputs[i], k=5)[1])
tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1])
tf_outputs_sentence += "[" + "||".join(tokens) + "]"
else:
tf_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
)
print(tf_outputs_sentence)
# pytorch: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
# tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。

```

## 手动权重转换
Expand Down
23 changes: 0 additions & 23 deletions examples/test_mlm.py

This file was deleted.

39 changes: 39 additions & 0 deletions examples/test_mlm_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
import tensorflow as tf
from transformers import RoFormerForMaskedLM, RoFormerTokenizer, TFRoFormerForMaskedLM

text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
tf_model = TFRoFormerForMaskedLM.from_pretrained(
"junnyu/roformer_chinese_base", from_pt=True
)
pt_inputs = tokenizer(text, return_tensors="pt")
tf_inputs = tokenizer(text, return_tensors="tf")
# pytorch
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
if id == tokenizer.mask_token_id:
tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1])
pt_outputs_sentence += "[" + "||".join(tokens) + "]"
else:
pt_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
)
print(pt_outputs_sentence)
# tf
tf_outputs = tf_model(**tf_inputs, training=False).logits[0]
tf_outputs_sentence = "tf: "
for i, id in enumerate(tokenizer.encode(text)):
if id == tokenizer.mask_token_id:
tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1])
tf_outputs_sentence += "[" + "||".join(tokens) + "]"
else:
tf_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
)
print(tf_outputs_sentence)
# pytorch: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
# tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
47 changes: 47 additions & 0 deletions examples/test_mlm_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import tensorflow as tf
from transformers import BertTokenizer
from roformer import RoFormerForMaskedLM, TFRoFormerForMaskedLM

text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = BertTokenizer.from_pretrained("junnyu/roformer_v2_chinese_char_base")
pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_v2_chinese_char_base")
tf_model = TFRoFormerForMaskedLM.from_pretrained(
"junnyu/roformer_v2_chinese_char_base", from_pt=True
)
pt_inputs = tokenizer(text, return_tensors="pt")
tf_inputs = tokenizer(text, return_tensors="tf")
# pytorch
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
if id == tokenizer.mask_token_id:
tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1])
pt_outputs_sentence += "[" + "||".join(tokens) + "]"
else:
pt_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
)
print(pt_outputs_sentence)
# tf
tf_outputs = tf_model(**tf_inputs, training=False).logits[0]
tf_outputs_sentence = "tf: "
for i, id in enumerate(tokenizer.encode(text)):
if id == tokenizer.mask_token_id:
tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1])
tf_outputs_sentence += "[" + "||".join(tokens) + "]"
else:
tf_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
)
print(tf_outputs_sentence)
# small
# pytorch: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。
# tf: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。
# base
# pytorch: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。
# tf: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。
# large
# pytorch: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。
# tf: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name="roformer",
package_dir={"": "src"},
packages=find_packages("src"),
version="0.3.1",
version="0.4.0",
license="Apache 2.0",
description="roformer_pytorch",
author="Jun Yu",
Expand Down
4 changes: 4 additions & 0 deletions src/roformer/configuration_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __init__(
pad_token_id=0,
rotary_value=False,
use_cache=True,
use_bias=True,
norm_type="layer_norm",
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -124,3 +126,5 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.rotary_value = rotary_value
self.use_cache = use_cache
self.use_bias = use_bias
self.norm_type = norm_type
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@ def convert_tf_checkpoint_to_pytorch(
# Load weights from tf checkpoint
load_tf_weights_in_roformer(model, config, tf_checkpoint_path)

# ignore 不保存roformer.encoder.embed_positions.weight
_keys_to_ignore_on_save = ["roformer.encoder.embed_positions.weight"]
state_dict = model.state_dict()
for ignore_key in _keys_to_ignore_on_save:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]

# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
torch.save(
model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False
state_dict, pytorch_dump_path, _use_new_zipfile_serialization=False
)


Expand Down
Loading

0 comments on commit 2ffcf6a

Please sign in to comment.