diff --git a/docs/images/models/DAT.png b/docs/images/models/DAT.png new file mode 100644 index 000000000..7b3b56538 Binary files /dev/null and b/docs/images/models/DAT.png differ diff --git a/docs/source/models/dssm_derivatives.md b/docs/source/models/dssm_derivatives.md index 9271beb82..986bdbcb4 100644 --- a/docs/source/models/dssm_derivatives.md +++ b/docs/source/models/dssm_derivatives.md @@ -1,6 +1,6 @@ # DSSM衍生扩展模型 -## DSSM + SENet +## 1. DSSM + SENet ### 简介 @@ -84,7 +84,7 @@ model_config:{ [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507) -## 并行DSSM +## 2. 并行DSSM 在召回中,我们希望尽可能把不同的特征进行交叉融合,以便提取到隐藏的信息。而不同的特征提取器侧重点不尽相同,比如MLP是隐式特征交叉,FM和DCN都属于显式、有限阶特征交叉, CIN可以实现vector-wise显式交叉。因此可以让信息经由不同的通道向塔顶流动,每种通道各有所长,相互取长补短。最终将各通道得到的Embedding聚合成最终的Embedding,与对侧交互,从而提升召回的效果。 @@ -93,3 +93,54 @@ model_config:{ ### 示例Config [parallel_dssm_on_taobao_backbone.config](https://github.com/alibaba/EasyRec/tree/master/samples/model_config/parallel_dssm_on_taobao_backbone.config) + +## 3. 对偶增强双塔 Dual Augmented Two-Tower + +双塔模型对用户和物品的特征分开进行建模,在对特征进行了多层神经网络的整合后进行交互。由于网络的整合可能会损失一部分信息,因此过晚的user/item交互不利于模型的学习,这也是DSSM的一个主要的弊端。在对偶增强双塔算法中,作者设计了一个辅助向量,通过学习对user和item进行增强,使得user和item的交互更加有效。 + +![dat](../../images/models/dat.png) + +### 配置说明 + +作为DSSM的衍生模型,DAT的配置与DSSM类似,在model_config中除了user和item的feature_group外,还需要增加user_id_augment feature_group和item_id_augment feature_group, 作为模型输入的增强向量。 +两塔各自的DNN最后一层输出维度需要和user_id_augment的embedding维度保持一致,以便构造AMM损失(Adaptive-Mimic Mechanism)。 + +``` + feature_groups: { + group_name: 'user_id_augment' + feature_names: 'user_id' + wide_deep:DEEP + } + feature_groups: { + group_name: 'item_id_augment' + feature_names: 'adgroup_id' + wide_deep:DEEP + } + + dat { + user_tower { + id: "user_id" + dnn { + hidden_units: [ 128, 32] + # dropout_ratio : [0.1, 0.1, 0.1, 0.1] + } + } + item_tower { + id: "adgroup_id" + dnn { + hidden_units: [ 128, 32] + } + } + simi_func: COSINE + temperature: 0.01 + l2_regularization: 1e-6 + } +``` + +### 示例Config + +[dat_on_taobao.config](https://github.com/alibaba/EasyRec/tree/master/samples/model_config/dat_on_taobao.config) + +### 参考论文 + +[A Dual Augmented Two-tower Model for Online Large-scale Recommendation](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_4.pdf) diff --git a/easy_rec/python/model/dat.py b/easy_rec/python/model/dat.py index fceb8cee8..5c312299c 100644 --- a/easy_rec/python/model/dat.py +++ b/easy_rec/python/model/dat.py @@ -25,6 +25,15 @@ def __init__(self, is_training) assert self._model_config.WhichOneof('model') == 'dat', \ 'invalid model config: %s' % self._model_config.WhichOneof('model') + + feature_group_names = [ + fg.group_name for fg in self._model_config.feature_groups + ] + assert 'user' in feature_group_names, 'user feature group not found' + assert 'item' in feature_group_names, 'item feature group not found' + assert 'user_id_augment' in feature_group_names, 'user_id_augment feature group not found' + assert 'item_id_augment' in feature_group_names, 'item_id_augment feature group not found' + self._model_config = self._model_config.dat assert isinstance(self._model_config, DATConfig) diff --git a/samples/model_config/dat_on_taobao.config b/samples/model_config/dat_on_taobao.config new file mode 100644 index 000000000..e972149d7 --- /dev/null +++ b/samples/model_config/dat_on_taobao.config @@ -0,0 +1,322 @@ +train_input_path: "data/test/tb_data/taobao_train_data" +eval_input_path: "data/test/tb_data/taobao_test_data" +model_dir: "experiments/dat_taobao_ckpt" + +train_config { + log_step_count_steps: 200 + optimizer_config: { + adam_optimizer: { + learning_rate: { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 0.00001 + } + } + } + use_moving_average: false + } + save_checkpoints_steps: 4000 + sync_replicas: false + num_steps: 100000 +} + +eval_config { + metrics_set: { + recall_at_topk { + topk: 50 + } + } + metrics_set: { + recall_at_topk { + topk: 10 + } + } + metrics_set: { + recall_at_topk { + topk: 5 + } + } + metrics_set: { + recall_at_topk { + topk: 1 + } + } +} + +data_config { + input_fields { + input_name:'clk' + input_type: INT32 + } + input_fields { + input_name:'buy' + input_type: INT32 + } + input_fields { + input_name: 'pid' + input_type: STRING + } + input_fields { + input_name: 'adgroup_id' + input_type: STRING + } + input_fields { + input_name: 'cate_id' + input_type: STRING + } + input_fields { + input_name: 'campaign_id' + input_type: STRING + } + input_fields { + input_name: 'customer' + input_type: STRING + } + input_fields { + input_name: 'brand' + input_type: STRING + } + input_fields { + input_name: 'user_id' + input_type: STRING + } + input_fields { + input_name: 'cms_segid' + input_type: STRING + } + input_fields { + input_name: 'cms_group_id' + input_type: STRING + } + input_fields { + input_name: 'final_gender_code' + input_type: STRING + } + input_fields { + input_name: 'age_level' + input_type: STRING + } + input_fields { + input_name: 'pvalue_level' + input_type: STRING + } + input_fields { + input_name: 'shopping_level' + input_type: STRING + } + input_fields { + input_name: 'occupation' + input_type: STRING + } + input_fields { + input_name: 'new_user_class_level' + input_type: STRING + } + input_fields { + input_name: 'tag_category_list' + input_type: STRING + } + input_fields { + input_name: 'tag_brand_list' + input_type: STRING + } + input_fields { + input_name: 'price' + input_type: INT32 + } + + label_fields: 'clk' + batch_size: 512 + num_epochs: 10000 + prefetch_size: 32 + input_type: CSVInput + + negative_sampler { + input_path: 'data/test/tb_data/taobao_ad_feature_gl' + num_sample: 2048 + num_eval_sample: 2048 + attr_fields: 'adgroup_id' + attr_fields: 'cate_id' + attr_fields: 'campaign_id' + attr_fields: 'customer' + attr_fields: 'brand' + item_id_field: 'adgroup_id' + } +} + +feature_config: { + features: { + input_names: 'pid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'adgroup_id' + feature_type: IdFeature + embedding_dim: 32 + hash_bucket_size: 100000 + } + features: { + input_names: 'cate_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + features: { + input_names: 'campaign_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'customer' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'brand' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'user_id' + feature_type: IdFeature + embedding_dim: 32 + hash_bucket_size: 100000 + } + features: { + input_names: 'cms_segid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: 'cms_group_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: 'final_gender_code' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'age_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'pvalue_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'shopping_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'occupation' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'new_user_class_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'tag_category_list' + feature_type: TagFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 + } + features: { + input_names: 'tag_brand_list' + feature_type: TagFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 + } + features: { + input_names: 'price' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 + } +} +model_config:{ + model_class: "DAT" + feature_groups: { + group_name: 'user' + feature_names: 'user_id' + feature_names: 'cms_segid' + feature_names: 'cms_group_id' + feature_names: 'age_level' + feature_names: 'pvalue_level' + feature_names: 'shopping_level' + feature_names: 'occupation' + feature_names: 'new_user_class_level' + feature_names: 'tag_category_list' + feature_names: 'tag_brand_list' + wide_deep:DEEP + } + feature_groups: { + group_name: "item" + feature_names: 'adgroup_id' + feature_names: 'cate_id' + feature_names: 'campaign_id' + feature_names: 'customer' + feature_names: 'brand' + #feature_names: 'price' + #feature_names: 'pid' + wide_deep:DEEP + } + feature_groups: { + group_name: 'user_id_augment' + feature_names: 'user_id' + wide_deep:DEEP + } + feature_groups: { + group_name: 'item_id_augment' + feature_names: 'adgroup_id' + wide_deep:DEEP + } + + dat { + user_tower { + id: "user_id" + dnn { + hidden_units: [ 128, 32] + # dropout_ratio : [0.1, 0.1, 0.1, 0.1] + } + } + item_tower { + id: "adgroup_id" + dnn { + hidden_units: [ 128, 32] + } + } + simi_func: COSINE + temperature: 0.01 + l2_regularization: 1e-6 + } + embedding_regularization: 5e-5 + loss_type: SOFTMAX_CROSS_ENTROPY +} + +export_config { +}