From b8be240c28cf53734a1fca177f8e9c48ca400c22 Mon Sep 17 00:00:00 2001 From: YuuuW <573009727@qq.com> Date: Sat, 2 Apr 2022 23:29:14 +0800 Subject: [PATCH] =?UTF-8?q?add=20roformer-sim=E7=9A=84=E4=BE=8B=E5=AD=90?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E6=9B=B4=E6=96=B0rotary=E7=9A=84=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 94 ++++++++- examples/test_sim.py | 63 ++++++ setup.py | 2 +- src/roformer/configuration_roformer.py | 10 + src/roformer/modeling_roformer.py | 223 ++++++++++++++------- src/roformer/modeling_tf_roformer.py | 12 +- src/roformer/tokenization_roformer.py | 32 ++- src/roformer/tokenization_roformer_fast.py | 25 +++ 8 files changed, 376 insertions(+), 85 deletions(-) create mode 100644 examples/test_sim.py diff --git a/README.md b/README.md index dde96b5..dbc55c3 100644 --- a/README.md +++ b/README.md @@ -2,18 +2,94 @@ RoFormer模型和RoFormer-V2模型 ## 更新 -- 2022/03/21 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!! +- **2022/04/02** +(1)修改RoFormerForCausalLM,支持`roformer-sim`并提供相关的例子,请见`examples/test_sim.py`。 +(2)修改`apply_rotary`实现方式,看起来更简单。 +```python +def apply_rotary(x, sinusoidal_pos): + sin, cos = sinusoidal_pos + x1, x2 = x[..., 0::2], x[..., 1::2] + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) +``` +- **2022/03/21** 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!! -## v2版本安装 + +## 安装 ```bash +# v2版本 pip install roformer>=0.4.0 -# 如果安装不了,说明清华镜像源没有同步,过一会就可以安装。 +# v1版本(代码已经加入到huggingface仓库,请使用新版本的transformers) +pip install -U transformers ``` -## v1版本安装(代码已经加入到huggingface仓库) -transformers v4.7版本已经发布,可以直接安装使用 -```bash -pip install -U transformers +## roformer-sim测试例子 +```python +import torch +import numpy as np +from roformer import RoFormerForCausalLM, RoFormerConfig +from transformers import BertTokenizer + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +# 可选以下几个。 +# junnyu/roformer_chinese_sim_char_small, junnyu/roformer_chinese_sim_char_base +# junnyu/roformer_chinese_sim_char_ft_small, roformer_chinese_sim_char_ft_base +pretrained_model = "junnyu/roformer_chinese_sim_char_base" +tokenizer = BertTokenizer.from_pretrained(pretrained_model) +config = RoFormerConfig.from_pretrained(pretrained_model) +config.is_decoder = True +config.eos_token_id = tokenizer.sep_token_id +config.pooler_activation = "linear" +model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config) +model.to(device) +model.eval() + +def gen_synonyms(text, n=100, k=20): + ''''含义: 产生sent的n个相似句,然后返回最相似的k个。 + 做法:用seq2seq生成,并用encoder算相似度并排序。 + ''' + # 寻找所有相似的句子 + r = [] + inputs1 = tokenizer(text, return_tensors="pt") + for _ in range(n): + inputs1.to(device) + output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。 + r.append(output) + + # 对相似的句子进行排序 + r = [i for i in set(r) if i != text and len(i) > 0] + r = [text] + r + inputs2 = tokenizer(r, padding=True, return_tensors="pt") + with torch.no_grad(): + inputs2.to(device) + outputs = model(**inputs2) + Z = outputs.pooler_output.cpu().numpy() + Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 + argsort = np.dot(Z[1:], -Z[0]).argsort() + + return [r[i + 1] for i in argsort[:k]] + +out = gen_synonyms("广州和深圳哪个好?") +print(out) +# ['深圳和广州哪个好?', +# '广州和深圳哪个好', +# '深圳和广州哪个好', +# '深圳和广州哪个比较好。', +# '深圳和广州哪个最好?', +# '深圳和广州哪个比较好', +# '广州和深圳那个比较好', +# '深圳和广州哪个更好?', +# '深圳与广州哪个好', +# '深圳和广州,哪个比较好', +# '广州与深圳比较哪个好', +# '深圳和广州哪里比较好', +# '深圳还是广州比较好?', +# '广州和深圳哪个地方好一些?', +# '广州好还是深圳好?', +# '广州好还是深圳好呢?', +# '广州与深圳哪个地方好点?', +# '深圳好还是广州好', +# '广州好还是深圳好', +# '广州和深圳哪个城市好?'] ``` ## 模型权重对照表 @@ -39,6 +115,8 @@ pip install -U transformers | [roformer_chinese_sim_char_ft_small](https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small) | [chinese_roformer-sim-char-ft_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1G36x7YQF1b6nzW0OzyJS_Q) (download code:gty5) | + + ### 英文模型(使用electra的训练方法在openwebtext上训练的small模型(rotary value = True)) | huggingface.co | | ---------------------------------- | @@ -139,7 +217,7 @@ print(tf_outputs_sentence) # tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 ``` - + ## 手动权重转换 ```bash python convert_roformer_original_tf_checkpoint_to_pytorch.py \ diff --git a/examples/test_sim.py b/examples/test_sim.py new file mode 100644 index 0000000..5afecce --- /dev/null +++ b/examples/test_sim.py @@ -0,0 +1,63 @@ +import torch +import numpy as np +from roformer import RoFormerForCausalLM, RoFormerConfig +from transformers import BertTokenizer + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +pretrained_model = "junnyu/roformer_chinese_sim_char_base" +tokenizer = BertTokenizer.from_pretrained(pretrained_model) +config = RoFormerConfig.from_pretrained(pretrained_model) +config.is_decoder = True +config.eos_token_id = tokenizer.sep_token_id +config.pooler_activation = "linear" +model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config) +model.to(device) +model.eval() + +def gen_synonyms(text, n=100, k=20): + ''''含义: 产生sent的n个相似句,然后返回最相似的k个。 + 做法:用seq2seq生成,并用encoder算相似度并排序。 + ''' + # 寻找所有相似的句子 + r = [] + inputs1 = tokenizer(text, return_tensors="pt") + for _ in range(n): + inputs1.to(device) + output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。 + r.append(output) + + # 对相似的句子进行排序 + r = [i for i in set(r) if i != text and len(i) > 0] + r = [text] + r + inputs2 = tokenizer(r, padding=True, return_tensors="pt") + with torch.no_grad(): + inputs2.to(device) + outputs = model(**inputs2) + Z = outputs.pooler_output.cpu().numpy() + Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 + argsort = np.dot(Z[1:], -Z[0]).argsort() + + return [r[i + 1] for i in argsort[:k]] + +out = gen_synonyms("广州和深圳哪个好?") +print(out) +# ['深圳和广州哪个好?', +# '广州和深圳哪个好', +# '深圳和广州哪个好', +# '深圳和广州哪个比较好。', +# '深圳和广州哪个最好?', +# '深圳和广州哪个比较好', +# '广州和深圳那个比较好', +# '深圳和广州哪个更好?', +# '深圳与广州哪个好', +# '深圳和广州,哪个比较好', +# '广州与深圳比较哪个好', +# '深圳和广州哪里比较好', +# '深圳还是广州比较好?', +# '广州和深圳哪个地方好一些?', +# '广州好还是深圳好?', +# '广州好还是深圳好呢?', +# '广州与深圳哪个地方好点?', +# '深圳好还是广州好', +# '广州好还是深圳好', +# '广州和深圳哪个城市好?'] diff --git a/setup.py b/setup.py index 8b853d4..fa10744 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name="roformer", package_dir={"": "src"}, packages=find_packages("src"), - version="0.4.0", + version="0.4.1", license="Apache 2.0", description="roformer_pytorch", author="Jun Yu", diff --git a/src/roformer/configuration_roformer.py b/src/roformer/configuration_roformer.py index 46b1b3a..6449488 100644 --- a/src/roformer/configuration_roformer.py +++ b/src/roformer/configuration_roformer.py @@ -24,8 +24,16 @@ "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json", "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json", "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json", + "junnyu/roformer_chinese_sim_char_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_small/resolve/main/config.json", + "junnyu/roformer_chinese_sim_char_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_base/resolve/main/config.json", + "junnyu/roformer_chinese_sim_char_ft_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base/resolve/main/config.json", + "junnyu/roformer_chinese_sim_char_ft_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small/resolve/main/config.json", "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json", "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json", + "junnyu/roformer_base_wwm_cluecorpussmall": "https://huggingface.co/junnyu/roformer_base_wwm_cluecorpussmall/resolve/main/config.json", + "junnyu/roformer_v2_chinese_char_small": "https://huggingface.co/junnyu/roformer_v2_chinese_char_small/resolve/main/config.json", + "junnyu/roformer_v2_chinese_char_base": "https://huggingface.co/junnyu/roformer_v2_chinese_char_base/resolve/main/config.json", + "junnyu/roformer_v2_chinese_char_large": "https://huggingface.co/junnyu/roformer_v2_chinese_char_large/resolve/main/config.json", # See all RoFormer models at https://huggingface.co/models?filter=roformer } @@ -107,6 +115,7 @@ def __init__( use_cache=True, use_bias=True, norm_type="layer_norm", + pooler_activation="tanh", **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -128,3 +137,4 @@ def __init__( self.use_cache = use_cache self.use_bias = use_bias self.norm_type = norm_type + self.pooler_activation = pooler_activation diff --git a/src/roformer/modeling_roformer.py b/src/roformer/modeling_roformer.py index 67699e6..8adef11 100644 --- a/src/roformer/modeling_roformer.py +++ b/src/roformer/modeling_roformer.py @@ -31,11 +31,11 @@ add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings, + ModelOutput ) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, @@ -49,6 +49,9 @@ find_pruneable_heads_and_indices, prune_linear_layer, ) +from dataclasses import dataclass +from typing import Optional, Tuple + from transformers.utils import logging from .configuration_roformer import RoFormerConfig @@ -64,8 +67,16 @@ "junnyu/roformer_chinese_base", "junnyu/roformer_chinese_char_small", "junnyu/roformer_chinese_char_base", + "junnyu/roformer_chinese_sim_char_small", + "junnyu/roformer_chinese_sim_char_base", + "junnyu/roformer_chinese_sim_char_ft_small", + "junnyu/roformer_chinese_sim_char_ft_base", "junnyu/roformer_small_discriminator", - "junnyu/roformer_small_generator" + "junnyu/roformer_small_generator", + "junnyu/roformer_base_wwm_cluecorpussmall", + "junnyu/roformer_v2_chinese_char_small", + "junnyu/roformer_v2_chinese_char_base", + "junnyu/roformer_v2_chinese_char_large", # See all RoFormer models at https://huggingface.co/models?filter=roformer ] @@ -90,6 +101,18 @@ def initializer(tensor, gain=1.0): std = 1.13684723 / hidden_size**0.5 * gain return nn.init.trunc_normal_(tensor, std=std) + +@dataclass +class CausalLMOutputWithPoolingAndCrossAttentions(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer class RoFormerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -121,9 +144,8 @@ def _init_weight(out: nn.Parameter): return out @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + def forward(self, seq_len: int, past_key_values_length: int = 0): """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] positions = torch.arange( past_key_values_length, past_key_values_length + seq_len, @@ -300,6 +322,8 @@ def forward( ): mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) + # rotary query + query_layer = self.apply_rotary(query_layer, sinusoidal_pos) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. @@ -317,24 +341,20 @@ def forward( elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) + # rotary key_layer & value_layer + key_layer = self.apply_rotary(key_layer, sinusoidal_pos) + if self.rotary_value: + value_layer = self.apply_rotary(value_layer, sinusoidal_pos) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - if sinusoidal_pos is not None: - if self.rotary_value: - ( - query_layer, - key_layer, - value_layer, - ) = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_layer, key_layer, value_layer - ) - else: - query_layer, key_layer = self.apply_rotary_position_embeddings( - sinusoidal_pos, query_layer, key_layer - ) + # rotary key_layer & value_layer + key_layer = self.apply_rotary(key_layer, sinusoidal_pos) + if self.rotary_value: + value_layer = self.apply_rotary(value_layer, sinusoidal_pos) + if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -379,36 +399,10 @@ def forward( return outputs @staticmethod - def apply_rotary_position_embeddings( - sinusoidal_pos, query_layer, key_layer, value_layer=None - ): - # https://kexue.fm/archives/8265 - # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2] - # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] - sin, cos = sinusoidal_pos.chunk(2, dim=-1) - # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos) - # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] - cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos) - # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] - rotate_half_query_layer = torch.stack( - [-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1 - ).reshape_as(query_layer) - query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos - # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] - rotate_half_key_layer = torch.stack( - [-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1 - ).reshape_as(key_layer) - key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos - if value_layer is not None: - # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] - rotate_half_value_layer = torch.stack( - [-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1 - ).reshape_as(value_layer) - value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos - return query_layer, key_layer, value_layer - return query_layer, key_layer - + def apply_rotary(x, sinusoidal_pos): + sin, cos = sinusoidal_pos + x1, x2 = x[..., 0::2], x[..., 1::2] + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoFormer class RoFormerSelfOutput(nn.Module): @@ -650,10 +644,13 @@ def forward( () if output_attentions and self.config.add_cross_attention else None ) - # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head] - sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1])[ + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + # [sequence_length, embed_size_per_head] -> sin & cos [batch_size, num_heads, sequence_length, embed_size_per_head // 2] + sinusoidal_pos = self.embed_positions(hidden_states.shape[1], past_key_values_length)[ None, None, :, : - ] + ].chunk(2, dim=-1) next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): @@ -791,7 +788,7 @@ class RoFormerPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() + self.activation = ACT2FN[config.pooler_activation] def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding @@ -913,7 +910,7 @@ class RoFormerModel(RoFormerPreTrainedModel): input to the forward pass. """ - def __init__(self, config, add_pooling_layer=True): + def __init__(self, config, add_pooling_layer=False): super().__init__(config) self.config = config self.add_pooling_layer = add_pooling_layer @@ -1034,7 +1031,7 @@ def forward( # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( - attention_mask, input_shape, device + attention_mask, input_shape, device, past_key_values_length ) # If a 2D or 3D attention mask is provided for the cross-attention @@ -1096,6 +1093,31 @@ def forward( cross_attentions=encoder_outputs.cross_attentions, ) + # 添加了个past_key_values_length + def get_extended_attention_mask(self, attention_mask, input_shape, device, past_key_values_length): + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + if self.config.is_decoder and past_key_values_length > 0: # 第一次编码的时候不需要使用decoder mask,之后的需要decoder mask。 + extended_attention_mask = self.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + @add_start_docstrings( """RoFormer Model with a `language modeling` head on top. """, @@ -1111,7 +1133,7 @@ def __init__(self, config): "bi-directional self-attention." ) - self.roformer = RoFormerModel(config, add_pooling_layer=False) + self.roformer = RoFormerModel(config) self.cls = RoFormerOnlyMLMHead(config) # Initialize weights and apply final processing @@ -1233,7 +1255,7 @@ def __init__(self, config): "If you want to use `RoFormerForCausalLM` as a standalone, add `is_decoder=True.`" ) - self.roformer = RoFormerModel(config, add_pooling_layer=False) + self.roformer = RoFormerModel(config, add_pooling_layer=True) self.cls = RoFormerOnlyMLMHead(config) # Initialize weights and apply final processing @@ -1249,7 +1271,7 @@ def set_output_embeddings(self, new_embeddings): ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length") ) @replace_return_docstrings( - output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC + output_type=CausalLMOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC ) def forward( self, @@ -1291,15 +1313,68 @@ def forward( decoding (see :obj:`past_key_values`). Returns: Example:: - >>> from transformers import RoFormerTokenizer, RoFormerForCausalLM, RoFormerConfig >>> import torch - >>> tokenizer = RoFormerTokenizer.from_pretrained('junnyu/roformer_chinese_base') - >>> config = RoFormerConfig.from_pretrained("junnyu/roformer_chinese_base") + >>> import numpy as np + >>> from roformer import RoFormerForCausalLM, RoFormerConfig + >>> from transformers import BertTokenizer + >>> device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + >>> pretrained_model = "junnyu/roformer_chinese_sim_char_base" + >>> tokenizer = BertTokenizer.from_pretrained(pretrained_model) + >>> config = RoFormerConfig.from_pretrained(pretrained_model) >>> config.is_decoder = True - >>> model = RoFormerForCausalLM.from_pretrained('junnyu/roformer_chinese_base', config=config) - >>> inputs = tokenizer("今天天气非常好。", return_tensors="pt") - >>> outputs = model(**inputs) - >>> prediction_logits = outputs.logits + >>> config.eos_token_id = tokenizer.sep_token_id + >>> config.pooler_activation = "linear" + >>> model = RoFormerForCausalLM.from_pretrained(pretrained_model, config=config) + >>> model.to(device) + >>> model.eval() + + >>> def gen_synonyms(text, n=100, k=20): + >>> ''''含义: 产生sent的n个相似句,然后返回最相似的k个。 + >>> 做法:用seq2seq生成,并用encoder算相似度并排序。 + >>> ''' + >>> # 寻找所有相似的句子 + >>> r = [] + >>> inputs1 = tokenizer(text, return_tensors="pt") + >>> for _ in range(n): + >>> inputs1.to(device) + >>> output = tokenizer.batch_decode(model.generate(**inputs1, top_p=0.95, do_sample=True, max_length=128), skip_special_tokens=True)[0].replace(" ","").replace(text, "") # 去除空格,去除原始text文本。 + >>> r.append(output) + + >>> # 对相似的句子进行排序 + >>> r = [i for i in set(r) if i != text and len(i) > 0] + >>> r = [text] + r + >>> inputs2 = tokenizer(r, padding=True, return_tensors="pt") + >>> with torch.no_grad(): + >>> inputs2.to(device) + >>> outputs = model(**inputs2) + >>> Z = outputs.pooler_output.cpu().numpy() + >>> Z /= (Z**2).sum(axis=1, keepdims=True)**0.5 + >>> argsort = np.dot(Z[1:], -Z[0]).argsort() + + >>> return [r[i + 1] for i in argsort[:k]] + + >>> out = gen_synonyms("广州和深圳哪个好?") + >>> print(out) + >>> # ['深圳和广州哪个好?', + >>> # '广州和深圳哪个好', + >>> # '深圳和广州哪个好', + >>> # '深圳和广州哪个比较好。', + >>> # '深圳和广州哪个最好?', + >>> # '深圳和广州哪个比较好', + >>> # '广州和深圳那个比较好', + >>> # '深圳和广州哪个更好?', + >>> # '深圳与广州哪个好', + >>> # '深圳和广州,哪个比较好', + >>> # '广州与深圳比较哪个好', + >>> # '深圳和广州哪里比较好', + >>> # '深圳还是广州比较好?', + >>> # '广州和深圳哪个地方好一些?', + >>> # '广州好还是深圳好?', + >>> # '广州好还是深圳好呢?', + >>> # '广州与深圳哪个地方好点?', + >>> # '深圳好还是广州好', + >>> # '广州好还是深圳好', + >>> # '广州和深圳哪个城市好?'] """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict @@ -1335,12 +1410,13 @@ def forward( ) if not return_dict: - output = (prediction_scores,) + outputs[1:] + output = (prediction_scores,) + outputs[1:] # with pooler return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutputWithCrossAttentions( + return CausalLMOutputWithPoolingAndCrossAttentions( loss=lm_loss, logits=prediction_scores, + pooler_output=outputs.pooler_output, # with pooler_output past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -1348,7 +1424,7 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past=None, attention_mask=None, **model_kwargs + self, input_ids, past=None, attention_mask=None, token_type_ids=None, **model_kwargs ): input_shape = input_ids.shape @@ -1356,13 +1432,20 @@ def prepare_inputs_for_generation( if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) + # 第一次编码的时候token_type_ids等于0 + if token_type_ids is None: + token_type_ids = input_ids.new_zeros(input_shape) + # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] + # 然后有past之后,token_type_ids等于1 + token_type_ids = torch.ones_like(input_ids[:, -1:]) return { "input_ids": input_ids, "attention_mask": attention_mask, + "token_type_ids": token_type_ids, "past_key_values": past, } @@ -1411,7 +1494,7 @@ class RoFormerForSequenceClassification(RoFormerPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.roformer = RoFormerModel(config, add_pooling_layer=False) + self.roformer = RoFormerModel(config) self.classifier = RoFormerClassificationHead(config) # Initialize weights and apply final processing @@ -1509,7 +1592,7 @@ class RoFormerForMultipleChoice(RoFormerPreTrainedModel): def __init__(self, config): super().__init__(config) - self.roformer = RoFormerModel(config, add_pooling_layer=False) + self.roformer = RoFormerModel(config) self.sequence_summary = SequenceSummary(config) self.classifier = nn.Linear(config.hidden_size, 1) @@ -1616,7 +1699,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.roformer = RoFormerModel(config, add_pooling_layer=False) + self.roformer = RoFormerModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) @@ -1711,7 +1794,7 @@ def __init__(self, config): config.num_labels = 2 self.num_labels = config.num_labels - self.roformer = RoFormerModel(config, add_pooling_layer=False) + self.roformer = RoFormerModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing diff --git a/src/roformer/modeling_tf_roformer.py b/src/roformer/modeling_tf_roformer.py index 24936a9..74b69c0 100644 --- a/src/roformer/modeling_tf_roformer.py +++ b/src/roformer/modeling_tf_roformer.py @@ -63,13 +63,21 @@ _CONFIG_FOR_DOC = "RoFormerConfig" _TOKENIZER_FOR_DOC = "RoFormerTokenizer" -TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ +ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ "junnyu/roformer_chinese_small", "junnyu/roformer_chinese_base", "junnyu/roformer_chinese_char_small", "junnyu/roformer_chinese_char_base", + "junnyu/roformer_chinese_sim_char_small", + "junnyu/roformer_chinese_sim_char_base", + "junnyu/roformer_chinese_sim_char_ft_small", + "junnyu/roformer_chinese_sim_char_ft_base", "junnyu/roformer_small_discriminator", - "junnyu/roformer_small_generator" + "junnyu/roformer_small_generator", + "junnyu/roformer_base_wwm_cluecorpussmall", + "junnyu/roformer_v2_chinese_char_small", + "junnyu/roformer_v2_chinese_char_base", + "junnyu/roformer_v2_chinese_char_large", # See all RoFormer models at https://huggingface.co/models?filter=roformer ] diff --git a/src/roformer/tokenization_roformer.py b/src/roformer/tokenization_roformer.py index 68d2ad0..c608313 100644 --- a/src/roformer/tokenization_roformer.py +++ b/src/roformer/tokenization_roformer.py @@ -18,14 +18,13 @@ import os from typing import List, Optional, Tuple -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.utils import logging from transformers.models.bert.tokenization_bert import ( BasicTokenizer, WordpieceTokenizer, load_vocab, ) - +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging logger = logging.get_logger(__name__) @@ -37,8 +36,17 @@ "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt", "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt", + "junnyu/roformer_chinese_sim_char_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_small/resolve/main/vocab.txt", + "junnyu/roformer_chinese_sim_char_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_base/resolve/main/vocab.txt", + "junnyu/roformer_chinese_sim_char_ft_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small/resolve/main/vocab.txt", + "junnyu/roformer_chinese_sim_char_ft_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base/resolve/main/vocab.txt", "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt", "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt", + "junnyu/roformer_base_wwm_cluecorpussmall": "https://huggingface.co/junnyu/roformer_base_wwm_cluecorpussmall/resolve/main/vocab.txt", + "junnyu/roformer_v2_chinese_char_small": "https://huggingface.co/junnyu/roformer_v2_chinese_char_small/resolve/main/vocab.txt", + "junnyu/roformer_v2_chinese_char_base": "https://huggingface.co/junnyu/roformer_v2_chinese_char_base/resolve/main/vocab.txt", + "junnyu/roformer_v2_chinese_char_large": "https://huggingface.co/junnyu/roformer_v2_chinese_char_large/resolve/main/vocab.txt", + # See all RoFormer models at https://huggingface.co/models?filter=roformer } } @@ -47,8 +55,16 @@ "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small": 512, "junnyu/roformer_chinese_char_base": 512, + "junnyu/roformer_chinese_sim_char_small": 512, + "junnyu/roformer_chinese_sim_char_base": 512, + "junnyu/roformer_chinese_sim_char_ft_small": 512, + "junnyu/roformer_chinese_sim_char_ft_base": 512, "junnyu/roformer_small_discriminator": 128, "junnyu/roformer_small_generator": 128, + "junnyu/roformer_base_wwm_cluecorpussmall": 512, + "junnyu/roformer_v2_chinese_char_small": 512, + "junnyu/roformer_v2_chinese_char_base": 512, + "junnyu/roformer_v2_chinese_char_large": 512, } @@ -57,8 +73,16 @@ "junnyu/roformer_chinese_base": {"do_lower_case": True}, "junnyu/roformer_chinese_char_small": {"do_lower_case": True}, "junnyu/roformer_chinese_char_base": {"do_lower_case": True}, + "junnyu/roformer_chinese_sim_char_small": {"do_lower_case": True}, + "junnyu/roformer_chinese_sim_char_base": {"do_lower_case": True}, + "junnyu/roformer_chinese_sim_char_ft_small": {"do_lower_case": True}, + "junnyu/roformer_chinese_sim_char_ft_base": {"do_lower_case": True}, "junnyu/roformer_small_discriminator": {"do_lower_case": True}, "junnyu/roformer_small_generator": {"do_lower_case": True}, + "junnyu/roformer_base_wwm_cluecorpussmall": {"do_lower_case": True}, + "junnyu/roformer_v2_chinese_char_small": {"do_lower_case": True}, + "junnyu/roformer_v2_chinese_char_base": {"do_lower_case": True}, + "junnyu/roformer_v2_chinese_char_large": {"do_lower_case": True}, } @@ -329,4 +353,4 @@ def save_vocabulary( index = token_index writer.write(token + "\n") index += 1 - return (vocab_file,) + return (vocab_file,) \ No newline at end of file diff --git a/src/roformer/tokenization_roformer_fast.py b/src/roformer/tokenization_roformer_fast.py index b055380..a2a1c7e 100644 --- a/src/roformer/tokenization_roformer_fast.py +++ b/src/roformer/tokenization_roformer_fast.py @@ -35,8 +35,17 @@ "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", "junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt", "junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt", + "junnyu/roformer_chinese_sim_char_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_small/resolve/main/vocab.txt", + "junnyu/roformer_chinese_sim_char_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_base/resolve/main/vocab.txt", + "junnyu/roformer_chinese_sim_char_ft_small": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_small/resolve/main/vocab.txt", + "junnyu/roformer_chinese_sim_char_ft_base": "https://huggingface.co/junnyu/roformer_chinese_sim_char_ft_base/resolve/main/vocab.txt", "junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt", "junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt", + "junnyu/roformer_base_wwm_cluecorpussmall": "https://huggingface.co/junnyu/roformer_base_wwm_cluecorpussmall/resolve/main/vocab.txt", + "junnyu/roformer_v2_chinese_char_small": "https://huggingface.co/junnyu/roformer_v2_chinese_char_small/resolve/main/vocab.txt", + "junnyu/roformer_v2_chinese_char_base": "https://huggingface.co/junnyu/roformer_v2_chinese_char_base/resolve/main/vocab.txt", + "junnyu/roformer_v2_chinese_char_large": "https://huggingface.co/junnyu/roformer_v2_chinese_char_large/resolve/main/vocab.txt", + # See all RoFormer models at https://huggingface.co/models?filter=roformer } } @@ -45,8 +54,16 @@ "junnyu/roformer_chinese_base": 1536, "junnyu/roformer_chinese_char_small": 512, "junnyu/roformer_chinese_char_base": 512, + "junnyu/roformer_chinese_sim_char_small": 512, + "junnyu/roformer_chinese_sim_char_base": 512, + "junnyu/roformer_chinese_sim_char_ft_small": 512, + "junnyu/roformer_chinese_sim_char_ft_base": 512, "junnyu/roformer_small_discriminator": 128, "junnyu/roformer_small_generator": 128, + "junnyu/roformer_base_wwm_cluecorpussmall": 512, + "junnyu/roformer_v2_chinese_char_small": 512, + "junnyu/roformer_v2_chinese_char_base": 512, + "junnyu/roformer_v2_chinese_char_large": 512, } @@ -55,8 +72,16 @@ "junnyu/roformer_chinese_base": {"do_lower_case": True}, "junnyu/roformer_chinese_char_small": {"do_lower_case": True}, "junnyu/roformer_chinese_char_base": {"do_lower_case": True}, + "junnyu/roformer_chinese_sim_char_small": {"do_lower_case": True}, + "junnyu/roformer_chinese_sim_char_base": {"do_lower_case": True}, + "junnyu/roformer_chinese_sim_char_ft_small": {"do_lower_case": True}, + "junnyu/roformer_chinese_sim_char_ft_base": {"do_lower_case": True}, "junnyu/roformer_small_discriminator": {"do_lower_case": True}, "junnyu/roformer_small_generator": {"do_lower_case": True}, + "junnyu/roformer_base_wwm_cluecorpussmall": {"do_lower_case": True}, + "junnyu/roformer_v2_chinese_char_small": {"do_lower_case": True}, + "junnyu/roformer_v2_chinese_char_base": {"do_lower_case": True}, + "junnyu/roformer_v2_chinese_char_large": {"do_lower_case": True}, }