From 16599476d913b8ae2e23db2f24759cb60fadd8f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=AB=E8=8B=8F?= Date: Wed, 18 Sep 2024 11:16:39 +0800 Subject: [PATCH] add transformer layer --- easy_rec/python/layers/keras/__init__.py | 1 + easy_rec/python/layers/keras/transformer.py | 13 +++++++------ easy_rec/python/protos/seq_encoder.proto | 3 +-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/easy_rec/python/layers/keras/__init__.py b/easy_rec/python/layers/keras/__init__.py index fcf6be817..bc2d7a7f4 100644 --- a/easy_rec/python/layers/keras/__init__.py +++ b/easy_rec/python/layers/keras/__init__.py @@ -29,3 +29,4 @@ from .ppnet import PPNet from .transformer import TransformerBlock from .transformer import TransformerEncoder +from .transformer import TextEncoder diff --git a/easy_rec/python/layers/keras/transformer.py b/easy_rec/python/layers/keras/transformer.py index 004bb3c35..518c26fb5 100644 --- a/easy_rec/python/layers/keras/transformer.py +++ b/easy_rec/python/layers/keras/transformer.py @@ -100,13 +100,18 @@ def __init__(self, params, name='transformer_encoder', reuse=None, **kwargs): super(TransformerEncoder, self).__init__(name=name, **kwargs) d_model = params.hidden_size dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1) - vocab_size = params.vocab_size max_position = params.get_or_default('max_position_embeddings', 512) num_layers = params.get_or_default('num_hidden_layers', 1) + vocab_size = params.vocab_size self.output_all = params.get_or_default('output_all_token_embeddings', True) self.pos_encoding = PositionalEmbedding(vocab_size, d_model, max_position) self.dropout = Dropout(dropout_rate) self.enc_layers = [TransformerBlock(params) for _ in range(num_layers)] + self._vocab_size = vocab_size + + @property + def vocab_size(self): + return self._vocab_size def call(self, inputs, training=None): x, mask = inputs @@ -125,9 +130,6 @@ def __init__(self, params, name='text_encoder', reuse=None, **kwargs): self.separator = params.get_or_default('separator', ' ') self.cls_token = '[CLS]' + self.separator self.sep_token = self.separator + '[SEP]' + self.separator - hash_bucket_size = params.get_or_default('hash_bucket_size', None) - emb_dim = params.transformer.hidden_size - self.emb_table = Embedding(hash_bucket_size, emb_dim) trans_params = Parameter.make_from_pb(params.attention) self.encoder = TransformerEncoder(trans_params) @@ -147,6 +149,5 @@ def call(self, inputs, training=None): tokens = tf.sparse.to_dense(tokens, default_value='') mask = tf.cast(tf.not_equal(tokens, ''), tf.bool) token_ids = tf.string_to_hash_bucket_fast(tokens, self.hash_bucket_size) - embedding = self.emb_table(token_ids) - encoding = self.encoder([embedding, mask], training) + encoding = self.encoder([token_ids, mask], training) return encoding diff --git a/easy_rec/python/protos/seq_encoder.proto b/easy_rec/python/protos/seq_encoder.proto index ab89b7ab4..cee230acd 100644 --- a/easy_rec/python/protos/seq_encoder.proto +++ b/easy_rec/python/protos/seq_encoder.proto @@ -33,7 +33,7 @@ message Transformer { // The maximum sequence length that this model might ever be used with required uint32 max_position_embeddings = 8 [default = 512]; // Whether to add position embeddings for the position of each token in the text sequence - required bool use_position_embeddings = 9 [default = true]; + required bool use_position_embeddings = 9 [default = false]; // Whether to output all token embedding, if set to false, then only output the first token embedding required bool output_all_token_embeddings = 10 [default = true]; } @@ -41,7 +41,6 @@ message Transformer { message TextEncoder { required Transformer transformer = 1; required string separator = 2 [default = ' ']; - optional uint32 hash_bucket_size = 3; } message BSTEncoder {