Skip to content

Commit

Permalink
Noise samplers (#46)
Browse files Browse the repository at this point in the history
* Add noise sampler support

* Fix bug in initial implementation of noise samplers

* Fix seed bug in initial implementation of noise samplers

* Negate BatchedBrownianTree integrals if t0 > t1

* Negate BatchedBrownianTree integrals if t0 > t1 on creation

* Fix bug when noise_sampler was not provided to some samplers

* Noise sampler tweaks

* Add correct stochastic DPM-Solver++
  • Loading branch information
crowsonkb authored Nov 20, 2022
1 parent 60e5042 commit 7621f11
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 16 deletions.
138 changes: 122 additions & 16 deletions k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import nn
from torchdiffeq import odeint
import torchsde
from tqdm.auto import trange, tqdm

from . import utils
Expand Down Expand Up @@ -50,6 +51,62 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
return sigma_down, sigma_up


def default_noise_sampler(x):
return lambda sigma, sigma_next: torch.randn_like(x)


class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""

def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]

@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)

def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]


class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""

def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)

def __call__(self, sigma, sigma_next):
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()


@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
Expand All @@ -72,9 +129,10 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,


@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.):
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
Expand All @@ -85,7 +143,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x


Expand Down Expand Up @@ -150,9 +209,10 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,


@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.):
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
Expand All @@ -173,7 +233,7 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x


Expand Down Expand Up @@ -318,7 +378,8 @@ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
return x_3, eps_cache

def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.):
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
if not t_end > t_start and eta:
raise ValueError('eta must be 0 for reverse sampling')

Expand Down Expand Up @@ -352,11 +413,12 @@ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.):
else:
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)

x = x + su * s_noise * torch.randn_like(x)
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))

return x

def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1.):
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
forward = t_end > t_start
Expand Down Expand Up @@ -395,7 +457,7 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078
accept = pid.propose_step(error)
if accept:
x_prev = x_low
x = x_high + su * s_noise * torch.randn_like(x_high)
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
s = t
info['n_accept'] += 1
else:
Expand All @@ -410,36 +472,37 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078


@torch.no_grad()
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1.):
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(total=n, disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise)
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)


@torch.no_grad()
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., return_info=False):
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise)
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
if return_info:
return x, info
return x


@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1.):
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
Expand All @@ -455,7 +518,7 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2++(2S)
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
r = 1 / 2
h = t_next - t
Expand All @@ -464,7 +527,50 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
# Noise addition
x = x + torch.randn_like(x) * s_noise * sigma_up
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x


@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()

for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
fac = 1 / (2 * r)

# Step 1
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)

# Step 2
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
return x


Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ install_requires =
scipy
torch
torchdiffeq
torchsde
torchvision
tqdm
wandb

0 comments on commit 7621f11

Please sign in to comment.