Skip to content

Commit

Permalink
add ZILN loss for ltv prediction task
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Nov 11, 2024
1 parent 5259480 commit 38dfe80
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions easy_rec/python/loss/zero_inflated_lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def zero_inflated_lognormal_loss(labels, logits, name=''):
"""
loss_name = name if name else 'ziln_loss'
labels = tf.cast(labels, dtype=tf.float32)
if labels.shape.ndims == 1:
labels = tf.expand_dims(labels, 1) # [B, 1]
positive = tf.cast(labels > 0, tf.float32)

logits = tf.convert_to_tensor(logits, dtype=tf.float32)
Expand Down

0 comments on commit 38dfe80

Please sign in to comment.