Skip to content

Commit

Permalink
Merge pull request #1 from MisterChief95/kohaku_sampler
Browse files Browse the repository at this point in the history
Kohaku sampler
  • Loading branch information
MisterChief95 authored Nov 22, 2024
2 parents d61f052 + af682de commit ccfe635
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 1 deletion.
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v1
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The source for these new samplers come from projects created by [Koishi-Star](ht
- Euler SMEA
- Euler SMEA Dy
- Euler SMEA Dy Negative
- Kohaku LoNyu Yog

### Comparison

Expand Down
12 changes: 12 additions & 0 deletions extra_euler_samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,15 @@
from .euler_smea_dy_negative import sample_euler_smea_dy_negative
from .euler_max import sample_euler_max
from .euler_negative import sample_euler_negative
from .kohaku_lonyu_yog import sample_kohaku_lonyu_yog

__all__ = [
"sample_euler_dy",
"sample_euler_dy_negative",
"sample_euler_smea",
"sample_euler_smea_dy",
"sample_euler_smea_dy_negative",
"sample_euler_max",
"sample_euler_negative",
"sample_kohaku_lonyu_yog",
]
55 changes: 55 additions & 0 deletions extra_euler_samplers/kohaku_lonyu_yog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch

from k_diffusion.sampling import default_noise_sampler, get_ancestral_step, to_d

from tqdm.auto import trange


@torch.no_grad()
def sample_kohaku_lonyu_yog(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
noise_sampler=None,
eta=1.0,
):
"""Kohaku_LoNyu_Yog"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
dt = sigma_down - sigmas[i]

if i <= (len(sigmas) - 1) / 2:
x2 = -x
denoised2 = model(x2, sigma_hat * s_in, **extra_args)
d2 = to_d(x2, sigma_hat, denoised2)

x3 = x + ((d + d2) / 2) * dt
denoised3 = model(x3, sigma_hat * s_in, **extra_args)
d3 = to_d(x3, sigma_hat, denoised3)

real_d = (d + d3) / 2
x = x + real_d * dt

x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
else:
x = x + d * dt
return x
78 changes: 78 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]

# Same as Black.
line-length = 120
indent-width = 4

# Assume Python 3.8
target-version = "py310"

[tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E4", "E7", "E9", "F"]
ignore = []

# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"

# Like Black, indent with spaces, rather than tabs.
indent-style = "space"

# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

# Enable auto-formatting of code examples in docstrings. Markdown,
# reStructuredText code/literal blocks and doctests are all supported.
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = false

# Set the line length limit used when formatting code snippets in
# docstrings.
#
# This only has an effect when the `docstring-code-format` setting is
# enabled.
docstring-code-line-length = "dynamic"
4 changes: 3 additions & 1 deletion scripts/extra_euler_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# See modules_forge/alter_samplers.py for the basis of this class and build_constructor function


class ExtraSampler(KDiffusionSampler):
"""
Overloads KDiffusionSampler to add extra parameters to the constructor
Expand All @@ -17,7 +18,7 @@ def __init__(self, sd_model, sampler_name, options=None):
self.sampler_name = sampler_name
self.unet = sd_model.forge_objects.unet
sampler_function = getattr(extra_euler_samplers, sampler_name)

super().__init__(sampler_function, sd_model, options)

self.extra_params = ["s_churn", "s_tmin", "s_tmax", "s_noise"]
Expand All @@ -38,6 +39,7 @@ def constructor(m):
("Euler SMEA", "sample_euler_smea", ["k_euler_smea"], {}),
("Euler SMEA Dy", "sample_euler_smea_dy", ["k_euler_smea_dy"], {}),
("Euler SMEA Dy Negative", "sample_euler_smea_dy_negative", ["k_euler_smea_dy_negative"], {}),
("Kohaky LoNyu Yog", "sample_kohaku_lonyu_yog", ["k_kohaku_lonyu_yog"], {}),
]

euler_samplers_data_k_diffusion: list[SamplerData] = [
Expand Down

0 comments on commit ccfe635

Please sign in to comment.