Skip to content

Commit

Permalink
Merge pull request #27 from JunnYu/clean_roformer
Browse files Browse the repository at this point in the history
add roformer-sim的例子,并更新rotary的实现方式
  • Loading branch information
JunnYu authored Apr 2, 2022
2 parents 0eb4fc7 + b8be240 commit 967242a
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 85 deletions.
94 changes: 86 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,94 @@
RoFormer模型和RoFormer-V2模型

## 更新
- 2022/03/21 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!!
- **2022/04/02**
(1)修改RoFormerForCausalLM,支持`roformer-sim`并提供相关的例子,请见`examples/test_sim.py`
(2)修改`apply_rotary`实现方式,看起来更简单。
```python
def apply_rotary(x, sinusoidal_pos):
sin, cos = sinusoidal_pos
x1, x2 = x[..., 0::2], x[..., 1::2]
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
```
- **2022/03/21** 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!!

## v2版本安装

## 安装
```bash
# v2版本
pip install roformer>=0.4.0
# 如果安装不了,说明清华镜像源没有同步,过一会就可以安装。
# v1版本(代码已经加入到huggingface仓库,请使用新版本的transformers)
pip install -U transformers
```

## v1版本安装(代码已经加入到huggingface仓库)
transformers v4.7版本已经发布,可以直接安装使用
```bash
pip install -U transformers
## roformer-sim测试例子
```python
import torch
import numpy as np
from roformer import RoFormerForCausalLM, RoFormerConfig
from transformers import BertTokenizer

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 可选以下几个。
# junnyu/roformer_chinese_sim_char_small, junnyu/roformer_chinese_sim_char_base
# junnyu/roformer_chinese_sim_char_ft_small, roformer_chinese_sim_char_ft_base
pretrained_model = "junnyu/roformer_chinese_sim_char_base"
tokenizer = BertTokenizer.from_pretrained(pretrained_model)
config = RoFormerConfig.from_pretrained(pretrained_model)
config.is_decoder = True
config.eos_token_id = tokenizer.sep_token_id
config.pooler_activation = "linear"
model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config)
model.to(device)
model.eval()

def gen_synonyms(text, n=100, k=20):
''''含义: 产生sent的n个相似句,然后返回最相似的k个。
做法:用seq2seq生成,并用encoder算相似度并排序。
'''
# 寻找所有相似的句子
r = []
inputs1 = tokenizer(text, return_tensors="pt")
for _ in range(n):
inputs1.to(device)
output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。
r.append(output)

# 对相似的句子进行排序
r = [i for i in set(r) if i != text and len(i) > 0]
r = [text] + r
inputs2 = tokenizer(r, padding=True, return_tensors="pt")
with torch.no_grad():
inputs2.to(device)
outputs = model(**inputs2)
Z = outputs.pooler_output.cpu().numpy()
Z /= (Z**2).sum(axis=1, keepdims=True)**0.5
argsort = np.dot(Z[1:], -Z[0]).argsort()

return [r[i + 1] for i in argsort[:k]]

out = gen_synonyms("广州和深圳哪个好?")
print(out)
# ['深圳和广州哪个好?',
# '广州和深圳哪个好',
# '深圳和广州哪个好',
# '深圳和广州哪个比较好。',
# '深圳和广州哪个最好?',
# '深圳和广州哪个比较好',
# '广州和深圳那个比较好',
# '深圳和广州哪个更好?',
# '深圳与广州哪个好',
# '深圳和广州,哪个比较好',
# '广州与深圳比较哪个好',
# '深圳和广州哪里比较好',
# '深圳还是广州比较好?',
# '广州和深圳哪个地方好一些?',
# '广州好还是深圳好?',
# '广州好还是深圳好呢?',
# '广州与深圳哪个地方好点?',
# '深圳好还是广州好',
# '广州好还是深圳好',
# '广州和深圳哪个城市好?']
```

## 模型权重对照表
Expand All @@ -39,6 +115,8 @@ pip install -U transformers
| [roformer_chinese_sim_char_ft_small](https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small) | [chinese_roformer-sim-char-ft_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1G36x7YQF1b6nzW0OzyJS_Q) (download code:gty5) |




### 英文模型(使用electra的训练方法在openwebtext上训练的small模型(rotary value = True))
| huggingface.co |
| ---------------------------------- |
Expand Down Expand Up @@ -139,7 +217,7 @@ print(tf_outputs_sentence)
# tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。

```

