diff --git a/easy_rec/python/layers/input_layer.py b/easy_rec/python/layers/input_layer.py index 3085adc7a..7cdbcc669 100644 --- a/easy_rec/python/layers/input_layer.py +++ b/easy_rec/python/layers/input_layer.py @@ -8,8 +8,8 @@ from easy_rec.python.compat.feature_column import feature_column from easy_rec.python.feature_column.feature_column import FeatureColumnParser from easy_rec.python.feature_column.feature_group import FeatureGroup -from easy_rec.python.layers import dnn from easy_rec.python.layers import seq_input_layer +from easy_rec.python.layers import seq_model from easy_rec.python.layers import variational_dropout_layer from easy_rec.python.layers.common_layers import text_cnn from easy_rec.python.protos.feature_config_pb2 import WideOrDeep @@ -67,37 +67,6 @@ def __init__(self, def has_group(self, group_name): return group_name in self._feature_groups - def target_attention(self, dnn_config, deep_fea, name): - cur_id, hist_id_col, seq_len = deep_fea['key'], deep_fea[ - 'hist_seq_emb'], deep_fea['hist_seq_len'] - - seq_max_len = tf.shape(hist_id_col)[1] - emb_dim = hist_id_col.shape[2] - - cur_ids = tf.tile(cur_id, [1, seq_max_len]) - cur_ids = tf.reshape(cur_ids, - tf.shape(hist_id_col)) # (B, seq_max_len, emb_dim) - - din_net = tf.concat( - [cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col], - axis=-1) # (B, seq_max_len, emb_dim*4) - - din_layer = dnn.DNN(dnn_config, None, name, self._is_training) - din_net = din_layer(din_net) - scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?) - - seq_len = tf.expand_dims(seq_len, 1) - mask = tf.sequence_mask(seq_len) - padding = tf.ones_like(scores) * (-2**32 + 1) - scores = tf.where(mask, scores, padding) # [B, 1, seq_max_len] - - # Scale - scores = tf.nn.softmax(scores) # (B, 1, seq_max_len) - hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, emb_dim] - hist_din_emb = tf.reshape(hist_din_emb, [-1, emb_dim]) # [B, emb_dim] - din_output = tf.concat([hist_din_emb, cur_id], axis=1) - return din_output - def call_seq_input_layer(self, features, seq_att_map_config, @@ -121,8 +90,15 @@ def call_seq_input_layer(self, from easy_rec.python.protos.dnn_pb2 import DNN seq_dnn_config = DNN() seq_dnn_config.hidden_units.extend([128, 64, 32, 1]) - seq_fea = self.target_attention( - seq_dnn_config, seq_features, name='seq_dnn') + + seq_fea = None + if seq_att_map_config.seq_model == 'self_attention': + seq_fea = seq_model.self_attention(seq_features, + seq_att_map_config.seq_len, + seq_att_map_config.multi_head_size) + else: + seq_fea = seq_model.target_attention(seq_dnn_config, seq_features, + 'seq_dnn', self._is_training) return seq_fea def __call__(self, features, group_name, is_combine=True): diff --git a/easy_rec/python/layers/seq_model.py b/easy_rec/python/layers/seq_model.py new file mode 100644 index 000000000..7d95826b3 --- /dev/null +++ b/easy_rec/python/layers/seq_model.py @@ -0,0 +1,125 @@ +import math + +import tensorflow as tf + +from easy_rec.python.layers import dnn +from easy_rec.python.layers import layer_norm + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +# target attention +def target_attention(dnn_config, deep_fea, name, is_training): + cur_id, hist_id_col, seq_len = deep_fea['key'], deep_fea[ + 'hist_seq_emb'], deep_fea['hist_seq_len'] + + seq_max_len = tf.shape(hist_id_col)[1] + emb_dim = hist_id_col.shape[2] + + cur_ids = tf.tile(cur_id, [1, seq_max_len]) + cur_ids = tf.reshape(cur_ids, + tf.shape(hist_id_col)) # (B, seq_max_len, emb_dim) + + din_net = tf.concat( + [cur_ids, hist_id_col, cur_ids - hist_id_col, cur_ids * hist_id_col], + axis=-1) # (B, seq_max_len, emb_dim*4) + + din_layer = dnn.DNN(dnn_config, None, name, is_training) + din_net = din_layer(din_net) + scores = tf.reshape(din_net, [-1, 1, seq_max_len]) # (B, 1, ?) + + seq_len = tf.expand_dims(seq_len, 1) + mask = tf.sequence_mask(seq_len) + padding = tf.ones_like(scores) * (-2**32 + 1) + scores = tf.where(mask, scores, padding) # [B, 1, seq_max_len] + + # Scale + scores = tf.nn.softmax(scores) # (B, 1, seq_max_len) + hist_din_emb = tf.matmul(scores, hist_id_col) # [B, 1, emb_dim] + hist_din_emb = tf.reshape(hist_din_emb, [-1, emb_dim]) # [B, emb_dim] + din_output = tf.concat([hist_din_emb, cur_id], axis=1) + return din_output + + +def attention_net(net, dim, cur_seq_len, seq_size, name): + query_net = dnn_net(net, [dim], name + '_query') # B, seq_len, dim + key_net = dnn_net(net, [dim], name + '_key') + value_net = dnn_net(net, [dim], name + '_value') + scores = tf.matmul( + query_net, key_net, transpose_b=True) # [B, seq_size, seq_size] + + hist_mask = tf.sequence_mask( + cur_seq_len, maxlen=seq_size - 1) # [B, seq_size-1] + cur_id_mask = tf.ones([tf.shape(hist_mask)[0], 1], dtype=tf.bool) # [B, 1] + mask = tf.concat([hist_mask, cur_id_mask], axis=1) # [B, seq_size] + masks = tf.reshape(tf.tile(mask, [1, seq_size]), + (-1, seq_size, seq_size)) # [B, seq_size, seq_size] + padding = tf.ones_like(scores) * (-2**32 + 1) + scores = tf.where(masks, scores, padding) # [B, seq_size, seq_size] + + # Scale + scores = tf.nn.softmax(scores) # (B, seq_size, seq_size) + att_res_net = tf.matmul(scores, value_net) # [B, seq_size, emb_dim] + return att_res_net + + +def dnn_net(net, dnn_units, name): + with tf.variable_scope(name_or_scope=name, reuse=tf.AUTO_REUSE): + for idx, units in enumerate(dnn_units): + net = tf.layers.dense( + net, units=units, activation=tf.nn.relu, name='%s_%d' % (name, idx)) + return net + + +def add_and_norm(net_1, net_2, emb_dim): + net = tf.add(net_1, net_2) + layer = layer_norm.LayerNormalization(emb_dim) + net = layer(net) + return net + + +def multi_head_att_net(id_cols, head_count, emb_dim, seq_len, seq_size): + multi_head_attention_res = [] + part_cols_emd_dim = int(math.ceil(emb_dim / head_count)) + for start_idx in range(0, emb_dim, part_cols_emd_dim): + if start_idx + part_cols_emd_dim > emb_dim: + part_cols_emd_dim = emb_dim - start_idx + part_id_col = tf.slice(id_cols, [0, 0, start_idx], + [-1, -1, part_cols_emd_dim]) + part_attention_net = attention_net( + part_id_col, + part_cols_emd_dim, + seq_len, + seq_size, + name='multi_head_%d' % start_idx) + multi_head_attention_res.append(part_attention_net) + multi_head_attention_res_net = tf.concat(multi_head_attention_res, axis=2) + multi_head_attention_res_net = dnn_net( + multi_head_attention_res_net, [emb_dim], name='multi_head_attention') + return multi_head_attention_res_net + + +def self_attention(deep_fea, seq_size, head_count): + cur_id, hist_id_col, seq_len = deep_fea['key'], deep_fea[ + 'hist_seq_emb'], deep_fea['hist_seq_len'] + + cur_batch_max_seq_len = tf.shape(hist_id_col)[1] + + hist_id_col = tf.cond( + tf.constant(seq_size) > cur_batch_max_seq_len, lambda: tf.pad( + hist_id_col, [[0, 0], [0, seq_size - cur_batch_max_seq_len - 1], + [0, 0]], 'CONSTANT'), + lambda: tf.slice(hist_id_col, [0, 0, 0], [-1, seq_size - 1, -1])) + all_ids = tf.concat([hist_id_col, tf.expand_dims(cur_id, 1)], + axis=1) # b, seq_size, emb_dim + + emb_dim = int(all_ids.shape[2]) + attention_net = multi_head_att_net(all_ids, head_count, emb_dim, seq_len, + seq_size) + + tmp_net = add_and_norm(all_ids, attention_net, emb_dim) + feed_forward_net = dnn_net(tmp_net, [emb_dim], 'feed_forward_net') + net = add_and_norm(tmp_net, feed_forward_net, emb_dim) + atten_output = tf.reshape(net, [-1, seq_size * emb_dim]) + return atten_output diff --git a/easy_rec/python/protos/feature_config.proto b/easy_rec/python/protos/feature_config.proto index 18ef12ea1..29ca4179d 100644 --- a/easy_rec/python/protos/feature_config.proto +++ b/easy_rec/python/protos/feature_config.proto @@ -119,4 +119,7 @@ message SeqAttGroupConfig { optional bool tf_summary = 3 [default = false]; optional DNN seq_dnn = 4; optional bool allow_key_search = 5 [default = false]; + optional string seq_model = 6 [default = 'target_attention']; + optional uint32 multi_head_size = 7 [default = 4]; + optional uint32 seq_len = 8 [default = 50]; } diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index b334ef506..da64a8574 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -581,6 +581,12 @@ def test_sequence_wide_and_deep(self): 'samples/model_config/wide_and_deep_on_sequence_feature_taobao.config', self._test_dir) + def test_sequence_self_attention_dbmtl(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/dbmtl_on_sequence_feature_self_attention_taobao.config', + self._test_dir) + self.assertTrue(self._success) + if __name__ == '__main__': tf.test.main() diff --git a/samples/model_config/dbmtl_on_sequence_feature_self_attention_taobao.config b/samples/model_config/dbmtl_on_sequence_feature_self_attention_taobao.config new file mode 100644 index 000000000..021e09128 --- /dev/null +++ b/samples/model_config/dbmtl_on_sequence_feature_self_attention_taobao.config @@ -0,0 +1,298 @@ +train_input_path: "data/test/tb_data/taobao_train_data" +eval_input_path: "data/test/tb_data/taobao_test_data" +model_dir: "experiments/dbmtl_taobao_ckpt" + +train_config { + optimizer_config { + adam_optimizer { + learning_rate { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 1e-07 + } + } + } + use_moving_average: false + } + num_steps: 5000 + sync_replicas: true + save_checkpoints_steps: 100 + log_step_count_steps: 100 +} +eval_config { + metrics_set { + auc { + } + } +} +data_config { + batch_size: 4096 + label_fields: "clk" + label_fields: "buy" + prefetch_size: 32 + input_type: CSVInput + input_fields { + input_name: "clk" + input_type: INT32 + } + input_fields { + input_name: "buy" + input_type: INT32 + } + input_fields { + input_name: "pid" + input_type: STRING + } + input_fields { + input_name: "adgroup_id" + input_type: STRING + } + input_fields { + input_name: "cate_id" + input_type: STRING + } + input_fields { + input_name: "campaign_id" + input_type: STRING + } + input_fields { + input_name: "customer" + input_type: STRING + } + input_fields { + input_name: "brand" + input_type: STRING + } + input_fields { + input_name: "user_id" + input_type: STRING + } + input_fields { + input_name: "cms_segid" + input_type: STRING + } + input_fields { + input_name: "cms_group_id" + input_type: STRING + } + input_fields { + input_name: "final_gender_code" + input_type: STRING + } + input_fields { + input_name: "age_level" + input_type: STRING + } + input_fields { + input_name: "pvalue_level" + input_type: STRING + } + input_fields { + input_name: "shopping_level" + input_type: STRING + } + input_fields { + input_name: "occupation" + input_type: STRING + } + input_fields { + input_name: "new_user_class_level" + input_type: STRING + } + input_fields { + input_name: "tag_category_list" + input_type: STRING + } + input_fields { + input_name: "tag_brand_list" + input_type: STRING + } + input_fields { + input_name: "price" + input_type: INT32 + } +} +feature_configs { + input_names: "pid" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "adgroup_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "cate_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 +} +feature_configs { + input_names: "campaign_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "customer" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "brand" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "user_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "cms_segid" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs { + input_names: "cms_group_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs { + input_names: "final_gender_code" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "age_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "pvalue_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "shopping_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "occupation" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "new_user_class_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "tag_category_list" + feature_type: SequenceFeature + embedding_dim: 16 + hash_bucket_size: 100000 + separator: "|" +} +feature_configs { + input_names: "tag_brand_list" + feature_type: SequenceFeature + embedding_dim: 16 + hash_bucket_size: 100000 + separator: "|" +} +feature_configs { + input_names: "price" + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 +} +model_config { + model_class: "DBMTL" + feature_groups { + group_name: "all" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + feature_names: "pid" + wide_deep: DEEP + sequence_features: { + group_name: "seq_fea" + tf_summary: false + seq_model: "self_attention" + seq_len: 50 + seq_att_map: { + key: "brand" + key: "cate_id" + hist_seq: "tag_brand_list" + hist_seq: "tag_category_list" + } + } + } + dbmtl { + bottom_dnn { + hidden_units: [1024, 512, 256] + } + task_towers { + tower_name: "ctr" + label_name: "clk" + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + dnn { + hidden_units: [256, 128, 64, 32] + } + relation_dnn { + hidden_units: [32] + } + weight: 1.0 + } + task_towers { + tower_name: "cvr" + label_name: "buy" + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + dnn { + hidden_units: [256, 128, 64, 32] + } + relation_tower_names: ["ctr"] + relation_dnn { + hidden_units: [32] + } + weight: 1.0 + } + l2_regularization: 1e-6 + } + embedding_regularization: 5e-6 +}