Skip to content

Commit

Permalink
add transformer layer
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Sep 18, 2024
1 parent dfe285e commit 1659947
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions easy_rec/python/layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from .ppnet import PPNet
from .transformer import TransformerBlock
from .transformer import TransformerEncoder
from .transformer import TextEncoder
13 changes: 7 additions & 6 deletions easy_rec/python/layers/keras/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,18 @@ def __init__(self, params, name='transformer_encoder', reuse=None, **kwargs):
super(TransformerEncoder, self).__init__(name=name, **kwargs)
d_model = params.hidden_size
dropout_rate = params.get_or_default('hidden_dropout_prob', 0.1)
vocab_size = params.vocab_size
max_position = params.get_or_default('max_position_embeddings', 512)
num_layers = params.get_or_default('num_hidden_layers', 1)
vocab_size = params.vocab_size
self.output_all = params.get_or_default('output_all_token_embeddings', True)
self.pos_encoding = PositionalEmbedding(vocab_size, d_model, max_position)
self.dropout = Dropout(dropout_rate)
self.enc_layers = [TransformerBlock(params) for _ in range(num_layers)]
self._vocab_size = vocab_size

@property
def vocab_size(self):
return self._vocab_size

def call(self, inputs, training=None):
x, mask = inputs
Expand All @@ -125,9 +130,6 @@ def __init__(self, params, name='text_encoder', reuse=None, **kwargs):
self.separator = params.get_or_default('separator', ' ')
self.cls_token = '[CLS]' + self.separator
self.sep_token = self.separator + '[SEP]' + self.separator
hash_bucket_size = params.get_or_default('hash_bucket_size', None)
emb_dim = params.transformer.hidden_size
self.emb_table = Embedding(hash_bucket_size, emb_dim)
trans_params = Parameter.make_from_pb(params.attention)
self.encoder = TransformerEncoder(trans_params)

Expand All @@ -147,6 +149,5 @@ def call(self, inputs, training=None):
tokens = tf.sparse.to_dense(tokens, default_value='')
mask = tf.cast(tf.not_equal(tokens, ''), tf.bool)
token_ids = tf.string_to_hash_bucket_fast(tokens, self.hash_bucket_size)
embedding = self.emb_table(token_ids)
encoding = self.encoder([embedding, mask], training)
encoding = self.encoder([token_ids, mask], training)
return encoding
3 changes: 1 addition & 2 deletions easy_rec/python/protos/seq_encoder.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ message Transformer {
// The maximum sequence length that this model might ever be used with
required uint32 max_position_embeddings = 8 [default = 512];
// Whether to add position embeddings for the position of each token in the text sequence
required bool use_position_embeddings = 9 [default = true];
required bool use_position_embeddings = 9 [default = false];
// Whether to output all token embedding, if set to false, then only output the first token embedding
required bool output_all_token_embeddings = 10 [default = true];
}

message TextEncoder {
required Transformer transformer = 1;
required string separator = 2 [default = ' '];
optional uint32 hash_bucket_size = 3;
}

message BSTEncoder {
Expand Down

0 comments on commit 1659947

Please sign in to comment.