From 38dfe80bb246820485266ecac069efa0a1cf1cd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=AB=E8=8B=8F?= Date: Mon, 11 Nov 2024 19:57:34 +0800 Subject: [PATCH] add ZILN loss for ltv prediction task --- easy_rec/python/loss/zero_inflated_lognormal.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/easy_rec/python/loss/zero_inflated_lognormal.py b/easy_rec/python/loss/zero_inflated_lognormal.py index 8124edc0c..da1e03d25 100644 --- a/easy_rec/python/loss/zero_inflated_lognormal.py +++ b/easy_rec/python/loss/zero_inflated_lognormal.py @@ -49,6 +49,8 @@ def zero_inflated_lognormal_loss(labels, logits, name=''): """ loss_name = name if name else 'ziln_loss' labels = tf.cast(labels, dtype=tf.float32) + if labels.shape.ndims == 1: + labels = tf.expand_dims(labels, 1) # [B, 1] positive = tf.cast(labels > 0, tf.float32) logits = tf.convert_to_tensor(logits, dtype=tf.float32)