Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CEWithChunkedOutputLoss does not check division by zero #2341

Open
pocca2048 opened this issue Feb 4, 2025 · 6 comments
Open

CEWithChunkedOutputLoss does not check division by zero #2341

pocca2048 opened this issue Feb 4, 2025 · 6 comments
Assignees
Labels
discussion Start a discussion triaged This issue has been assigned an owner and appropriate label

Comments

@pocca2048
Copy link

Similar to #2225, CEWithChunkedOutputLoss does not check division by zero, too.
This makes a loss nan.

total_elements = (labels != self.ignore_index).sum()
...
return total_loss / total_elements
@joecummings
Copy link
Contributor

This is an interesting point. If you pass a bunch of masked out labels to PyTorch's regular Cross Entropy calculation, you actually get nan, not 0.0.

ignore_index = -100
batch_size = 2
num_tokens = 10
vocab_size = 10
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16)
# Labels are all set to ignore_index
labels = torch.full((batch_size, num_tokens), ignore_index, dtype=torch.long)
logits = logits.reshape(-1, logits.size(-1))
labels = labels.reshape(-1)
standard_loss = torch.nn.functional.cross_entropy(
    logits.float(), labels, reduction="mean", ignore_index=ignore_index
)
print(standard_loss)

I would lean towards keeping our calculation the same as the regular PyTorch core's implementation of Cross Entropy, but would like to hear from @felipemello1.

@joecummings joecummings added the discussion Start a discussion label Feb 4, 2025
@joecummings joecummings self-assigned this Feb 4, 2025
@joecummings joecummings added the triaged This issue has been assigned an owner and appropriate label label Feb 4, 2025
@joecummings
Copy link
Contributor

After discussing offline with @felipemello1, we agree to stick with the current implementation, which matches what you would expect when using torch.nn.CrossEntropy

@felipemello1
Copy link
Contributor

@pocca2048 , out of curiosity, why would your dataset have no labels? If we change the loss, wouldnt it be a silent error?

@felipemello1
Copy link
Contributor

@joecummings , should we change KL loss back, so its consistent?

@pocca2048
Copy link
Author

@felipemello1
It happens when we use train_on_input=False and message is too long that output is truncated.

@felipemello1
Copy link
Contributor

what do you think about: #2344?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Start a discussion triaged This issue has been assigned an owner and appropriate label
Projects
None yet
Development

No branches or pull requests

3 participants