Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for NaN's when training DASM with ambiguous sequences #93

Merged
merged 7 commits into from
Dec 10, 2024

Conversation

willdumm
Copy link
Contributor

The fix is one line, applying a mask that should have been applied from the very beginning, but previously didn't matter much since the data didn't contain ambiguities.

Also adds parallelized neutral model application, moved to the cpu. This makes running some notebooks much faster, and also speeds up setup for training Snakemake runs.

@@ -195,11 +196,10 @@ def loss_of_batch(self, batch):
# logit space, so we are set up for using the cross entropy loss.
# However we have to mask out the sites that are not substituted, i.e.
# the sites for which aa_subs_indicator is 0.
subs_mask = aa_subs_indicator == 1
subs_mask = (aa_subs_indicator == 1) & mask
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is it, the whole Nan/inf fix! Could you have a look around here and make sure this looks reasonable to you, too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice sleuthing! It seems good to me... am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, just wanted extra eyes on it! Thanks

@willdumm willdumm requested a review from matsen December 10, 2024 19:06
Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@@ -195,11 +196,10 @@ def loss_of_batch(self, batch):
# logit space, so we are set up for using the cross entropy loss.
# However we have to mask out the sites that are not substituted, i.e.
# the sites for which aa_subs_indicator is 0.
subs_mask = aa_subs_indicator == 1
subs_mask = (aa_subs_indicator == 1) & mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice sleuthing! It seems good to me... am I missing something?

@willdumm willdumm merged commit 22c8873 into main Dec 10, 2024
2 checks passed
@willdumm willdumm deleted the wd-nan-ambig-fix branch December 10, 2024 19:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants