Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference between latents requires_grad=True and torch.no_grad() #34

Open
gvalvano opened this issue May 4, 2023 · 2 comments
Open

Comments

@gvalvano
Copy link

gvalvano commented May 4, 2023

Thanks for sharing such an amazing work :)

In the last section of the notebook Stable Diffusion Deep Dive.ipynb, you mention:

NB: We should set latents requires_grad=True before we do the forward pass of the unet (removing with torch.no_grad()) if we want mode accurate gradients. BUT this requires a lot of extra memory. You'll see both approaches used depending on whose implementation you're looking at.

Can you please clarify what is the difference between the two approaches? For example, if I had to code this, I would have used torch.no_grad(), but apparently you preferred another approach. What does it change computationally and results-wise?.

I think adding this as extra info to the notebook would be useful to others, too :)

@johnowhitaker
Copy link
Collaborator

If we set requires_grad=True AFTER getting the noise prediction from the unet (the example shown), then he gradient of the loss function w.r.t the latents tells us "how do I change these latents such that when I remove this noise it looks good"

If we set requires_grad=True BEFORE getting the noise prediction from the unet, then the noise prediction depends on the latents and the gradients can be traced back through the unet. So they tell us "how do I change these latents such that WHEN I FEED THEM THROUGH THE UNET AND THEN REMOVE THE PREDICTED NOISE it looks good".

The second case reflects what actually happens during sampling. We want to tweak the latents such that the final result (based on a prediction made with those modified latents) minimizes our loss. The first case tweaks the latents such that a prediction based on the unmodified latents minimizes the loss. It's a subtle difference, but especially in more complicated cases than the demo it does make a difference. For example, with CLIP guidance (where you try to minimize the difference between the generated image and a text or image prompt in CLIP embedding space) the results with the non-shortcut method seem to be better. The downside is much higher memory and compute usage, since the gradients need to be traced back through the UNet.

If you're interested in guidance, there are other improvements worth looking at - https://arxiv.org/abs/2301.11558 shows an example of CLIP guidance in their repo (https://github.com/sWizad/split-diffusion) that uses some clever maths to make the guidance more stable. I tend to use one of their approaches in all my guidance stuff these days.

@gvalvano
Copy link
Author

gvalvano commented May 4, 2023

I see. This is a very interesting topic and it looks much clearer to me, now 😁
I will check the references, too. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants