From 4314f9101a2f3bd7f11ba4290d2a7e2e64b4ceea Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 23 Nov 2022 09:32:43 -0800 Subject: [PATCH] Add CompVisVDenoiser wrapper --- k_diffusion/external.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 2f1d258..79b51ce 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -136,3 +136,42 @@ def __init__(self, model, quantize=False, device='cpu'): def get_eps(self, *args, **kwargs): return self.inner_model.apply_model(*args, **kwargs) + + +class DiscreteVDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output v.""" + + def __init__(self, model, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.inner_model = model + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_skip, c_out, c_in + + def get_v(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def loss(self, input, noise, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * utils.append_dims(sigma, input.ndim) + model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + target = (input - c_skip * noised_input) / c_out + return (model_output - target).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip + + +class CompVisVDenoiser(DiscreteVDDPMDenoiser): + """A wrapper for CompVis diffusion models that output v.""" + + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + + def get_v(self, x, t, cond, **kwargs): + return self.inner_model.apply_model(x, t, cond)