Skip to content

Commit

Permalink
code style fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng committed Oct 11, 2024
1 parent 66dcbac commit 6d0bae6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 45 deletions.
10 changes: 6 additions & 4 deletions docs/source/models/dssm_derivatives.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# DSSM衍生扩展模型
# DSSM衍生扩展模型

## DSSM + SENet

### 简介

在推荐场景中,往往存在多种用户特征和物品特征,特征类型各不相同,各种特征经过embedding层后进入双塔模型的DNN层进行训练,在部分场景中甚至还会引入多模态embedding特征, 如图像和文本的embedding。
然而各个特征对目标的影响不尽相同,有的特征重要性高,对模型整体表现影响大,有的特征则影响较小。因此当特征不断增多时,可以结合SENet自动学习每个特征的权重,增强重要信息到塔顶的能力。


![dssm+senet](../../images/models/dssm+senet.png)

### 配置说明
Expand Down Expand Up @@ -70,12 +70,14 @@ model_config:{
}
```

- senet参数配置:
- senet参数配置:
- num_squeeze_group: 每个特征embedding的分组个数, 默认为2
- reduction_ratio: 维度压缩比例, 默认为4

### 示例Config

[dssm_senet_on_taobao.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dssm_senet_on_taobao.config)

### 参考论文
[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)

[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507)
36 changes: 18 additions & 18 deletions easy_rec/python/layers/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@


class SENet:
'''
Squeeze and Excite Network
"""Squeeze and Excite Network.
Input shape
- A list of 2D tensor with shape: ``(batch_size,embedding_size)``.
Expand All @@ -20,15 +19,20 @@ class SENet:
reduction_ratio: int, reduction ratio for squeeze.
l2_reg: float, l2 regularizer for embedding.
name: str, name of the layer.
"""

'''
def __init__(self, num_fields, num_squeeze_group, reduction_ratio, l2_reg, name='SENet'):
def __init__(self,
num_fields,
num_squeeze_group,
reduction_ratio,
l2_reg,
name='SENet'):
self.num_fields = num_fields
self.num_squeeze_group = num_squeeze_group
self.reduction_ratio = reduction_ratio
self._l2_reg = l2_reg
self._name = name

def __call__(self, inputs):
g = self.num_squeeze_group
f = self.num_fields
Expand All @@ -39,7 +43,6 @@ def __call__(self, inputs):
for input in inputs:
emb_size += int(input.shape[-1])


group_embs = [
tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs
]
Expand All @@ -50,24 +53,21 @@ def __call__(self, inputs):
squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g]
z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2]



reduced = tf.layers.dense(
inputs=z,
units=reduction_size,
kernel_regularizer=self._l2_reg,
activation='relu',
name='%s/reduce' % self._name)
inputs=z,
units=reduction_size,
kernel_regularizer=self._l2_reg,
activation='relu',
name='%s/reduce' % self._name)

excited_weights = tf.layers.dense(
inputs=reduced,
units=emb_size,
kernel_initializer='glorot_normal',
units=emb_size,
kernel_initializer='glorot_normal',
name='%s/excite' % self._name)


# Re-weight
inputs = tf.concat(inputs, axis=-1)
output = inputs * excited_weights

return output
return output
46 changes: 24 additions & 22 deletions easy_rec/python/model/dssm_senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import tensorflow as tf

from easy_rec.python.layers import dnn
from easy_rec.python.layers import senet
from easy_rec.python.model.dssm import DSSM
from easy_rec.python.model.match_model import MatchModel
from easy_rec.python.protos.dssm_senet_pb2 import DSSM_SENet as DSSM_SENet_Config
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.protos.simi_pb2 import Similarity
from easy_rec.python.utils.proto_util import copy_obj
from easy_rec.python.layers import senet
from easy_rec.python.model.dssm import DSSM

from easy_rec.python.protos.dssm_senet_pb2 import DSSM_SENet as DSSM_SENet_Config

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand All @@ -25,7 +26,8 @@ def __init__(self,
labels=None,
is_training=False):

MatchModel.__init__(self, model_config, feature_configs, features, labels, is_training)
MatchModel.__init__(self, model_config, feature_configs, features, labels,
is_training)

assert self._model_config.WhichOneof('model') == 'dssm_senet', \
'invalid model config: %s' % self._model_config.WhichOneof('model')
Expand All @@ -35,26 +37,27 @@ def __init__(self,
# copy_obj so that any modification will not affect original config
self.user_tower = copy_obj(self._model_config.user_tower)

self.user_seq_features, self.user_plain_features, self.user_feature_list = self._input_layer(self._feature_dict, 'user', is_combine=False)
self.user_seq_features, self.user_plain_features, self.user_feature_list = self._input_layer(
self._feature_dict, 'user', is_combine=False)
self.user_num_fields = len(self.user_feature_list)

# copy_obj so that any modification will not affect original config
self.item_tower = copy_obj(self._model_config.item_tower)

self.item_seq_features, self.item_plain_features, self.item_feature_list = self._input_layer(self._feature_dict, 'item', is_combine=False)
self.item_seq_features, self.item_plain_features, self.item_feature_list = self._input_layer(
self._feature_dict, 'item', is_combine=False)
self.item_num_fields = len(self.item_feature_list)

self._user_tower_emb = None
self._item_tower_emb = None

def build_predict_graph(self):
user_senet = senet.SENet(
num_fields=self.user_num_fields,
num_squeeze_group=self.user_tower.senet.num_squeeze_group,
reduction_ratio=self.user_tower.senet.reduction_ratio,
l2_reg=self._l2_reg,
name='user_senet'
)
num_fields=self.user_num_fields,
num_squeeze_group=self.user_tower.senet.num_squeeze_group,
reduction_ratio=self.user_tower.senet.reduction_ratio,
l2_reg=self._l2_reg,
name='user_senet')
user_senet_output_list = user_senet(self.user_feature_list)
user_senet_output = tf.concat(user_senet_output_list, axis=-1)

Expand All @@ -70,15 +73,14 @@ def build_predict_graph(self):
name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1))

item_senet = senet.SENet(
num_fields=self.item_num_fields,
num_squeeze_group=self.item_tower.senet.num_squeeze_group,
reduction_ratio=self.item_tower.senet.reduction_ratio,
l2_reg=self._l2_reg,
name='item_senet'
)

num_fields=self.item_num_fields,
num_squeeze_group=self.item_tower.senet.num_squeeze_group,
reduction_ratio=self.item_tower.senet.reduction_ratio,
l2_reg=self._l2_reg,
name='item_senet')

item_senet_output_list = item_senet(self.item_feature_list)
item_senet_output = tf.concat(item_senet_output_list, axis=-1)
item_senet_output = tf.concat(item_senet_output_list, axis=-1)

num_item_dnn_layer = len(self.item_tower.dnn.hidden_units)
last_item_hidden = self.item_tower.dnn.hidden_units.pop()
Expand Down Expand Up @@ -137,5 +139,5 @@ def build_predict_graph(self):

def build_output_dict(self):
output_dict = MatchModel.build_output_dict(self)
return output_dict

return output_dict
2 changes: 1 addition & 1 deletion easy_rec/python/protos/dssm_senet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ message DSSM_SENet_Tower {
required string id = 1;
required SENet senet = 2;
required DNN dnn = 3;

};


Expand Down

0 comments on commit 6d0bae6

Please sign in to comment.