Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional quasirandom timesteps, zero terminal SNR, cosine schedule for SD models #138

Merged
merged 5 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 86 additions & 19 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ def stable_diffusion_2(
prediction_type: str = 'epsilon',
latent_mean: Union[float, Tuple, str] = 0.0,
latent_std: Union[float, Tuple, str] = 5.489980785067252,
beta_schedule: str = 'scaled_linear',
zero_terminal_snr: bool = False,
offset_noise: Optional[float] = None,
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
quasirandomness: bool = False,
train_seed: int = 42,
val_seed: int = 1138,
precomputed_latents: bool = False,
encode_latents_in_fp16: bool = True,
Expand Down Expand Up @@ -78,10 +82,17 @@ def stable_diffusion_2(
latent_std (float, list, str): The std. dev. of the autoencoder latents. Either a float for a single value,
a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `1/0.18215`.
beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'.
Default: `scaled_linear`.
zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
[MeanSquaredError()].
val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to
[MeanSquaredError()].
quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise.
Default: `False`.
train_seed (int): Seed to use for generating diffusion process noise during training if using
quasirandomness. Default: `42`.
val_seed (int): Seed to use for generating evaluation images. Defaults to 1138.
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not
Expand Down Expand Up @@ -145,14 +156,29 @@ def stable_diffusion_2(
assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple)

# Make the noise schedulers
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler')
inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps,
beta_start=noise_scheduler.config.beta_start,
beta_end=noise_scheduler.config.beta_end,
beta_schedule=noise_scheduler.config.beta_schedule,
trained_betas=noise_scheduler.config.trained_betas,
clip_sample=noise_scheduler.config.clip_sample,
set_alpha_to_one=noise_scheduler.config.set_alpha_to_one,
noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule=beta_schedule,
trained_betas=None,
variance_type='fixed_small',
clip_sample=False,
prediction_type=prediction_type,
thresholding=False,
dynamic_thresholding_ratio=0.995,
clip_sample_range=1.0,
Landanjs marked this conversation as resolved.
Show resolved Hide resolved
sample_max_value=1.0,
timestep_spacing='leading',
steps_offset=1,
rescale_betas_zero_snr=zero_terminal_snr)

inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule=beta_schedule,
trained_betas=None,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type)

