Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Dual Augmented Two-Towers #503

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ If EasyRec is useful for your research, please cite:

### Join Us

- DingDing Group: 32260796. click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,MwaiOIY1Tb2W+onmBBumO7sQsdDOYjBmv6FXC6wTGns=&_dt_no_comment=1&origin=11#/ ) or scan QrCode to join![dinggroup1.png](docs/images/qrcode/dinggroup1.png)
- DingDing Group: 32260796. click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,MwaiOIY1Tb2W+onmBBumO7sQsdDOYjBmv6FXC6wTGns=&_dt_no_comment=1&origin=11#/) or scan QrCode to join![dinggroup1.png](docs/images/qrcode/dinggroup1.png)
- DingDing Group2: 37930014162, click [this url](https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload?action=joingroup&code=v1,k1,1ppFWEXXNPyxUClHh77gCmpfB+JcPhbFv6FXC6wTGns=&_dt_no_comment=1&origin=11#/) or scan QrCode to join![dinggroup2.png](docs/images/qrcode/dinggroup2.png)
- Email Group: [email protected].

Expand Down
Binary file added docs/images/models/DAT.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
55 changes: 53 additions & 2 deletions docs/source/models/dssm_derivatives.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# DSSM衍生扩展模型

## DSSM + SENet
## 1. DSSM + SENet

### 简介

Expand Down Expand Up @@ -84,7 +84,7 @@ model_config:{

[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)

## 并行DSSM
## 2. 并行DSSM

在召回中,我们希望尽可能把不同的特征进行交叉融合,以便提取到隐藏的信息。而不同的特征提取器侧重点不尽相同,比如MLP是隐式特征交叉,FM和DCN都属于显式、有限阶特征交叉, CIN可以实现vector-wise显式交叉。因此可以让信息经由不同的通道向塔顶流动,每种通道各有所长,相互取长补短。最终将各通道得到的Embedding聚合成最终的Embedding,与对侧交互,从而提升召回的效果。

Expand All @@ -93,3 +93,54 @@ model_config:{
### 示例Config

[parallel_dssm_on_taobao_backbone.config](https://github.com/alibaba/EasyRec/tree/master/samples/model_config/parallel_dssm_on_taobao_backbone.config)

## 3. 对偶增强双塔 Dual Augmented Two-Tower

双塔模型对用户和物品的特征分开进行建模,在对特征进行了多层神经网络的整合后进行交互。由于网络的整合可能会损失一部分信息,因此过晚的user/item交互不利于模型的学习,这也是DSSM的一个主要的弊端。在对偶增强双塔算法中,作者设计了一个辅助向量,通过学习对user和item进行增强,使得user和item的交互更加有效。

![dat](../../images/models/dat.png)

### 配置说明

作为DSSM的衍生模型,DAT的配置与DSSM类似,在model_config中除了user和item的feature_group外,还需要增加user_id_augment feature_group和item_id_augment feature_group, 作为模型输入的增强向量。
两塔各自的DNN最后一层输出维度需要和user_id_augment的embedding维度保持一致,以便构造AMM损失(Adaptive-Mimic Mechanism)。

```
feature_groups: {
group_name: 'user_id_augment'
feature_names: 'user_id'
wide_deep:DEEP
}
feature_groups: {
group_name: 'item_id_augment'
feature_names: 'adgroup_id'
wide_deep:DEEP
}

dat {
user_tower {
id: "user_id"
dnn {
hidden_units: [ 128, 32]
# dropout_ratio : [0.1, 0.1, 0.1, 0.1]
}
}
item_tower {
id: "adgroup_id"
dnn {
hidden_units: [ 128, 32]
}
}
simi_func: COSINE
temperature: 0.01
l2_regularization: 1e-6
}
```

### 示例Config

[dat_on_taobao.config](https://github.com/alibaba/EasyRec/tree/master/samples/model_config/dat_on_taobao.config)

### 参考论文

[A Dual Augmented Two-tower Model for Online Large-scale Recommendation](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_4.pdf)
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
138 changes: 138 additions & 0 deletions easy_rec/python/model/dat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf

from easy_rec.python.layers import dnn
from easy_rec.python.model.match_model import MatchModel
from easy_rec.python.protos.dat_pb2 import DAT as DATConfig
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.utils.proto_util import copy_obj

if tf.__version__ >= '2.0':
tf = tf.compat.v1


class DAT(MatchModel):
"""Dual Augmented Two-tower Model."""

def __init__(self,
model_config,
feature_configs,
features,
labels=None,
is_training=False):
super(DAT, self).__init__(model_config, feature_configs, features, labels,
is_training)
assert self._model_config.WhichOneof('model') == 'dat', \
'invalid model config: %s' % self._model_config.WhichOneof('model')

feature_group_names = [
fg.group_name for fg in self._model_config.feature_groups
]
assert 'user' in feature_group_names, 'user feature group not found'
assert 'item' in feature_group_names, 'item feature group not found'
assert 'user_id_augment' in feature_group_names, 'user_id_augment feature group not found'
assert 'item_id_augment' in feature_group_names, 'item_id_augment feature group not found'

self._model_config = self._model_config.dat
assert isinstance(self._model_config, DATConfig)

self.user_tower = copy_obj(self._model_config.user_tower)
self.user_deep_feature, _ = self._input_layer(self._feature_dict, 'user')
self.user_augmented_vec, _ = self._input_layer(self._feature_dict,
'user_id_augment')

self.item_tower = copy_obj(self._model_config.item_tower)
self.item_deep_feature, _ = self._input_layer(self._feature_dict, 'item')
self.item_augmented_vec, _ = self._input_layer(self._feature_dict,
'item_id_augment')

self._user_tower_emb = None
self._item_tower_emb = None

def build_predict_graph(self):
num_user_dnn_layer = len(self.user_tower.dnn.hidden_units)
last_user_hidden = self.user_tower.dnn.hidden_units.pop()
user_dnn = dnn.DNN(self.user_tower.dnn, self._l2_reg, 'user_dnn',
self._is_training)

user_tower_feature = tf.concat(
[self.user_deep_feature, self.user_augmented_vec], axis=-1)
user_tower_emb = user_dnn(user_tower_feature)
user_tower_emb = tf.layers.dense(
inputs=user_tower_emb,
units=last_user_hidden,
kernel_regularizer=self._l2_reg,
name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))

num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
last_item_hidden = self.item_tower.dnn.hidden_units.pop()
item_dnn = dnn.DNN(self.item_tower.dnn, self._l2_reg, 'item_dnn',
self._is_training)

item_tower_feature = tf.concat(
[self.item_deep_feature, self.item_augmented_vec], axis=-1)
item_tower_emb = item_dnn(item_tower_feature)
item_tower_emb = tf.layers.dense(
inputs=item_tower_emb,
units=last_item_hidden,
kernel_regularizer=self._l2_reg,
name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1))

user_tower_emb = self.norm(user_tower_emb)
item_tower_emb = self.norm(item_tower_emb)
temperature = self._model_config.temperature

y_pred = self.sim(user_tower_emb, item_tower_emb) / temperature

if self._is_point_wise:
raise ValueError('Currently DAT model only supports list wise mode.')

if self._loss_type == LossType.CLASSIFICATION:
raise ValueError(
'Currently DAT model only supports SOFTMAX_CROSS_ENTROPY loss.')
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
y_pred = self._mask_in_batch(y_pred)
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
else:
self._prediction_dict['y'] = y_pred

self._prediction_dict['user_tower_emb'] = user_tower_emb
self._prediction_dict['item_tower_emb'] = item_tower_emb
self._prediction_dict['user_emb'] = tf.reduce_join(
tf.as_string(user_tower_emb), axis=-1, separator=',')
self._prediction_dict['item_emb'] = tf.reduce_join(
tf.as_string(item_tower_emb), axis=-1, separator=',')

augmented_p_u = tf.stop_gradient(user_tower_emb)
augmented_p_i = tf.stop_gradient(item_tower_emb)

self._prediction_dict['augmented_p_u'] = augmented_p_u
self._prediction_dict['augmented_p_i'] = augmented_p_i

self._prediction_dict['augmented_a_u'] = self.user_augmented_vec
self._prediction_dict['augmented_a_i'] = self.item_augmented_vec

return self._prediction_dict

def get_outputs(self):
if self._loss_type == LossType.CLASSIFICATION:
raise ValueError(
'Currently DAT model only supports SOFTMAX_CROSS_ENTROPY loss.')
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
self._prediction_dict['logits'] = tf.squeeze(
self._prediction_dict['logits'], axis=-1)
self._prediction_dict['probs'] = tf.nn.sigmoid(
self._prediction_dict['logits'])
return [
'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
'item_tower_emb', 'augmented_p_u', 'augmented_p_i', 'augmented_a_u',
'augmented_a_i'
]
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))

def build_output_dict(self):
output_dict = super(DAT, self).build_output_dict()
return output_dict
13 changes: 13 additions & 0 deletions easy_rec/python/model/match_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ def _build_list_wise_loss_graph(self):
# if pos_simi < 0, produce loss
reg_pos_loss = tf.nn.relu(-pos_simi)
self._loss_dict['reg_pos_loss'] = tf.reduce_mean(reg_pos_loss)

# the AMM loss for DAT model
if all([
k in self._prediction_dict.keys() for k in
['augmented_p_u', 'augmented_p_i', 'augmented_a_u', 'augmented_a_i']
]):
self._loss_dict['amm_loss_u'] = tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_u'] -
self._prediction_dict['augmented_p_i'][:batch_size]))
self._loss_dict['amm_loss_i'] = tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_i'][:batch_size] -
self._prediction_dict['augmented_p_u']))

