Skip to content

Commit

Permalink
fix sample weight calc for match model (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng authored Dec 27, 2024
1 parent b07889a commit cd47816
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
4 changes: 3 additions & 1 deletion easy_rec/python/inference/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,9 @@ def _parse_value(all_vals):
]
for k in self._reserved_cols:
if k in all_vals and all_vals[k].dtype == np.object:
all_vals[k] = [val.decode('utf-8', errors='ignore') for val in all_vals[k]]
all_vals[k] = [
val.decode('utf-8', errors='ignore') for val in all_vals[k]
]

ts2 = time.time()
reserve_vals = self._get_reserve_vals(self._reserved_cols,
Expand Down
13 changes: 7 additions & 6 deletions easy_rec/python/model/dropoutnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def cosine_similarity(user_emb, item_emb):
tf.multiply(user_emb, item_emb), axis=1, name='cosine')
return user_item_sim


def bernoulli_dropout(x, rate, training=False):
if rate == 0.0 or not training:
return x
Expand Down Expand Up @@ -90,9 +91,9 @@ def build_predict_graph(self):
content_feature = user_content_dnn(self.user_content_feature)
user_features.append(content_feature)
if self.user_preference_feature is not None:
user_prefer_feature = bernoulli_dropout(self.user_preference_feature,
self._model_config.user_dropout_rate,
self._is_training)
user_prefer_feature = bernoulli_dropout(
self.user_preference_feature, self._model_config.user_dropout_rate,
self._is_training)
user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg,
'user_preference', self._is_training)
prefer_feature = user_prefer_dnn(user_prefer_feature)
Expand All @@ -118,9 +119,9 @@ def build_predict_graph(self):
content_feature = item_content_dnn(self.item_content_feature)
item_features.append(content_feature)
if self.item_preference_feature is not None:
item_prefer_feature = bernoulli_dropout(self.item_preference_feature,
self._model_config.item_dropout_rate,
self._is_training)
item_prefer_feature = bernoulli_dropout(
self.item_preference_feature, self._model_config.item_dropout_rate,
self._is_training)
item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg,
'item_preference', self._is_training)
prefer_feature = item_prefer_dnn(item_prefer_feature)
Expand Down
15 changes: 11 additions & 4 deletions easy_rec/python/model/match_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,21 @@ def _build_list_wise_loss_graph(self):
indices = tf.concat([indices[:, None], indices[:, None]], axis=1)
hit_prob = tf.gather_nd(
self._prediction_dict['probs'][:batch_size, :batch_size], indices)

sample_weights = tf.cast(tf.squeeze(self._sample_weight), tf.float32)
self._loss_dict['cross_entropy_loss'] = -tf.reduce_mean(
tf.log(hit_prob + 1e-12) * tf.squeeze(self._sample_weight))
tf.log(hit_prob + 1e-12) *
sample_weights) / tf.reduce_mean(sample_weights)

logging.info('softmax cross entropy loss is used')

user_features = self._prediction_dict['user_tower_emb']
pos_item_emb = self._prediction_dict['item_tower_emb'][:batch_size]
pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1)
# if pos_simi < 0, produce loss
reg_pos_loss = tf.nn.relu(-pos_simi)
self._loss_dict['reg_pos_loss'] = tf.reduce_mean(reg_pos_loss)
self._loss_dict['reg_pos_loss'] = tf.reduce_mean(
reg_pos_loss * sample_weights) / tf.reduce_mean(sample_weights)

# the AMM loss for DAT model
if all([
Expand All @@ -235,10 +240,12 @@ def _build_list_wise_loss_graph(self):
]):
self._loss_dict['amm_loss_u'] = tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_u'] -
self._prediction_dict['augmented_p_i'][:batch_size]))
self._prediction_dict['augmented_p_i'][:batch_size]) *
sample_weights) / tf.reduce_mean(sample_weights)
self._loss_dict['amm_loss_i'] = tf.reduce_mean(
tf.square(self._prediction_dict['augmented_a_i'][:batch_size] -
self._prediction_dict['augmented_p_u']))
self._prediction_dict['augmented_p_u']) *
sample_weights) / tf.reduce_mean(sample_weights)

else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))
Expand Down

0 comments on commit cd47816

Please sign in to comment.