Skip to content

Commit

Permalink
Add Kohaky LoNyu Yog sampler, ruff workflow and pyproject
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterChief95 committed Nov 22, 2024
1 parent 19e8c07 commit af682de
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
11 changes: 11 additions & 0 deletions extra_euler_samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,14 @@
from .euler_max import sample_euler_max
from .euler_negative import sample_euler_negative
from .kohaku_lonyu_yog import sample_kohaku_lonyu_yog

__all__ = [
"sample_euler_dy",
"sample_euler_dy_negative",
"sample_euler_smea",
"sample_euler_smea_dy",
"sample_euler_smea_dy_negative",
"sample_euler_max",
"sample_euler_negative",
"sample_kohaku_lonyu_yog",
]
24 changes: 18 additions & 6 deletions extra_euler_samplers/kohaku_lonyu_yog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,39 @@


@torch.no_grad()
def sample_kohaku_lonyu_yog(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
s_tmax=float('inf'), s_noise=1., noise_sampler=None, eta=1.):
def sample_kohaku_lonyu_yog(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
noise_sampler=None,
eta=1.0,
):
"""Kohaku_LoNyu_Yog"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
dt = sigma_down - sigmas[i]

if i <= (len(sigmas) - 1) / 2:
x2 = - x
x2 = -x
denoised2 = model(x2, sigma_hat * s_in, **extra_args)
d2 = to_d(x2, sigma_hat, denoised2)

Expand Down
3 changes: 2 additions & 1 deletion scripts/extra_euler_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# See modules_forge/alter_samplers.py for the basis of this class and build_constructor function


class ExtraSampler(KDiffusionSampler):
"""
Overloads KDiffusionSampler to add extra parameters to the constructor
Expand All @@ -17,7 +18,7 @@ def __init__(self, sd_model, sampler_name, options=None):
self.sampler_name = sampler_name
self.unet = sd_model.forge_objects.unet
sampler_function = getattr(extra_euler_samplers, sampler_name)

super().__init__(sampler_function, sd_model, options)

self.extra_params = ["s_churn", "s_tmin", "s_tmax", "s_noise"]
Expand Down

0 comments on commit af682de

Please sign in to comment.