else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))
return self._loss_dict
Expand Down
16 changes: 12 additions & 4 deletions easy_rec/python/model/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def build_rtp_output_dict(self):
'failed to build RTP rank_predict output: classification model ' +
"expect 'probs' prediction, which is not found. Please check if" +
' build_predict_graph() is called.')
elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
elif loss_types & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
if 'y' in self._prediction_dict:
forwarded = self._prediction_dict['y']
else:
Expand Down Expand Up @@ -379,7 +381,9 @@ def _build_metric_impl(self,
metric.recall_at_topk.topk)
elif metric.WhichOneof('metric') == 'mean_absolute_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['mean_absolute_error' +
suffix] = metrics_tf.mean_absolute_error(
label, self._prediction_dict['y' + suffix])
Expand All @@ -391,7 +395,9 @@ def _build_metric_impl(self,
assert False, 'mean_absolute_error is not supported for this model'
elif metric.WhichOneof('metric') == 'mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['mean_squared_error' +
suffix] = metrics_tf.mean_squared_error(
label, self._prediction_dict['y' + suffix])
Expand All @@ -403,7 +409,9 @@ def _build_metric_impl(self,
assert False, 'mean_squared_error is not supported for this model'
elif metric.WhichOneof('metric') == 'root_mean_squared_error':
label = tf.to_float(self._labels[label_name])
if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}:
if loss_type & {
LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS
}:
metric_dict['root_mean_squared_error' +
suffix] = metrics_tf.root_mean_squared_error(
label, self._prediction_dict['y' + suffix])
Expand Down
21 changes: 21 additions & 0 deletions easy_rec/python/protos/dat.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
syntax = "proto2";
package protos;