# Make the composer model
Expand All @@ -170,6 +196,8 @@ def stable_diffusion_2(
offset_noise=offset_noise,
train_metrics=train_metrics,
val_metrics=val_metrics,
quasirandomness=quasirandomness,
train_seed=train_seed,
val_seed=val_seed,
precomputed_latents=precomputed_latents,
encode_latents_in_fp16=encode_latents_in_fp16,
Expand Down Expand Up @@ -207,9 +235,13 @@ def stable_diffusion_xl(
prediction_type: str = 'epsilon',
latent_mean: Union[float, Tuple, str] = 0.0,
latent_std: Union[float, Tuple, str] = 7.67754318618,
beta_schedule: str = 'scaled_linear',
zero_terminal_snr: bool = False,
offset_noise: Optional[float] = None,
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
quasirandomness: bool = False,
train_seed: int = 42,
val_seed: int = 1138,
precomputed_latents: bool = False,
encode_latents_in_fp16: bool = True,
Expand Down Expand Up @@ -247,12 +279,19 @@ def stable_diffusion_xl(
latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value,
a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `1/0.13025`.
beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'.
Default: `scaled_linear`.
zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`.
offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not
be used. Default `None`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
[MeanSquaredError()].
val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to
[MeanSquaredError()].
quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise.
Default: `False`.
train_seed (int): Seed to use for generating diffusion process noise during training if using
quasirandomness. Default: `42`.
val_seed (int): Seed to use for generating evaluation images. Defaults to 1138.
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
Expand Down Expand Up @@ -360,17 +399,43 @@ def stable_diffusion_xl(
resnet._fsdp_wrap = True

# Make the noise schedulers
noise_scheduler = DDPMScheduler.from_pretrained(unet_model_name, subfolder='scheduler')
inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
trained_betas=None,
prediction_type=prediction_type,
interpolation_type='linear',
use_karras_sigmas=False,
timestep_spacing='leading',
steps_offset=1)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule=beta_schedule,
trained_betas=None,
variance_type='fixed_small',
clip_sample=False,
prediction_type=prediction_type,
thresholding=False,
dynamic_thresholding_ratio=0.995,
clip_sample_range=1.0,
sample_max_value=1.0,
timestep_spacing='leading',
steps_offset=1,
rescale_betas_zero_snr=zero_terminal_snr)
if beta_schedule == 'squaredcos_cap_v2':
inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule=beta_schedule,
trained_betas=None,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
rescale_betas_zero_snr=zero_terminal_snr)
else:
inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule=beta_schedule,
trained_betas=None,
prediction_type=prediction_type,
interpolation_type='linear',
use_karras_sigmas=False,
timestep_spacing='leading',
steps_offset=1,
rescale_betas_zero_snr=zero_terminal_snr)

# Make the composer model
model = StableDiffusion(
Expand All @@ -387,6 +452,8 @@ def stable_diffusion_xl(
offset_noise=offset_noise,
train_metrics=train_metrics,
val_metrics=val_metrics,
quasirandomness=quasirandomness,
train_seed=train_seed,
val_seed=val_seed,
precomputed_latents=precomputed_latents,
encode_latents_in_fp16=encode_latents_in_fp16,
Expand Down
32 changes: 28 additions & 4 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn.functional as F
from composer.models import ComposerModel
from composer.utils import dist
from torchmetrics import MeanSquaredError
from tqdm.auto import tqdm

Expand Down Expand Up @@ -51,6 +52,10 @@ class StableDiffusion(ComposerModel):
Default: `None`.
val_metrics (list): List of torchmetrics to calculate during validation.
Default: `None`.
quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise.
Default: `False`.
train_seed (int): Seed to use for generating diffusion process noise during training if using
quasirandomness. Default: `42`.
val_seed (int): Seed to use for generating eval images. Default: `1138`.
image_key (str): The name of the image inputs in the dataloader batch.
Default: `image_tensor`.
Expand Down Expand Up @@ -85,6 +90,8 @@ def __init__(self,
offset_noise: Optional[float] = None,
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
quasirandomness: bool = False,
train_seed: int = 42,
val_seed: int = 1138,
image_key: str = 'image',
text_key: str = 'captions',
Expand All @@ -105,6 +112,8 @@ def __init__(self,
raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}')
self.downsample_factor = downsample_factor
self.offset_noise = offset_noise
self.quasirandomness = quasirandomness
self.train_seed = train_seed
self.val_seed = val_seed
self.image_key = image_key
self.image_latents_key = image_latents_key
Expand Down Expand Up @@ -140,13 +149,25 @@ def __init__(self,

# Optional rng generator
self.rng_generator: Optional[torch.Generator] = None
if self.quasirandomness:
self.sobol_engine = torch.quasirandom.SobolEngine(dimension=1, scramble=True, seed=self.train_seed)

def _apply(self, fn):
super(StableDiffusion, self)._apply(fn)
self.latent_mean = fn(self.latent_mean)
self.latent_std = fn(self.latent_std)
return self

def _generate_quasirandom_timesteps(self, latents: torch.Tensor):
# Generate a quasirandom sequence of timesteps equal to the global batch size
global_batch_size = latents.shape[0] * dist.get_world_size()
timesteps = (len(self.noise_scheduler) * self.sobol_engine.draw(global_batch_size)).squeeze()
Landanjs marked this conversation as resolved.
Show resolved Hide resolved
timesteps = torch.floor(timesteps).long().clamp(0, len(self.noise_scheduler) - 1)
Landanjs marked this conversation as resolved.
Show resolved Hide resolved
# Get this device's subset of all the timesteps
idx_offset = dist.get_global_rank() * latents.shape[0]
timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]]
return timesteps.to(latents.device)

def set_rng_generator(self, rng_generator: torch.Generator):
"""Sets the rng generator for the model."""
self.rng_generator = rng_generator
Expand Down Expand Up @@ -193,10 +214,13 @@ def forward(self, batch):
text_pooled_embeds *= batch['drop_caption_mask'].view(-1, 1)

# Sample the diffusion timesteps
timesteps = torch.randint(0,
len(self.noise_scheduler), (latents.shape[0],),
device=latents.device,
generator=self.rng_generator)
if self.quasirandomness:
timesteps = self._generate_quasirandom_timesteps(latents)
else:
timesteps = torch.randint(0,
len(self.noise_scheduler), (latents.shape[0],),
device=latents.device,
generator=self.rng_generator)
Landanjs marked this conversation as resolved.
Show resolved Hide resolved
# Add noise to the inputs (forward diffusion)
noise = torch.randn(*latents.shape, device=latents.device, generator=self.rng_generator)
if self.offset_noise is not None:
Expand Down
Loading