-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up mask sampling with rejection sampling #2585
Speed up mask sampling with rejection sampling #2585
Conversation
The speeding up is really significant ,but when it comes to sparse data such as lidar it will mostly fail. I think it may be added as an optional method instead of changing the original code. |
you could save nonzero index as a member either. |
I recently relied on this PR to dramatically speed up training with masks. @anc2001 would you be interested in getting this PR over the finish line? Perhaps with a optional flag as @KevinXu02 suggests, though I think it may make more sense for this to be the default behavior, and revert to the slower behavior with an optional flag. |
Yes @akristoffersen I'd love to push this PR over the finish line! Just been a little busy recently and have had some issues setting up the dev environment. |
Do we have any usage data on how often this is used with sparse data such as lidar? My impression is that sparse data usage is less common, and so enabling this option by default would be better for the majority. |
A better solution than the optional flag might be some kind of adaptive thresholding where if the number of valid pixels in the mask is below a certain threshold percentage revert functionality back to original and if above rejection sample. This might be better for datasets with a mix of sparse and dense masks, but I'm not sure if that's necessary so will leave to any future PRs if people want that kind functionality. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, small comment but otherwise looks good to merge :D
).long() | ||
indices[~chosen_indices_validity] = replacement_indices | ||
|
||
if num_valid != batch_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of just raising warning, I think it would make sense to default back to the slow non-rejection sampling if this occurs. I would still issue a warning, but it would suck for training to fail on the off chance that not enough valid indices are generated in time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nit: With the most recent change, if the first iter fails with rejection sampling, a warning will be generated but the indices that will be returned will not necessarily only contained valid locations within the mask. Every other sampling will be fine, but the first one will still have invalid indices.
I think instead, we should throw away the current generated indices and start from scratch with the non-rejection sampling method, as well as change the config flag to use the non-rejection sampling method for all future iterations.
LGTM, thanks again @anc2001! feel free to merge :D |
… into rejection_mask_sampling
Hey @akristoffersen, I don't have write access to merge this PR! Do you know who we would need to @ to get this merged? |
That would be me! Should merge once all checks are done |
* change masked pixel sampling to use rejection sampling instead of torch.nonzero * black reformat code * pyright unbound variable num_valid * pyright type issues with num_valid * add configuration settings for rejection sampling masks * black reformat * maybe this fixes it? * revert behavior if mask sampling failed, still raise warning * on iteration failure, use non-rejection sampling to generate indices * ruff --------- Co-authored-by: adrian_chang <[email protected]> Co-authored-by: Alexander Kristoffersen <[email protected]>
I only seem to have these problems when using depth with |
Speeds up mask sampling and avoids OOM errors for large images / masks by avoiding
torch.nonzero
with rejection sampling. Found that this solves many of the issues related to inefficient sampling with masks.Ex.

rays / sec without masks
rays / sec with current mask sampling

rays / sec with new mask sampling
