Skip to content

Commit

Permalink
Add ModelSamplingStableCascade to control the shift sampling parameter.
Browse files Browse the repository at this point in the history
shift is 2.0 by default on Stage C and 1.0 by default on Stage B.
  • Loading branch information
comfyanonymous committed Feb 18, 2024
1 parent 6bcf57f commit 8b60d33
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
12 changes: 8 additions & 4 deletions comfy/model_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,16 @@ def __init__(self, model_config=None):
else:
sampling_settings = {}

self.num_timesteps = 1000
self.shift = sampling_settings.get("shift", 1.0)
cosine_s=8e-3
self.set_parameters(sampling_settings.get("shift", 1.0))

def set_parameters(self, shift=1.0, cosine_s=8e-3):
self.shift = shift
self.cosine_s = torch.tensor(cosine_s)
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2

#This part is just for compatibility with some schedulers in the codebase
self.num_timesteps = 1000
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
for x in range(self.num_timesteps):
t = x / self.num_timesteps
sigmas[x] = self.sigma(t)
Expand Down
27 changes: 27 additions & 0 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,32 @@ class ModelSamplingAdvanced(sampling_base, sampling_type):
m.add_object_patch("model_sampling", model_sampling)
return (m, )

class ModelSamplingStableCascade:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}),
}}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model, shift):
m = model.clone()

sampling_base = comfy.model_sampling.StableCascadeSampling
sampling_type = comfy.model_sampling.EPS

class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )

class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -171,5 +197,6 @@ def rescale_cfg(args):
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"ModelSamplingStableCascade": ModelSamplingStableCascade,
"RescaleCFG": RescaleCFG,
}

0 comments on commit 8b60d33

Please sign in to comment.