From 6d0bae6a15b2fa4e707faea8732effb4a84f0132 Mon Sep 17 00:00:00 2001 From: "eric.gc" Date: Fri, 11 Oct 2024 16:16:53 +0800 Subject: [PATCH] code style fix --- docs/source/models/dssm_derivatives.md | 10 +++--- easy_rec/python/layers/senet.py | 36 +++++++++---------- easy_rec/python/model/dssm_senet.py | 46 +++++++++++++------------ easy_rec/python/protos/dssm_senet.proto | 2 +- 4 files changed, 49 insertions(+), 45 deletions(-) diff --git a/docs/source/models/dssm_derivatives.md b/docs/source/models/dssm_derivatives.md index 747bf696f..d74aa9057 100644 --- a/docs/source/models/dssm_derivatives.md +++ b/docs/source/models/dssm_derivatives.md @@ -1,12 +1,12 @@ -# DSSM衍生扩展模型 +# DSSM衍生扩展模型 ## DSSM + SENet + ### 简介 在推荐场景中,往往存在多种用户特征和物品特征,特征类型各不相同,各种特征经过embedding层后进入双塔模型的DNN层进行训练,在部分场景中甚至还会引入多模态embedding特征, 如图像和文本的embedding。 然而各个特征对目标的影响不尽相同,有的特征重要性高,对模型整体表现影响大,有的特征则影响较小。因此当特征不断增多时,可以结合SENet自动学习每个特征的权重,增强重要信息到塔顶的能力。 - ![dssm+senet](../../images/models/dssm+senet.png) ### 配置说明 @@ -70,12 +70,14 @@ model_config:{ } ``` -- senet参数配置: +- senet参数配置: - num_squeeze_group: 每个特征embedding的分组个数, 默认为2 - reduction_ratio: 维度压缩比例, 默认为4 ### 示例Config + [dssm_senet_on_taobao.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dssm_senet_on_taobao.config) ### 参考论文 -[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507) \ No newline at end of file + +[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507) diff --git a/easy_rec/python/layers/senet.py b/easy_rec/python/layers/senet.py index 5715d189f..777079341 100644 --- a/easy_rec/python/layers/senet.py +++ b/easy_rec/python/layers/senet.py @@ -7,8 +7,7 @@ class SENet: - ''' - Squeeze and Excite Network + """Squeeze and Excite Network. Input shape - A list of 2D tensor with shape: ``(batch_size,embedding_size)``. @@ -20,15 +19,20 @@ class SENet: reduction_ratio: int, reduction ratio for squeeze. l2_reg: float, l2 regularizer for embedding. name: str, name of the layer. + """ - ''' - def __init__(self, num_fields, num_squeeze_group, reduction_ratio, l2_reg, name='SENet'): + def __init__(self, + num_fields, + num_squeeze_group, + reduction_ratio, + l2_reg, + name='SENet'): self.num_fields = num_fields self.num_squeeze_group = num_squeeze_group self.reduction_ratio = reduction_ratio self._l2_reg = l2_reg self._name = name - + def __call__(self, inputs): g = self.num_squeeze_group f = self.num_fields @@ -39,7 +43,6 @@ def __call__(self, inputs): for input in inputs: emb_size += int(input.shape[-1]) - group_embs = [ tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs ] @@ -50,24 +53,21 @@ def __call__(self, inputs): squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g] z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2] - - reduced = tf.layers.dense( - inputs=z, - units=reduction_size, - kernel_regularizer=self._l2_reg, - activation='relu', - name='%s/reduce' % self._name) - + inputs=z, + units=reduction_size, + kernel_regularizer=self._l2_reg, + activation='relu', + name='%s/reduce' % self._name) + excited_weights = tf.layers.dense( inputs=reduced, - units=emb_size, - kernel_initializer='glorot_normal', + units=emb_size, + kernel_initializer='glorot_normal', name='%s/excite' % self._name) - # Re-weight inputs = tf.concat(inputs, axis=-1) output = inputs * excited_weights - return output \ No newline at end of file + return output diff --git a/easy_rec/python/model/dssm_senet.py b/easy_rec/python/model/dssm_senet.py index f1c5446bb..406d3cbdf 100644 --- a/easy_rec/python/model/dssm_senet.py +++ b/easy_rec/python/model/dssm_senet.py @@ -3,13 +3,14 @@ import tensorflow as tf from easy_rec.python.layers import dnn +from easy_rec.python.layers import senet +from easy_rec.python.model.dssm import DSSM from easy_rec.python.model.match_model import MatchModel -from easy_rec.python.protos.dssm_senet_pb2 import DSSM_SENet as DSSM_SENet_Config from easy_rec.python.protos.loss_pb2 import LossType from easy_rec.python.protos.simi_pb2 import Similarity from easy_rec.python.utils.proto_util import copy_obj -from easy_rec.python.layers import senet -from easy_rec.python.model.dssm import DSSM + +from easy_rec.python.protos.dssm_senet_pb2 import DSSM_SENet as DSSM_SENet_Config if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -25,7 +26,8 @@ def __init__(self, labels=None, is_training=False): - MatchModel.__init__(self, model_config, feature_configs, features, labels, is_training) + MatchModel.__init__(self, model_config, feature_configs, features, labels, + is_training) assert self._model_config.WhichOneof('model') == 'dssm_senet', \ 'invalid model config: %s' % self._model_config.WhichOneof('model') @@ -35,13 +37,15 @@ def __init__(self, # copy_obj so that any modification will not affect original config self.user_tower = copy_obj(self._model_config.user_tower) - self.user_seq_features, self.user_plain_features, self.user_feature_list = self._input_layer(self._feature_dict, 'user', is_combine=False) + self.user_seq_features, self.user_plain_features, self.user_feature_list = self._input_layer( + self._feature_dict, 'user', is_combine=False) self.user_num_fields = len(self.user_feature_list) # copy_obj so that any modification will not affect original config self.item_tower = copy_obj(self._model_config.item_tower) - self.item_seq_features, self.item_plain_features, self.item_feature_list = self._input_layer(self._feature_dict, 'item', is_combine=False) + self.item_seq_features, self.item_plain_features, self.item_feature_list = self._input_layer( + self._feature_dict, 'item', is_combine=False) self.item_num_fields = len(self.item_feature_list) self._user_tower_emb = None @@ -49,12 +53,11 @@ def __init__(self, def build_predict_graph(self): user_senet = senet.SENet( - num_fields=self.user_num_fields, - num_squeeze_group=self.user_tower.senet.num_squeeze_group, - reduction_ratio=self.user_tower.senet.reduction_ratio, - l2_reg=self._l2_reg, - name='user_senet' - ) + num_fields=self.user_num_fields, + num_squeeze_group=self.user_tower.senet.num_squeeze_group, + reduction_ratio=self.user_tower.senet.reduction_ratio, + l2_reg=self._l2_reg, + name='user_senet') user_senet_output_list = user_senet(self.user_feature_list) user_senet_output = tf.concat(user_senet_output_list, axis=-1) @@ -70,15 +73,14 @@ def build_predict_graph(self): name='user_dnn/dnn_%d' % (num_user_dnn_layer - 1)) item_senet = senet.SENet( - num_fields=self.item_num_fields, - num_squeeze_group=self.item_tower.senet.num_squeeze_group, - reduction_ratio=self.item_tower.senet.reduction_ratio, - l2_reg=self._l2_reg, - name='item_senet' - ) - + num_fields=self.item_num_fields, + num_squeeze_group=self.item_tower.senet.num_squeeze_group, + reduction_ratio=self.item_tower.senet.reduction_ratio, + l2_reg=self._l2_reg, + name='item_senet') + item_senet_output_list = item_senet(self.item_feature_list) - item_senet_output = tf.concat(item_senet_output_list, axis=-1) + item_senet_output = tf.concat(item_senet_output_list, axis=-1) num_item_dnn_layer = len(self.item_tower.dnn.hidden_units) last_item_hidden = self.item_tower.dnn.hidden_units.pop() @@ -137,5 +139,5 @@ def build_predict_graph(self): def build_output_dict(self): output_dict = MatchModel.build_output_dict(self) - - return output_dict \ No newline at end of file + + return output_dict diff --git a/easy_rec/python/protos/dssm_senet.proto b/easy_rec/python/protos/dssm_senet.proto index fd49b9f76..ee941104f 100644 --- a/easy_rec/python/protos/dssm_senet.proto +++ b/easy_rec/python/protos/dssm_senet.proto @@ -9,7 +9,7 @@ message DSSM_SENet_Tower { required string id = 1; required SENet senet = 2; required DNN dnn = 3; - + };