## 手动权重转换
```bash
python convert_roformer_original_tf_checkpoint_to_pytorch.py \
Expand Down
63 changes: 63 additions & 0 deletions examples/test_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import numpy as np
from roformer import RoFormerForCausalLM, RoFormerConfig
from transformers import BertTokenizer

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
pretrained_model = "junnyu/roformer_chinese_sim_char_base"
tokenizer = BertTokenizer.from_pretrained(pretrained_model)
config = RoFormerConfig.from_pretrained(pretrained_model)
config.is_decoder = True
config.eos_token_id = tokenizer.sep_token_id
config.pooler_activation = "linear"
model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config)
model.to(device)
model.eval()

def gen_synonyms(text, n=100, k=20):
''''含义: 产生sent的n个相似句,然后返回最相似的k个。
做法:用seq2seq生成,并用encoder算相似度并排序。
'''
# 寻找所有相似的句子
r = []
inputs1 = tokenizer(text, return_tensors="pt")
for _ in range(n):
inputs1.to(device)
output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。
r.append(output)

# 对相似的句子进行排序
r = [i for i in set(r) if i != text and len(i) > 0]
r = [text] + r
inputs2 = tokenizer(r, padding=True, return_tensors="pt")
with torch.no_grad():
inputs2.to(device)
outputs = model(**inputs2)
Z = outputs.pooler_output.cpu().numpy()
Z /= (Z**2).sum(axis=1, keepdims=True)**0.5
argsort = np.dot(Z[1:], -Z[0]).argsort()

return [r[i + 1] for i in argsort[:k]]

out = gen_synonyms("广州和深圳哪个好?")
print(out)
# ['深圳和广州哪个好?',
# '广州和深圳哪个好',
# '深圳和广州哪个好',
# '深圳和广州哪个比较好。',
# '深圳和广州哪个最好?',
# '深圳和广州哪个比较好',
# '广州和深圳那个比较好',
# '深圳和广州哪个更好?',
# '深圳与广州哪个好',
# '深圳和广州,哪个比较好',
# '广州与深圳比较哪个好',
# '深圳和广州哪里比较好',
# '深圳还是广州比较好?',
# '广州和深圳哪个地方好一些?',
# '广州好还是深圳好?',
# '广州好还是深圳好呢?',
# '广州与深圳哪个地方好点?',
# '深圳好还是广州好',
# '广州好还是深圳好',
# '广州和深圳哪个城市好?']
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name="roformer",
package_dir={"": "src"},
packages=find_packages("src"),
version="0.4.0",
version="0.4.1",
license="Apache 2.0",
description="roformer_pytorch",
author="Jun Yu",
Expand Down
10 changes: 10 additions & 0 deletions src/roformer/configuration_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
"junnyu/roformer_chinese_sim_char_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_small/resolve/main/config.json",
"junnyu/roformer_chinese_sim_char_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_base/resolve/main/config.json",
"junnyu/roformer_chinese_sim_char_ft_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base/resolve/main/config.json",
"junnyu/roformer_chinese_sim_char_ft_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small/resolve/main/config.json",
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
"junnyu/roformer_base_wwm_cluecorpussmall": "https://huggingface.co/junnyu/roformer_base_wwm_cluecorpussmall/resolve/main/config.json",
"junnyu/roformer_v2_chinese_char_small": "https://huggingface.co/junnyu/roformer_v2_chinese_char_small/resolve/main/config.json",
"junnyu/roformer_v2_chinese_char_base": "https://huggingface.co/junnyu/roformer_v2_chinese_char_base/resolve/main/config.json",
"junnyu/roformer_v2_chinese_char_large": "https://huggingface.co/junnyu/roformer_v2_chinese_char_large/resolve/main/config.json",
# See all RoFormer models at https://huggingface.co/models?filter=roformer
}

Expand Down Expand Up @@ -107,6 +115,7 @@ def __init__(
use_cache=True,
use_bias=True,
norm_type="layer_norm",
pooler_activation="tanh",
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
Expand All @@ -128,3 +137,4 @@ def __init__(
self.use_cache = use_cache
self.use_bias = use_bias
self.norm_type = norm_type
self.pooler_activation = pooler_activation
Loading

0 comments on commit 967242a

Please sign in to comment.