diff --git a/easy_rec/python/inference/predictor.py b/easy_rec/python/inference/predictor.py index fa0853ea5..c57092ab5 100644 --- a/easy_rec/python/inference/predictor.py +++ b/easy_rec/python/inference/predictor.py @@ -534,7 +534,9 @@ def _parse_value(all_vals): ] for k in self._reserved_cols: if k in all_vals and all_vals[k].dtype == np.object: - all_vals[k] = [val.decode('utf-8', errors='ignore') for val in all_vals[k]] + all_vals[k] = [ + val.decode('utf-8', errors='ignore') for val in all_vals[k] + ] ts2 = time.time() reserve_vals = self._get_reserve_vals(self._reserved_cols, diff --git a/easy_rec/python/model/dropoutnet.py b/easy_rec/python/model/dropoutnet.py index b4c30e77c..683677531 100644 --- a/easy_rec/python/model/dropoutnet.py +++ b/easy_rec/python/model/dropoutnet.py @@ -21,6 +21,7 @@ def cosine_similarity(user_emb, item_emb): tf.multiply(user_emb, item_emb), axis=1, name='cosine') return user_item_sim + def bernoulli_dropout(x, rate, training=False): if rate == 0.0 or not training: return x @@ -90,9 +91,9 @@ def build_predict_graph(self): content_feature = user_content_dnn(self.user_content_feature) user_features.append(content_feature) if self.user_preference_feature is not None: - user_prefer_feature = bernoulli_dropout(self.user_preference_feature, - self._model_config.user_dropout_rate, - self._is_training) + user_prefer_feature = bernoulli_dropout( + self.user_preference_feature, self._model_config.user_dropout_rate, + self._is_training) user_prefer_dnn = dnn.DNN(self.user_preference_layers, self._l2_reg, 'user_preference', self._is_training) prefer_feature = user_prefer_dnn(user_prefer_feature) @@ -118,9 +119,9 @@ def build_predict_graph(self): content_feature = item_content_dnn(self.item_content_feature) item_features.append(content_feature) if self.item_preference_feature is not None: - item_prefer_feature = bernoulli_dropout(self.item_preference_feature, - self._model_config.item_dropout_rate, - self._is_training) + item_prefer_feature = bernoulli_dropout( + self.item_preference_feature, self._model_config.item_dropout_rate, + self._is_training) item_prefer_dnn = dnn.DNN(self.item_preference_layers, self._l2_reg, 'item_preference', self._is_training) prefer_feature = item_prefer_dnn(item_prefer_feature) diff --git a/easy_rec/python/model/match_model.py b/easy_rec/python/model/match_model.py index 24cd1bcbe..2347f1aef 100644 --- a/easy_rec/python/model/match_model.py +++ b/easy_rec/python/model/match_model.py @@ -217,8 +217,12 @@ def _build_list_wise_loss_graph(self): indices = tf.concat([indices[:, None], indices[:, None]], axis=1) hit_prob = tf.gather_nd( self._prediction_dict['probs'][:batch_size, :batch_size], indices) + + sample_weights = tf.cast(tf.squeeze(self._sample_weight), tf.float32) self._loss_dict['cross_entropy_loss'] = -tf.reduce_mean( - tf.log(hit_prob + 1e-12) * tf.squeeze(self._sample_weight)) + tf.log(hit_prob + 1e-12) * + sample_weights) / tf.reduce_mean(sample_weights) + logging.info('softmax cross entropy loss is used') user_features = self._prediction_dict['user_tower_emb'] @@ -226,7 +230,8 @@ def _build_list_wise_loss_graph(self): pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1) # if pos_simi < 0, produce loss reg_pos_loss = tf.nn.relu(-pos_simi) - self._loss_dict['reg_pos_loss'] = tf.reduce_mean(reg_pos_loss) + self._loss_dict['reg_pos_loss'] = tf.reduce_mean( + reg_pos_loss * sample_weights) / tf.reduce_mean(sample_weights) # the AMM loss for DAT model if all([ @@ -235,10 +240,12 @@ def _build_list_wise_loss_graph(self): ]): self._loss_dict['amm_loss_u'] = tf.reduce_mean( tf.square(self._prediction_dict['augmented_a_u'] - - self._prediction_dict['augmented_p_i'][:batch_size])) + self._prediction_dict['augmented_p_i'][:batch_size]) * + sample_weights) / tf.reduce_mean(sample_weights) self._loss_dict['amm_loss_i'] = tf.reduce_mean( tf.square(self._prediction_dict['augmented_a_i'][:batch_size] - - self._prediction_dict['augmented_p_u'])) + self._prediction_dict['augmented_p_u']) * + sample_weights) / tf.reduce_mean(sample_weights) else: raise ValueError('invalid loss type: %s' % str(self._loss_type))