diff --git a/README.md b/README.md index 64709b2..95af845 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,26 @@ -# PyTorch RoFormer -原版Tensorflow权重(https://github.com/ZhuiyiTechnology/roformer) -- [chinese_roformer_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1fiss862YsGCwf2HvU_Jm-g) (提取码:xy9x) -- [chinese_roformer_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1iIXgZHHCgrYGXVRRSSCVPg) (提取码:gy97) -- [chinese_roformer-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1Q1pq8F4Fsl6bTipUAkqeDQ) (提取码:bt94) -- [chinese_roformer-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1cc281-M0Rsjlwws5phqzbQ)(提取码:a44c) -- [chinese_roformer-sim-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1f1FB288nv1a6jYjsNCordg)(提取码:2cgz) -- [chinese_roformer-sim-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1r0eJ7shGwQ0RzV9BTFFW4g)(提取码:h68q) - -已经转化为PyTorch权重 -- [chinese_roformer_small.zip](https://pan.baidu.com/s/1Cx7lhtojTyRF61IKHWXEHw) (提取码:8znw) -- [chinese_roformer_base.zip](https://pan.baidu.com/s/10W5BYDQSeLyajTWjexZeoQ) (提取码:bimr) -- [chinese_roformer_char_base.zip](https://pan.baidu.com/s/18bgJ1t_1ke0BXq_Xg02qSQ) (提取码:oqb5) - -## 安装(代码已经加入到huggingface仓库) +# PyTorch RoFormer & RoFormer-V2 +RoFormer模型和RoFormer-V2模型 + +## 更新 +- 2022/03/21 添加`roformer-v2`的权重, 注:必须使用本仓库的代码,不能使用transformers仓库的代码!!! + +## 安装(代码已经加入到huggingface仓库),V2版本需要使用本仓库的代码 transformers v4.7版本已经发布,可以直接安装使用 ```bash pip install -U transformers ``` + ## 模型权重对照表 -### 中文模型 +### 中文模型 roformer-v2 +| huggingface.co | bert4keras | +| ---------------------------------- | ------------------------------------------------ | +| [roformer_v2_chinese_char_small](https://huggingface.co/junnyu/roformer_v2_chinese_char_small) | [chinese_roformer-v2-char_L-6_H-384_A-6.zip](https://pan.baidu.com/s/1huUrC9P60Afggo8AfiUcmA) (download code:ttn4) | +| [roformer_v2_chinese_char_base](https://huggingface.co/junnyu/roformer_v2_chinese_char_base) | [chinese_roformer-v2-char_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1qcnN4LVKVe0-mnHlkN3-6Q) (download code:pfoh) | +| [roformer_v2_chinese_char_large](https://huggingface.co/junnyu/roformer_v2_chinese_char_large) | [chinese_roformer-v2-char_L-24_H-1024_A-16.zip](https://pan.baidu.com/s/1QiJWSZrGxn8vek-8myvL6w) (download code:npfv) | + + +### 中文模型 roformer-v1 | huggingface.co | bert4keras | | ---------------------------------- | ------------------------------------------------ | | [roformer_chinese_base](https://huggingface.co/junnyu/roformer_chinese_base) | [chinese_roformer_L-12_H-768_A-12.zip](https://pan.baidu.com/s/1fiss862YsGCwf2HvU_Jm-g) (download code:xy9x) | @@ -38,34 +39,69 @@ pip install -U transformers |[roformer_small_generator](https://huggingface.co/junnyu/roformer_small_generator)| |[roformer_small_discriminator](https://huggingface.co/junnyu/roformer_small_discriminator)| - -## 使用 +## roformer-v2 MLM测试 ```python import torch -from transformers import RoFormerModel, RoFormerTokenizer, TFRoFormerModel -tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base") -pt_model = RoFormerModel.from_pretrained("junnyu/roformer_chinese_base") -tf_model = TFRoFormerModel.from_pretrained("junnyu/roformer_chinese_base", - from_pt=True) -text = "这里基本保留了唐宋遗留下来的坊巷格局和大量明清古建筑,其中各级文保单位29处,被誉为“里坊制度的活化石”“明清建筑博物馆”!" +import tensorflow as tf +from transformers import BertTokenizer +from roformer import RoFormerForMaskedLM, TFRoFormerForMaskedLM + +text = "今天[MASK]很好,我[MASK]去公园玩。" +tokenizer = BertTokenizer.from_pretrained("junnyu/roformer_v2_chinese_char_base") +pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_v2_chinese_char_base") +tf_model = TFRoFormerForMaskedLM.from_pretrained( + "junnyu/roformer_v2_chinese_char_base", from_pt=True +) pt_inputs = tokenizer(text, return_tensors="pt") tf_inputs = tokenizer(text, return_tensors="tf") +# pytorch with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).last_hidden_state -print(pt_outputs.shape) -tf_outputs = tf_model(**tf_inputs, training=False).last_hidden_state -print(tf_outputs.shape) + pt_outputs = pt_model(**pt_inputs).logits[0] +pt_outputs_sentence = "pytorch: " +for i, id in enumerate(tokenizer.encode(text)): + if id == tokenizer.mask_token_id: + tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1]) + pt_outputs_sentence += "[" + "||".join(tokens) + "]" + else: + pt_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) + ) +print(pt_outputs_sentence) +# tf +tf_outputs = tf_model(**tf_inputs, training=False).logits[0] +tf_outputs_sentence = "tf: " +for i, id in enumerate(tokenizer.encode(text)): + if id == tokenizer.mask_token_id: + tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) + tf_outputs_sentence += "[" + "||".join(tokens) + "]" + else: + tf_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) + ) +print(tf_outputs_sentence) +# small +# pytorch: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 +# tf: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 +# base +# pytorch: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 +# tf: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 +# large +# pytorch: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 +# tf: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 ``` -## MLM测试 + +## roformer-v1 MLM测试 ```python import torch import tensorflow as tf from transformers import RoFormerForMaskedLM, RoFormerTokenizer, TFRoFormerForMaskedLM + text = "今天[MASK]很好,我[MASK]去公园玩。" tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base") pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base") tf_model = TFRoFormerForMaskedLM.from_pretrained( - "junnyu/roformer_chinese_base", from_pt=True) + "junnyu/roformer_chinese_base", from_pt=True +) pt_inputs = tokenizer(text, return_tensors="pt") tf_inputs = tokenizer(text, return_tensors="tf") # pytorch @@ -78,22 +114,24 @@ for i, id in enumerate(tokenizer.encode(text)): pt_outputs_sentence += "[" + "||".join(tokens) + "]" else: pt_outputs_sentence += "".join( - tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)) + tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) + ) print(pt_outputs_sentence) # tf tf_outputs = tf_model(**tf_inputs, training=False).logits[0] tf_outputs_sentence = "tf: " for i, id in enumerate(tokenizer.encode(text)): if id == tokenizer.mask_token_id: - tokens = tokenizer.convert_ids_to_tokens( - tf.math.top_k(tf_outputs[i], k=5)[1]) + tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) tf_outputs_sentence += "[" + "||".join(tokens) + "]" else: tf_outputs_sentence += "".join( - tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)) + tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) + ) print(tf_outputs_sentence) # pytorch: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 # tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 + ``` ## 手动权重转换 diff --git a/examples/test_mlm.py b/examples/test_mlm.py deleted file mode 100644 index 2ecdda0..0000000 --- a/examples/test_mlm.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from roformer import RoFormerForMaskedLM, RoFormerTokenizerFast - -text = "今天[MASK]很好,我[MASK]去公园玩。" -tokenizer = RoFormerTokenizerFast.from_pretrained("junnyu/roformer_chinese_base") -model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base") - -inputs = tokenizer(text, return_tensors="pt") -with torch.no_grad(): - outputs = model(**inputs).logits[0] - -outputs_sentence = "" -for i, id in enumerate(tokenizer.encode(text)): - if id == tokenizer.mask_token_id: - tokens = tokenizer.convert_ids_to_tokens(outputs[i].topk(k=5)[1]) - outputs_sentence += "[" + "||".join(tokens) + "]" - else: - outputs_sentence += "".join( - tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) - ) - -print(outputs_sentence) diff --git a/examples/test_mlm_v1.py b/examples/test_mlm_v1.py new file mode 100644 index 0000000..6f95e00 --- /dev/null +++ b/examples/test_mlm_v1.py @@ -0,0 +1,39 @@ +import torch +import tensorflow as tf +from transformers import RoFormerForMaskedLM, RoFormerTokenizer, TFRoFormerForMaskedLM + +text = "今天[MASK]很好,我[MASK]去公园玩。" +tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base") +pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base") +tf_model = TFRoFormerForMaskedLM.from_pretrained( + "junnyu/roformer_chinese_base", from_pt=True +) +pt_inputs = tokenizer(text, return_tensors="pt") +tf_inputs = tokenizer(text, return_tensors="tf") +# pytorch +with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).logits[0] +pt_outputs_sentence = "pytorch: " +for i, id in enumerate(tokenizer.encode(text)): + if id == tokenizer.mask_token_id: + tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1]) + pt_outputs_sentence += "[" + "||".join(tokens) + "]" + else: + pt_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) + ) +print(pt_outputs_sentence) +# tf +tf_outputs = tf_model(**tf_inputs, training=False).logits[0] +tf_outputs_sentence = "tf: " +for i, id in enumerate(tokenizer.encode(text)): + if id == tokenizer.mask_token_id: + tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) + tf_outputs_sentence += "[" + "||".join(tokens) + "]" + else: + tf_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) + ) +print(tf_outputs_sentence) +# pytorch: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 +# tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。 diff --git a/examples/test_mlm_v2.py b/examples/test_mlm_v2.py new file mode 100644 index 0000000..e4b26e0 --- /dev/null +++ b/examples/test_mlm_v2.py @@ -0,0 +1,47 @@ +import torch +import tensorflow as tf +from transformers import BertTokenizer +from roformer import RoFormerForMaskedLM, TFRoFormerForMaskedLM + +text = "今天[MASK]很好,我[MASK]去公园玩。" +tokenizer = BertTokenizer.from_pretrained("junnyu/roformer_v2_chinese_char_base") +pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_v2_chinese_char_base") +tf_model = TFRoFormerForMaskedLM.from_pretrained( + "junnyu/roformer_v2_chinese_char_base", from_pt=True +) +pt_inputs = tokenizer(text, return_tensors="pt") +tf_inputs = tokenizer(text, return_tensors="tf") +# pytorch +with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).logits[0] +pt_outputs_sentence = "pytorch: " +for i, id in enumerate(tokenizer.encode(text)): + if id == tokenizer.mask_token_id: + tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1]) + pt_outputs_sentence += "[" + "||".join(tokens) + "]" + else: + pt_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) + ) +print(pt_outputs_sentence) +# tf +tf_outputs = tf_model(**tf_inputs, training=False).logits[0] +tf_outputs_sentence = "tf: " +for i, id in enumerate(tokenizer.encode(text)): + if id == tokenizer.mask_token_id: + tokens = tokenizer.convert_ids_to_tokens(tf.math.top_k(tf_outputs[i], k=5)[1]) + tf_outputs_sentence += "[" + "||".join(tokens) + "]" + else: + tf_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True) + ) +print(tf_outputs_sentence) +# small +# pytorch: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 +# tf: 今天[的||,||是||很||也]很好,我[要||会||是||想||在]去公园玩。 +# base +# pytorch: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 +# tf: 今天[我||天||晴||园||玩]很好,我[想||要||会||就||带]去公园玩。 +# large +# pytorch: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 +# tf: 今天[天||气||我||空||阳]很好,我[又||想||会||就||爱]去公园玩。 diff --git a/setup.py b/setup.py index 3e5ae64..8b853d4 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name="roformer", package_dir={"": "src"}, packages=find_packages("src"), - version="0.3.1", + version="0.4.0", license="Apache 2.0", description="roformer_pytorch", author="Jun Yu", diff --git a/src/roformer/configuration_roformer.py b/src/roformer/configuration_roformer.py index ce11d41..46b1b3a 100644 --- a/src/roformer/configuration_roformer.py +++ b/src/roformer/configuration_roformer.py @@ -105,6 +105,8 @@ def __init__( pad_token_id=0, rotary_value=False, use_cache=True, + use_bias=True, + norm_type="layer_norm", **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -124,3 +126,5 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.rotary_value = rotary_value self.use_cache = use_cache + self.use_bias = use_bias + self.norm_type = norm_type diff --git a/src/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py b/src/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py index 8b91be5..32a107f 100644 --- a/src/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py +++ b/src/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py @@ -35,10 +35,17 @@ def convert_tf_checkpoint_to_pytorch( # Load weights from tf checkpoint load_tf_weights_in_roformer(model, config, tf_checkpoint_path) + # ignore 不保存roformer.encoder.embed_positions.weight + _keys_to_ignore_on_save = ["roformer.encoder.embed_positions.weight"] + state_dict = model.state_dict() + for ignore_key in _keys_to_ignore_on_save: + if ignore_key in state_dict.keys(): + del state_dict[ignore_key] + # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") torch.save( - model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False + state_dict, pytorch_dump_path, _use_new_zipfile_serialization=False ) diff --git a/src/roformer/modeling_roformer.py b/src/roformer/modeling_roformer.py index 85d125a..a9eed7b 100644 --- a/src/roformer/modeling_roformer.py +++ b/src/roformer/modeling_roformer.py @@ -69,6 +69,15 @@ # See all RoFormer models at https://huggingface.co/models?filter=roformer ] +class Norm(nn.Module): + def __init__(self, eps = 1e-12): + super().__init__() + self.eps = eps + + def forward(self, x): + variance = torch.mean(torch.square(x), dim=-1, keepdim=True) + return x / torch.sqrt(variance + self.eps) + # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer class RoFormerSinusoidalPositionalEmbedding(nn.Embedding): @@ -209,7 +218,7 @@ def __init__(self, config): # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) if config.norm_type=="layer_norm" else Norm(eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None): @@ -250,9 +259,9 @@ def __init__(self, config): self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) @@ -394,8 +403,8 @@ def apply_rotary_position_embeddings( class RoFormerSelfOutput(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.use_bias) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.norm_type=="layer_norm" else Norm(eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): @@ -469,7 +478,7 @@ def forward( class RoFormerIntermediate(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.use_bias) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: @@ -485,8 +494,8 @@ def forward(self, hidden_states): class RoFormerOutput(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.use_bias) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.norm_type=="layer_norm" else Norm(eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): @@ -726,6 +735,14 @@ def forward(self, hidden_states): hidden_states = self.LayerNorm(hidden_states) return hidden_states +class RoFormerV2LMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False) + + def forward(self, hidden_states): + return self.decoder(hidden_states) + class RoFormerLMPredictionHead(nn.Module): def __init__(self, config): @@ -751,7 +768,7 @@ def forward(self, hidden_states): class RoFormerOnlyMLMHead(nn.Module): def __init__(self, config): super().__init__() - self.predictions = RoFormerLMPredictionHead(config) + self.predictions = RoFormerLMPredictionHead(config) if config.norm_type=="layer_norm" else RoFormerV2LMPredictionHead(config) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) @@ -789,6 +806,7 @@ class RoFormerPreTrainedModel(PreTrainedModel): r"roformer\.embeddings_project\.weight", r"roformer\.embeddings_project\.bias", ] + _keys_to_ignore_on_save = ["roformer.encoder.embed_positions.weight"] def _init_weights(self, module): """Initialize the weights""" @@ -1076,7 +1094,7 @@ def __init__(self, config): "bi-directional self-attention." ) - self.roformer = RoFormerModel(config, add_pooling_layer=True) + self.roformer = RoFormerModel(config, add_pooling_layer=False) self.cls = RoFormerOnlyMLMHead(config) # Initialize weights and apply final processing diff --git a/src/roformer/modeling_tf_roformer.py b/src/roformer/modeling_tf_roformer.py index 0ae4a68..24936a9 100644 --- a/src/roformer/modeling_tf_roformer.py +++ b/src/roformer/modeling_tf_roformer.py @@ -73,6 +73,15 @@ # See all RoFormer models at https://huggingface.co/models?filter=roformer ] +class Norm(tf.keras.layers.Layer): + def __init__(self, epsilon = 1e-12): + super().__init__() + self.epsilon = epsilon + + def call(self, inputs): + variance = tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + return inputs / tf.sqrt(variance + self.epsilon) + class TFRoFormerSinusoidalPositionalEmbedding(tf.keras.layers.Layer): """This module produces sinusoidal positional embeddings of any length.""" @@ -153,7 +162,7 @@ def __init__(self, config: RoFormerConfig, **kwargs): self.embeddings_sum = tf.keras.layers.Add() self.LayerNorm = tf.keras.layers.LayerNormalization( epsilon=config.layer_norm_eps, name="LayerNorm" - ) + ) if config.norm_type=="layer_norm" else Norm(epsilon=config.layer_norm_eps) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) def build(self, input_shape: tf.TensorShape): @@ -226,16 +235,19 @@ def __init__(self, config: RoFormerConfig, **kwargs): units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query", + use_bias=config.use_bias ) self.key = tf.keras.layers.Dense( units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key", + use_bias=config.use_bias ) self.value = tf.keras.layers.Dense( units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value", + use_bias=config.use_bias ) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) self.rotary_value = config.rotary_value @@ -365,10 +377,11 @@ def __init__(self, config: RoFormerConfig, **kwargs): units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense", + use_bias=config.use_bias ) self.LayerNorm = tf.keras.layers.LayerNormalization( epsilon=config.layer_norm_eps, name="LayerNorm" - ) + ) if config.norm_type=="layer_norm" else Norm(epsilon=config.layer_norm_eps) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) def call( @@ -427,6 +440,7 @@ def __init__(self, config: RoFormerConfig, **kwargs): units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense", + use_bias=config.use_bias ) if isinstance(config.hidden_act, str): @@ -450,10 +464,11 @@ def __init__(self, config: RoFormerConfig, **kwargs): units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense", + use_bias=config.use_bias ) self.LayerNorm = tf.keras.layers.LayerNormalization( epsilon=config.layer_norm_eps, name="LayerNorm" - ) + ) if config.norm_type=="layer_norm" else Norm(epsilon=config.layer_norm_eps) self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) def call( @@ -608,7 +623,7 @@ def __init__(self, config: RoFormerConfig, **kwargs): self.LayerNorm = tf.keras.layers.LayerNormalization( epsilon=config.layer_norm_eps, name="LayerNorm" - ) + ) if config.norm_type=="layer_norm" else Norm(epsilon=config.layer_norm_eps) def call(self, hidden_states: tf.Tensor) -> tf.Tensor: hidden_states = self.dense(inputs=hidden_states) @@ -670,6 +685,38 @@ def call(self, hidden_states: tf.Tensor) -> tf.Tensor: return hidden_states +class TFRoFormerV2LMPredictionHead(tf.keras.layers.Layer): + def __init__( + self, config: RoFormerConfig, input_embeddings: tf.keras.layers.Layer, **kwargs + ): + super().__init__(**kwargs) + + self.vocab_size = config.vocab_size + self.embedding_size = config.embedding_size + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.input_embeddings = input_embeddings + + def get_output_embeddings(self) -> tf.keras.layers.Layer: + return self.input_embeddings + + def set_output_embeddings(self, value: tf.Variable): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + seq_length = shape_list(hidden_states)[1] + hidden_states = tf.reshape( + tensor=hidden_states, shape=[-1, self.embedding_size] + ) + hidden_states = tf.matmul( + a=hidden_states, b=self.input_embeddings.weight, transpose_b=True + ) + hidden_states = tf.reshape( + tensor=hidden_states, shape=[-1, seq_length, self.vocab_size] + ) + return hidden_states # Copied from transformers.models.bert.modeling_tf_bert.TFBertMLMHead with Bert->RoFormer class TFRoFormerMLMHead(tf.keras.layers.Layer): @@ -680,7 +727,7 @@ def __init__( self.predictions = TFRoFormerLMPredictionHead( config, input_embeddings, name="predictions" - ) + ) if config.norm_type == "layer_norm" else TFRoFormerV2LMPredictionHead(config, input_embeddings, name="predictions") def call(self, sequence_output: tf.Tensor) -> tf.Tensor: prediction_scores = self.predictions(hidden_states=sequence_output) @@ -703,7 +750,7 @@ def __init__( self.embeddings = TFRoFormerEmbeddings(config, name="embeddings") if config.embedding_size != config.hidden_size: self.embeddings_project = tf.keras.layers.Dense( - config.hidden_size, name="embeddings_project" + config.hidden_size, name="embeddings_project", use_bias=config.use_bias ) self.encoder = TFRoFormerEncoder(config, name="encoder") @@ -1022,7 +1069,7 @@ def __init__(self, config: RoFormerConfig, *inputs, **kwargs): ) self.roformer = TFRoFormerMainLayer( - config, add_pooling_layer=True, name="roformer" + config, add_pooling_layer=False, name="roformer" ) self.mlm = TFRoFormerMLMHead( config, input_embeddings=self.roformer.embeddings, name="mlm___cls"