diff --git a/examples/dailymail/analyze.py b/examples/dailymail/analyze.py index bd214ac..60ba64d 100644 --- a/examples/dailymail/analyze.py +++ b/examples/dailymail/analyze.py @@ -121,7 +121,7 @@ def compute_train_loss( ).flatten() masks = batch["labels"].view(-1) == -100 sampled_labels[masks] = -100 - return F.cross_entropy(logits, sampled_labels, reduction="sum") + return F.cross_entropy(logits.view(-1, logits.size(-1)), sampled_labels, reduction="sum") def compute_measurement( self,