Skip to content

Commit

Permalink
modify dropoutnet in case of batch size mismatch (#505)
Browse files Browse the repository at this point in the history
* modify dropoutnet in case of batch size mismatch
  • Loading branch information
yangxudong authored Dec 10, 2024
1 parent 4468723 commit 27a8622
Showing 1 changed file with 14 additions and 21 deletions.
35 changes: 14 additions & 21 deletions easy_rec/python/model/dropoutnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from easy_rec.python.model.easy_rec_model import EasyRecModel
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.utils.proto_util import copy_obj
from easy_rec.python.utils.shape_utils import get_shape_list

from easy_rec.python.protos.dropoutnet_pb2 import DropoutNet as DropoutNetConfig # NOQA
from easy_rec.python.loss.softmax_loss_with_negative_mining import softmax_loss_with_negative_mining # NOQA
Expand All @@ -22,6 +21,14 @@ def cosine_similarity(user_emb, item_emb):
tf.multiply(user_emb, item_emb), axis=1, name='cosine')
return user_item_sim

def bernoulli_dropout(x, rate, training=False):
if rate == 0.0 or not training:
return x
keep_rate = 1.0 - rate
dist = tf.distributions.Bernoulli(probs=keep_rate, dtype=x.dtype)
mask = dist.sample(sample_shape=tf.stack([tf.shape(x)[0], 1]))
return x * mask / keep_rate


class DropoutNet(EasyRecModel):

Expand Down Expand Up @@ -68,8 +75,6 @@ def __init__(self,
assert self.item_content_feature is not None or self.item_preference_feature is not None, 'no item feature'

def build_predict_graph(self):
batch_size = get_shape_list(self.item_content_feature)[0]

num_user_dnn_layer = len(self.user_tower_layers.hidden_units)
last_user_hidden = self.user_tower_layers.hidden_units.pop()
num_item_dnn_layer = len(self.item_tower_layers.hidden_units)
Expand All @@ -85,15 +90,9 @@ def build_predict_graph(self):
content_feature = user_content_dnn(self.user_content_feature)
user_features.append(content_feature)
if self.user_preference_feature is not None:
if self._is_training:
prob = tf.random.uniform([batch_size])
user_prefer_feature = tf.where(
tf.less(prob, self._model_config.user_dropout_rate),
tf.zeros_like(self.user_preference_feature),
self.user_preference_feature)
else:
user_prefer_feature = self.user_preference_feature

user_prefer_feature = bernoulli_dropout(self.user_preference_feature,
self._model_config.user_dropout_rate,
self._is_training)
user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg,
'user_preference', self._is_training)
prefer_feature = user_prefer_dnn(user_prefer_feature)
Expand All @@ -119,15 +118,9 @@ def build_predict_graph(self):
content_feature = item_content_dnn(self.item_content_feature)
item_features.append(content_feature)
if self.item_preference_feature is not None:
if self._is_training:
prob = tf.random.uniform([batch_size])
item_prefer_feature = tf.where(
tf.less(prob, self._model_config.item_dropout_rate),
tf.zeros_like(self.item_preference_feature),
self.item_preference_feature)
else:
item_prefer_feature = self.item_preference_feature

item_prefer_feature = bernoulli_dropout(self.item_preference_feature,
self._model_config.item_dropout_rate,
self._is_training)
item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg,
'item_preference', self._is_training)
prefer_feature = item_prefer_dnn(item_prefer_feature)
Expand Down

0 comments on commit 27a8622

Please sign in to comment.