From 27a862215dfd4795c52de912631c2a656c19bad2 Mon Sep 17 00:00:00 2001 From: yangxudong Date: Tue, 10 Dec 2024 10:58:44 +0800 Subject: [PATCH] modify dropoutnet in case of batch size mismatch (#505) * modify dropoutnet in case of batch size mismatch --- easy_rec/python/model/dropoutnet.py | 35 ++++++++++++----------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/easy_rec/python/model/dropoutnet.py b/easy_rec/python/model/dropoutnet.py index 8d84ad341..b4c30e77c 100644 --- a/easy_rec/python/model/dropoutnet.py +++ b/easy_rec/python/model/dropoutnet.py @@ -7,7 +7,6 @@ from easy_rec.python.model.easy_rec_model import EasyRecModel from easy_rec.python.protos.loss_pb2 import LossType from easy_rec.python.utils.proto_util import copy_obj -from easy_rec.python.utils.shape_utils import get_shape_list from easy_rec.python.protos.dropoutnet_pb2 import DropoutNet as DropoutNetConfig # NOQA from easy_rec.python.loss.softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA @@ -22,6 +21,14 @@ def cosine_similarity(user_emb, item_emb): tf.multiply(user_emb, item_emb), axis=1, name='cosine') return user_item_sim +def bernoulli_dropout(x, rate, training=False): + if rate == 0.0 or not training: + return x + keep_rate = 1.0 - rate + dist = tf.distributions.Bernoulli(probs=keep_rate, dtype=x.dtype) + mask = dist.sample(sample_shape=tf.stack([tf.shape(x)[0], 1])) + return x * mask / keep_rate + class DropoutNet(EasyRecModel): @@ -68,8 +75,6 @@ def __init__(self, assert self.item_content_feature is not None or self.item_preference_feature is not None, 'no item feature' def build_predict_graph(self): - batch_size = get_shape_list(self.item_content_feature)[0] - num_user_dnn_layer = len(self.user_tower_layers.hidden_units) last_user_hidden = self.user_tower_layers.hidden_units.pop() num_item_dnn_layer = len(self.item_tower_layers.hidden_units) @@ -85,15 +90,9 @@ def build_predict_graph(self): content_feature = user_content_dnn(self.user_content_feature) user_features.append(content_feature) if self.user_preference_feature is not None: - if self._is_training: - prob = tf.random.uniform([batch_size]) - user_prefer_feature = tf.where( - tf.less(prob, self._model_config.user_dropout_rate), - tf.zeros_like(self.user_preference_feature), - self.user_preference_feature) - else: - user_prefer_feature = self.user_preference_feature - + user_prefer_feature = bernoulli_dropout(self.user_preference_feature, + self._model_config.user_dropout_rate, + self._is_training) user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg, 'user_preference', self._is_training) prefer_feature = user_prefer_dnn(user_prefer_feature) @@ -119,15 +118,9 @@ def build_predict_graph(self): content_feature = item_content_dnn(self.item_content_feature) item_features.append(content_feature) if self.item_preference_feature is not None: - if self._is_training: - prob = tf.random.uniform([batch_size]) - item_prefer_feature = tf.where( - tf.less(prob, self._model_config.item_dropout_rate), - tf.zeros_like(self.item_preference_feature), - self.item_preference_feature) - else: - item_prefer_feature = self.item_preference_feature - + item_prefer_feature = bernoulli_dropout(self.item_preference_feature, + self._model_config.item_dropout_rate, + self._is_training) item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg, 'item_preference', self._is_training) prefer_feature = item_prefer_dnn(item_prefer_feature)