diff --git a/easy_rec/python/loss/jrc_loss.py b/easy_rec/python/loss/jrc_loss.py index 30c019a77..9ffe5b518 100644 --- a/easy_rec/python/loss/jrc_loss.py +++ b/easy_rec/python/loss/jrc_loss.py @@ -1,7 +1,7 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging - +import numpy as np import tensorflow as tf if tf.__version__ >= '2.0': @@ -66,8 +66,6 @@ def jrc_loss(labels, pairwise_weights = tf.tile(weights, tf.stack([batch_size, 1])) y_pos *= pairwise_weights y_neg *= pairwise_weights - else: - assert sample_weights == 1.0, 'invalid sample_weight %d' % sample_weights # Compute list-wise generative loss -log p(x|y, z) if same_label_loss: @@ -124,4 +122,6 @@ def jrc_loss(labels, else: raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' % loss_weight_strategy) + if np.isscalar(sample_weights): + return loss * sample_weights return loss