Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split up sample() to allow backwards pass #4

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
[submodule "CLIP"]
path = CLIP
url = https://github.com/openai/CLIP
1 change: 0 additions & 1 deletion CLIP
Submodule CLIP deleted from 40f548
72 changes: 71 additions & 1 deletion diffusion/sampling.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,79 @@

from . import utils

# These 4 sample_foo functions are subroutines called by sample()
def sample_step_pred(model, x, steps, eta, extra_args, ts, alphas, sigmas, i):
# Get the model output (v, the predicted velocity)
with torch.cuda.amp.autocast():
v = model(x, ts * steps[i], **extra_args).float()

# DDPM/DDIM sampling
# Predict the noise and the denoised image
pred = x * alphas[i] - v * sigmas[i]

return pred, v


def sample_step_noise(model, x, steps, eta, extra_args, ts, alphas, sigmas, i, pred, v):
eps = x * sigmas[i] + v * alphas[i]

# If we are not on the last timestep, compute the noisy image for the
# next timestep.
if i < len(steps) - 1:
# If eta > 0, adjust the scaling factor for the predicted noise
# downward according to the amount of additional noise to add
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()

# Recombine the predicted noise and predicted denoised image in the
# correct proportions for the next step
x = pred * alphas[i + 1] + eps * adjusted_sigma

# Add the correct amount of fresh noise

if eta:
x = x + torch.randn_like(x) * ddim_sigma

return x

def sample_setup(model, x, steps, eta, extra_args):
"""Draws samples from a model given starting noise."""

# print("SAMPLE SETUP ", steps.shape)
ts = x.new_ones([x.shape[0]])

# Create the noise schedule
alphas, sigmas = utils.t_to_alpha_sigma(steps)

sample_state = [model, steps, eta, extra_args, ts, alphas, sigmas]
return sample_state

def sample_step(sample_state, x, i, last_pred, last_v):
model, steps, eta, extra_args, ts, alphas, sigmas = sample_state
pred, v = sample_step_pred(model, x, steps, eta, extra_args, ts, alphas, sigmas, i)
return pred, v, x


def sample_noise(sample_state, x, i, last_pred, last_v):
model, steps, eta, extra_args, ts, alphas, sigmas = sample_state
if last_pred != None:
x = sample_step_noise(model, x, steps, eta, extra_args, ts, alphas, sigmas, i, last_pred, last_v)
return x

# this new version of sample calls the above four functions to do the work
def sample_split(model, x, steps, eta, extra_args):
pred = None
v = None
sample_state = sample_setup(model, x, steps, eta, extra_args)
for i in trange(len(steps)):
pred, v, x = sample_step(sample_state, x, i, pred, v)
x = sample_noise(sample_state, x, i, pred, v)

return pred

# this is the original version of sample which did everything at once

# DDPM/DDIM sampling
@torch.no_grad()
def sample(model, x, steps, eta, extra_args, callback=None):
"""Draws samples from a model given starting noise."""
Expand Down