diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index a1f679c..bfa8f44 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -4,6 +4,7 @@ import torch from torch import nn from torchdiffeq import odeint +import torchsde from tqdm.auto import trange, tqdm from . import utils @@ -50,6 +51,62 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): return sigma_down, sigma_up +def default_noise_sampler(x): + return lambda sigma, sigma_next: torch.randn_like(x) + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -72,9 +129,10 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.): +def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -85,7 +143,8 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis # Euler method dt = sigma_down - sigmas[i] x = x + d * dt - x = x + torch.randn_like(x) * sigma_up + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x @@ -150,9 +209,10 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.): - """Ancestral sampling with DPM-Solver inspired second-order steps.""" +def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -173,7 +233,7 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 - x = x + torch.randn_like(x) * sigma_up + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x @@ -318,7 +378,8 @@ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) return x_3, eps_cache - def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.): + def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler if not t_end > t_start and eta: raise ValueError('eta must be 0 for reverse sampling') @@ -352,11 +413,12 @@ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.): else: x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) - x = x + su * s_noise * torch.randn_like(x) + x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) return x - def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1.): + def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler if order not in {2, 3}: raise ValueError('order should be 2 or 3') forward = t_end > t_start @@ -395,7 +457,7 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078 accept = pid.propose_step(error) if accept: x_prev = x_low - x = x_high + su * s_noise * torch.randn_like(x_high) + x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) s = t info['n_accept'] += 1 else: @@ -410,7 +472,7 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078 @torch.no_grad() -def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1.): +def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: raise ValueError('sigma_min and sigma_max must not be 0') @@ -418,11 +480,11 @@ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) - return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise) + return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler) @torch.no_grad() -def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., return_info=False): +def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: raise ValueError('sigma_min and sigma_max must not be 0') @@ -430,16 +492,17 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) - x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) + x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) if return_info: return x, info return x @torch.no_grad() -def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1.): +def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() @@ -455,7 +518,7 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, dt = sigma_down - sigmas[i] x = x + d * dt else: - # DPM-Solver-2++(2S) + # DPM-Solver++(2S) t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) r = 1 / 2 h = t_next - t @@ -464,7 +527,50 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 # Noise addition - x = x + torch.randn_like(x) * s_noise * sigma_up + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): + """DPM-Solver++ (stochastic).""" + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigmas[i + 1] - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++ + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + s = t + h * r + fac = 1 / (2 * r) + + # Step 1 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) + s_ = t_fn(sd) + x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised + x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + + # Step 2 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) + t_next_ = t_fn(sd) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d + x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su return x diff --git a/setup.cfg b/setup.cfg index e863cf9..841c25e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ install_requires = scipy torch torchdiffeq + torchsde torchvision tqdm wandb