import "easy_rec/python/protos/dnn.proto";
import "easy_rec/python/protos/simi.proto";


message DATTower {
required string id = 1;
required DNN dnn = 2;
};


message DAT {
required DATTower user_tower = 1;
required DATTower item_tower = 2;
required float l2_regularization = 3 [default = 1e-4];
optional Similarity simi_func = 4 [default=COSINE];
required bool ignore_in_batch_neg_sam = 5 [default = false];
optional float temperature = 6 [default = 1.0];
}
2 changes: 2 additions & 0 deletions easy_rec/python/protos/easy_rec_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import "easy_rec/python/protos/tower.proto";
import "easy_rec/python/protos/pdn.proto";
import "easy_rec/python/protos/dssm_senet.proto";
import "easy_rec/python/protos/simi.proto";
import "easy_rec/python/protos/dat.proto";
// for input performance test
message DummyModel {
}
Expand Down Expand Up @@ -114,6 +115,7 @@ message EasyRecModel {
CoMetricLearningI2I metric_learning = 204;
PDN pdn = 205;
DSSM_SENet dssm_senet = 206;
DAT dat = 207;

MMoE mmoe = 301;
ESMM esmm = 302;
Expand Down
8 changes: 7 additions & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import threading
import time
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down Expand Up @@ -1286,6 +1286,12 @@ def test_xdeefm_backbone_on_taobao(self):
self._test_dir)
self.assertTrue(self._success)

@unittest.skipIf(gl is None, 'graphlearn is not installed')
def test_dat_on_taobao(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/dat_on_taobao.config', self._test_dir)
self.assertTrue(self._success)


if __name__ == '__main__':
tf.test.main()
2 changes: 1 addition & 1 deletion requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ recommonmark==0.6.0
sphinx==5.1.1
sphinx_markdown_tables==0.0.17
sphinx_rtd_theme
tensorflow-probability==0.11.0
tensorflow-probability==0.11.0
Loading
Loading