Skip to content

Commit

Permalink
add NotImplementedError, fix docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jazcollins committed Aug 21, 2023
1 parent b07abcf commit 9739743
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def stable_diffusion_2(
prompts.
Args:
model_name (str, optional): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'.
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
model_name (str): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'.
pretrained (bool): Whether to load pretrained weights. Defaults to True.
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
Expand All @@ -54,12 +54,12 @@ def stable_diffusion_2(
[MeanSquaredError(), FrechetInceptionDistance(normalize=True)].
val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to
[1.0, 3.0, 7.0].
val_seed (int, optional): Seed to use for generating evaluation images. Defaults to 1138.
val_seed (int): Seed to use for generating evaluation images. Defaults to 1138.
loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to
[(0, 1)].
precomputed_latents (bool, optional): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool, optional): Whether to encode latents in fp16. Defaults to True.
fsdp (bool, optional): Whether to use FSDP. Defaults to True.
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
fsdp (bool): Whether to use FSDP. Defaults to True.
"""
if train_metrics is None:
train_metrics = [MeanSquaredError()]
Expand Down Expand Up @@ -144,13 +144,14 @@ def stable_diffusion_xl(
prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2.
Args:
model_name (str, optional): Name of the model to load. Determines the text encoder, tokenizer,
model_name (str): Name of the model to load. Determines the text encoder, tokenizer,
and noise scheduler. Defaults to 'stabilityai/stable-diffusion-2-base'.
unet_model_name (str, optional): Name of the UNet model to load. Defaults to
unet_model_name (str): Name of the UNet model to load. Defaults to
'stabilityai/stable-diffusion-xl-base-1.0'.
vae_model_name (str, optional): Name of the VAE model to load. Defaults to
'madebyollin/sdxl-vae-fp16-fix'.
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
vae_model_name (str): Name of the VAE model to load. Defaults to
'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from
'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16.
pretrained (bool): Whether to load pretrained weights. Defaults to True.
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
Expand All @@ -159,12 +160,12 @@ def stable_diffusion_xl(
[MeanSquaredError(), FrechetInceptionDistance(normalize=True)].
val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to
[1.0, 3.0, 7.0].
val_seed (int, optional): Seed to use for generating evaluation images. Defaults to 1138.
val_seed (int): Seed to use for generating evaluation images. Defaults to 1138.
loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to
[(0, 1)].
precomputed_latents (bool, optional): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool, optional): Whether to encode latents in fp16. Defaults to True.
fsdp (bool, optional): Whether to use FSDP. Defaults to True.
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
fsdp (bool): Whether to use FSDP. Defaults to True.
"""
if train_metrics is None:
train_metrics = [MeanSquaredError()]
Expand All @@ -180,7 +181,7 @@ def stable_diffusion_xl(
metric.requires_grad_(False)

if pretrained:
unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet')
raise NotImplementedError('Full SDXL pipeline not implemented yet.')
else:
config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')
# Currently not doing micro-conditioning, so set config appropriately
Expand Down

0 comments on commit 9739743

Please sign in to comment.