Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up mask sampling with rejection sampling #2585

Merged

Conversation

anc2001
Copy link
Contributor

@anc2001 anc2001 commented Nov 6, 2023

Speeds up mask sampling and avoids OOM errors for large images / masks by avoiding torch.nonzero with rejection sampling. Found that this solves many of the issues related to inefficient sampling with masks.

Ex.
rays / sec without masks
no_masks

rays / sec with current mask sampling
masks_old

rays / sec with new mask sampling
masks_ne

@KevinXu02
Copy link
Contributor

KevinXu02 commented Nov 7, 2023

The speeding up is really significant ,but when it comes to sparse data such as lidar it will mostly fail. I think it may be added as an optional method instead of changing the original code.

@blacksino
Copy link
Contributor

you could save nonzero index as a member either.

@akristoffersen
Copy link
Contributor

I recently relied on this PR to dramatically speed up training with masks. @anc2001 would you be interested in getting this PR over the finish line? Perhaps with a optional flag as @KevinXu02 suggests, though I think it may make more sense for this to be the default behavior, and revert to the slower behavior with an optional flag.

@anc2001
Copy link
Contributor Author

anc2001 commented Jan 8, 2024

Yes @akristoffersen I'd love to push this PR over the finish line! Just been a little busy recently and have had some issues setting up the dev environment.

@machenmusik
Copy link
Contributor

machenmusik commented Jan 8, 2024

Do we have any usage data on how often this is used with sparse data such as lidar? My impression is that sparse data usage is less common, and so enabling this option by default would be better for the majority.

@anc2001
Copy link
Contributor Author

anc2001 commented Jan 10, 2024

A better solution than the optional flag might be some kind of adaptive thresholding where if the number of valid pixels in the mask is below a certain threshold percentage revert functionality back to original and if above rejection sample. This might be better for datasets with a mix of sparse and dense masks, but I'm not sure if that's necessary so will leave to any future PRs if people want that kind functionality.

@akristoffersen akristoffersen self-requested a review January 11, 2024 04:29
Copy link
Contributor

@akristoffersen akristoffersen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, small comment but otherwise looks good to merge :D

).long()
indices[~chosen_indices_validity] = replacement_indices

if num_valid != batch_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of just raising warning, I think it would make sense to default back to the slow non-rejection sampling if this occurs. I would still issue a warning, but it would suck for training to fail on the off chance that not enough valid indices are generated in time.

Copy link
Contributor

@akristoffersen akristoffersen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit: With the most recent change, if the first iter fails with rejection sampling, a warning will be generated but the indices that will be returned will not necessarily only contained valid locations within the mask. Every other sampling will be fine, but the first one will still have invalid indices.

I think instead, we should throw away the current generated indices and start from scratch with the non-rejection sampling method, as well as change the config flag to use the non-rejection sampling method for all future iterations.

@akristoffersen
Copy link
Contributor

LGTM, thanks again @anc2001! feel free to merge :D

@anc2001
Copy link
Contributor Author

anc2001 commented Jan 18, 2024

Hey @akristoffersen, I don't have write access to merge this PR! Do you know who we would need to @ to get this merged?

@akristoffersen akristoffersen enabled auto-merge (squash) January 18, 2024 18:11
@akristoffersen
Copy link
Contributor

That would be me! Should merge once all checks are done

@akristoffersen akristoffersen merged commit a78ca29 into nerfstudio-project:main Jan 18, 2024
4 checks passed
ArpegorPSGH pushed a commit to ArpegorPSGH/nerfstudio that referenced this pull request Jun 22, 2024
* change masked pixel sampling to use rejection sampling instead of torch.nonzero

* black reformat code

* pyright unbound variable num_valid

* pyright type issues with num_valid

* add configuration settings for rejection sampling masks

* black reformat

* maybe this fixes it?

* revert behavior if mask sampling failed, still raise warning

* on iteration failure, use non-rejection sampling to generate indices

* ruff

---------

Co-authored-by: adrian_chang <[email protected]>
Co-authored-by: Alexander Kristoffersen <[email protected]>
@nepfaff
Copy link
Contributor

nepfaff commented Sep 26, 2024

@anc2001, did you get these timing results with --pipeline.datamanager.masks-on-gpu True --pipeline.datamanager.images-on-gpu True or without these options? Not including these still seem to be slow for me but maybe I have a bug somewhere

I only seem to have these problems when using depth with depth-nerfacto

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants