-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
speedup mask sampling, properly save masks in images pipeline #2346
base: main
Are you sure you want to change the base?
Conversation
Added a change for speeding up mask sampling by caching the mask non-zero indices, it speeds up sampling a few orders of magnitude (before mask sampling was pretty much unusable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
chosen_indices = random.sample(range(len(nonzero_indices)), k=batch_size) | ||
indices = nonzero_indices[chosen_indices] | ||
if not hasattr(self, "nonzero_indices"): | ||
self.nonzero_indices = torch.nonzero(mask[..., 0], as_tuple=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems a bit weird if we call this function twice but with different mask
args, the second will be ignored
maybe we can refactor the method to take nonzero_indices
as an argument instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.nonzero
may require a lot of memory(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems a bit weird if we call this function twice but with different
mask
args, the second will be ignoredmaybe we can refactor the method to take
nonzero_indices
as an argument instead?
@kerrj does this implementation assume that each batch has the same set of masks? I think it'd be nice / more correct if it handles the case where there are different masks for different images.
Also, it looks like the goal is to avoid redundant calls to nonzero
. If each image has a different mask, would it make sense to use an img_to_nonzero_indices
dict or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@brentyi agreed, I wanted to keep the interface the same though since providing nonzero_indices kicks the can up to the callee, which would run into the same problem of changing masks.
@Ilyabasharov It requires the same amount of memory, since the indices are stored on CPU RAM. I tested with/without and it's the same, intuitively because to instantiate the nonzero_indices array each step you have to allocate the memory anyway, then immediately deallocate it which in practice is essentially the same as keeping it allocated since it happens every single step.
@kevin-thankyou-lin the mask parameter is shape NxHxW, so it includes all the images in the dataset. If the masks are different for each image it will take that into account.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kerrj thanks for explanation! I've tested torch.nonzero on gpu with large images and have faced with OOM :( but if we use CPU it will be much better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@brentyi agreed, I wanted to keep the interface the same though since providing nonzero_indices kicks the can up to the callee, which would run into the same problem of changing masks.
Agree that the cache hit / miss logic still needs to be solved, but kicking the can up seems nice. As a heuristic it seems ideal to avoid statefulness in lower-level primitives like pixel sampling, it seems like a risk for memory leaks, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kevin-thankyou-lin's point, the masks will only be of shape N,H,W is the entire dataset is cached, right? If someone uses the dataset without caching it all, then this will cause issues I think. Is that true @kerrj?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ethanweber if that's true I think the current behavior is also bugged, since it will always ever sample pixels from the provided masks
Hi, not sure what the status of this PR is, but wanted to suggest a speed up to mask sampling that has worked for me: #2585 |
Masks were not saved properly in the images processing pipeline; now they are