diff --git a/diffusion/models/models.py b/diffusion/models/models.py index febcb7c7..8dfedf49 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -48,9 +48,13 @@ def stable_diffusion_2( prediction_type: str = 'epsilon', latent_mean: Union[float, Tuple, str] = 0.0, latent_std: Union[float, Tuple, str] = 5.489980785067252, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, offset_noise: Optional[float] = None, train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, val_seed: int = 1138, precomputed_latents: bool = False, encode_latents_in_fp16: bool = True, @@ -78,10 +82,17 @@ def stable_diffusion_2( latent_std (float, list, str): The std. dev. of the autoencoder latents. Either a float for a single value, a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder checkpoint. Defaults to `1/0.18215`. + beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. + Default: `scaled_linear`. + zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. train_metrics (list, optional): List of metrics to compute during training. If None, defaults to [MeanSquaredError()]. val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not @@ -145,14 +156,26 @@ def stable_diffusion_2( assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) # Make the noise schedulers - noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler') - inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps, - beta_start=noise_scheduler.config.beta_start, - beta_end=noise_scheduler.config.beta_end, - beta_schedule=noise_scheduler.config.beta_schedule, - trained_betas=noise_scheduler.config.trained_betas, - clip_sample=noise_scheduler.config.clip_sample, - set_alpha_to_one=noise_scheduler.config.set_alpha_to_one, + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + clip_sample=False, + set_alpha_to_one=False, prediction_type=prediction_type) # Make the composer model @@ -170,6 +193,8 @@ def stable_diffusion_2( offset_noise=offset_noise, train_metrics=train_metrics, val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, val_seed=val_seed, precomputed_latents=precomputed_latents, encode_latents_in_fp16=encode_latents_in_fp16, @@ -207,9 +232,13 @@ def stable_diffusion_xl( prediction_type: str = 'epsilon', latent_mean: Union[float, Tuple, str] = 0.0, latent_std: Union[float, Tuple, str] = 7.67754318618, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, offset_noise: Optional[float] = None, train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, val_seed: int = 1138, precomputed_latents: bool = False, encode_latents_in_fp16: bool = True, @@ -247,12 +276,19 @@ def stable_diffusion_xl( latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder checkpoint. Defaults to `1/0.13025`. + beta_schedule (str): The beta schedule to use. Must be one of 'scaled_linear', 'linear', or 'squaredcos_cap_v2'. + Default: `scaled_linear`. + zero_terminal_snr (bool): Whether to enforce zero terminal SNR. Default: `False`. offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not be used. Default `None`. train_metrics (list, optional): List of metrics to compute during training. If None, defaults to [MeanSquaredError()]. val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. @@ -360,17 +396,40 @@ def stable_diffusion_xl( resnet._fsdp_wrap = True # Make the noise schedulers - noise_scheduler = DDPMScheduler.from_pretrained(unet_model_name, subfolder='scheduler') - inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule='scaled_linear', - trained_betas=None, - prediction_type=prediction_type, - interpolation_type='linear', - use_karras_sigmas=False, - timestep_spacing='leading', - steps_offset=1) + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + if beta_schedule == 'squaredcos_cap_v2': + inference_noise_scheduler = DDIMScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + rescale_betas_zero_snr=zero_terminal_snr) + else: + inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + prediction_type=prediction_type, + interpolation_type='linear', + use_karras_sigmas=False, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) # Make the composer model model = StableDiffusion( @@ -387,6 +446,8 @@ def stable_diffusion_xl( offset_noise=offset_noise, train_metrics=train_metrics, val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, val_seed=val_seed, precomputed_latents=precomputed_latents, encode_latents_in_fp16=encode_latents_in_fp16, diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 518b4de4..90d40a67 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -9,6 +9,8 @@ import torch import torch.nn.functional as F from composer.models import ComposerModel +from composer.utils import dist +from scipy.stats import qmc from torchmetrics import MeanSquaredError from tqdm.auto import tqdm @@ -51,6 +53,10 @@ class StableDiffusion(ComposerModel): Default: `None`. val_metrics (list): List of torchmetrics to calculate during validation. Default: `None`. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. val_seed (int): Seed to use for generating eval images. Default: `1138`. image_key (str): The name of the image inputs in the dataloader batch. Default: `image_tensor`. @@ -85,6 +91,8 @@ def __init__(self, offset_noise: Optional[float] = None, train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, val_seed: int = 1138, image_key: str = 'image', text_key: str = 'captions', @@ -105,6 +113,8 @@ def __init__(self, raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}') self.downsample_factor = downsample_factor self.offset_noise = offset_noise + self.quasirandomness = quasirandomness + self.train_seed = train_seed self.val_seed = val_seed self.image_key = image_key self.image_latents_key = image_latents_key @@ -140,6 +150,8 @@ def __init__(self, # Optional rng generator self.rng_generator: Optional[torch.Generator] = None + if self.quasirandomness: + self.sobol = qmc.Sobol(d=1, scramble=True, seed=self.train_seed) def _apply(self, fn): super(StableDiffusion, self)._apply(fn) @@ -147,6 +159,23 @@ def _apply(self, fn): self.latent_std = fn(self.latent_std) return self + def _generate_timesteps(self, latents: torch.Tensor): + if self.quasirandomness: + # Generate a quasirandom sequence of timesteps equal to the global batch size + global_batch_size = latents.shape[0] * dist.get_world_size() + sampled_fractions = torch.tensor(self.sobol.random(global_batch_size), device=latents.device) + timesteps = (len(self.noise_scheduler) * sampled_fractions).squeeze() + timesteps = torch.floor(timesteps).long() + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * latents.shape[0] + timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) + else: + timesteps = torch.randint(0, + len(self.noise_scheduler), (latents.shape[0],), + device=latents.device, + generator=self.rng_generator) + return timesteps + def set_rng_generator(self, rng_generator: torch.Generator): """Sets the rng generator for the model.""" self.rng_generator = rng_generator @@ -193,10 +222,7 @@ def forward(self, batch): text_pooled_embeds *= batch['drop_caption_mask'].view(-1, 1) # Sample the diffusion timesteps - timesteps = torch.randint(0, - len(self.noise_scheduler), (latents.shape[0],), - device=latents.device, - generator=self.rng_generator) + timesteps = self._generate_timesteps(latents) # Add noise to the inputs (forward diffusion) noise = torch.randn(*latents.shape, device=latents.device, generator=self.rng_generator) if self.offset_noise is not None: diff --git a/tests/test_model.py b/tests/test_model.py index 84dfd80e..bb697a20 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -150,3 +150,14 @@ def test_sdxl_generate(guidance_scale, negative_prompt, mask_pad_tokens): progress_bar=False, ) assert output.shape == (1, 3, 16, 16) + + +def test_quasirandomness(): + # fp16 vae does not run on cpu + model = stable_diffusion_2(pretrained=False, fsdp=False, encode_latents_in_fp16=False, quasirandomness=True) + # Generate many quasi-random samples + fake_latents = torch.randn(2048, 4, 8, 8) + for i in range(10**3): + timesteps = model._generate_timesteps(fake_latents) + assert (timesteps >= 0).all() + assert (timesteps < 1000).all()