diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index f67c362f..261edd5e 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -24,7 +24,6 @@ jobs: strategy: matrix: python_version: - - "3.8" - "3.9" - "3.10" pip_deps: diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 8b43017f..b43ce0ae 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -19,10 +19,6 @@ jobs: strategy: matrix: include: - - name: 'cpu-3.8-1.11' - container: mosaicml/pytorch:1.11.0_cpu-python3.8-ubuntu20.04 - markers: 'not gpu' - pytest_command: 'coverage run -m pytest' - name: 'cpu-3.9-1.12' container: mosaicml/pytorch:1.12.1_cpu-python3.9-ubuntu20.04 markers: 'not gpu' diff --git a/README.md b/README.md index 0f400650..29f17978 100644 --- a/README.md +++ b/README.md @@ -39,27 +39,14 @@ Results from our Mosaic Diffusion model after training for 550k iterations at 25 Here are the system settings we recommend to start training your own diffusion models: - Use a Docker image with PyTorch 1.13+, e.g. [MosaicML's PyTorch base image](https://hub.docker.com/r/mosaicml/pytorch/tags) - - Recommended tag: `mosaicml/pytorch_vision:1.13.1_cu117-python3.10-ubuntu20.04` + - Recommended tag: `mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04` - This image comes pre-configured with the following dependencies: - - PyTorch Version: 1.13.1 - - CUDA Version: 11.7 + - PyTorch Version: 2.1.2 + - CUDA Version: 12.1 - Python Version: 3.10 - Ubuntu Version: 20.04 - Use a system with NVIDIA GPUs -- For running on NVIDIA H100s, use a docker image with PyTorch 2.0+ e.g. [MosaicML's PyTorch base image](https://hub.docker.com/r/mosaicml/pytorch/tags) - - Recommended tag: `mosaicml/pytorch_vision:2.0.1_cu118-python3.10-ubuntu20.04` - - This image comes pre-configured with the following dependencies: - - PyTorch Version: 2.0.1 - - CUDA Version: 11.8 - - Python Version: 3.10 - - Ubuntu Version: 20.04 - - Depending on the training config, an additional install of `xformers` may be needed: - ``` - pip install -U ninja - pip install -U git+https://github.com/facebookresearch/xformers - ``` - # How many GPUs do I need? We benchmarked the U-Net training throughput as we scale the number of A100 GPUs from 8 to 128. Our time estimates are based on training Stable Diffusion 2.0 base on 1,126,400,000 images at 256x256 resolution and 1,740,800,000 images at 512x512 resolution. Our cost estimates are based on $2 / A100-hour. Since the time and cost estimates are for the U-Net only, these only hold if the VAE and CLIP latents are computed before training. It took 3,784 A100-hours (cost of $7,600) to pre-compute the VAE and CLIP latents offline. If you are computing VAE and CLIP latents while training, expect a 1.4x increase in time and cost. diff --git a/diffusion/models/autoencoder.py b/diffusion/models/autoencoder.py index 7da24827..21b2c55e 100644 --- a/diffusion/models/autoencoder.py +++ b/diffusion/models/autoencoder.py @@ -6,7 +6,7 @@ Based on the implementation from https://github.com/CompVis/stable-diffusion """ -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import lpips import torch @@ -16,6 +16,8 @@ from composer.utils import dist from composer.utils.file_helpers import get_file from diffusers import AutoencoderKL +from diffusers.models.autoencoders.vae import DecoderOutput +from diffusers.models.modeling_outputs import AutoencoderKLOutput from torchmetrics import MeanMetric, MeanSquaredError, Metric from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity @@ -662,7 +664,7 @@ def update_metric(self, batch, outputs, metric): metric.update(outputs['x_recon'], batch[self.input_key]) -class ComposerDiffusersAutoEncoder(ComposerAutoEncoder): +class ComposerDiffusersAutoEncoder(ComposerModel): """Composer wrapper for the Huggingface Diffusers Autoencoder. Args: @@ -672,24 +674,83 @@ class ComposerDiffusersAutoEncoder(ComposerAutoEncoder): """ def __init__(self, model: AutoencoderKL, autoencoder_loss: AutoEncoderLoss, input_key: str = 'image'): - super().__init__(model, autoencoder_loss, input_key) + super().__init__() self.model = model self.autoencoder_loss = autoencoder_loss self.input_key = input_key + # Set up train metrics + train_metrics = [MeanSquaredError()] + self.train_metrics = {metric.__class__.__name__: metric for metric in train_metrics} + # Set up val metrics + psnr_metric = PeakSignalNoiseRatio(data_range=2.0) + ssim_metric = StructuralSimilarityIndexMeasure(data_range=2.0) + lpips_metric = LearnedPerceptualImagePatchSimilarity(net_type='vgg') + val_metrics = [MeanSquaredError(), MeanMetric(), lpips_metric, psnr_metric, ssim_metric] + self.val_metrics = {metric.__class__.__name__: metric for metric in val_metrics} + def get_last_layer_weight(self) -> torch.Tensor: """Get the weight of the last layer of the decoder.""" return self.model.decoder.conv_out.weight def forward(self, batch): - latent_dist = self.model.encode(batch[self.input_key])['latent_dist'] + encoder_output = self.model.encode(batch[self.input_key], return_dict=True) + assert isinstance(encoder_output, AutoencoderKLOutput) + latent_dist = encoder_output['latent_dist'] latents = latent_dist.sample() mean, log_var = latent_dist.mean, latent_dist.logvar - recon = self.model.decode(latents).sample + output_dist = self.model.decode(latents, return_dict=True) + assert isinstance(output_dist, DecoderOutput) + recon = output_dist.sample return {'x_recon': recon, 'latents': latents, 'mean': mean, 'log_var': log_var} + def loss(self, outputs, batch): + last_layer = self.get_last_layer_weight() + return self.autoencoder_loss(outputs, batch, last_layer) + + def eval_forward(self, batch, outputs=None): + if outputs is not None: + return outputs + outputs = self.forward(batch) + return outputs + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics = self.train_metrics + else: + metrics = self.val_metrics + + if isinstance(metrics, Metric): + metrics_dict = {metrics.__class__.__name__: metrics} + elif isinstance(metrics, list): + metrics_dict = {metrics.__class__.__name__: metric for metric in metrics} + else: + metrics_dict = {} + for name, metric in metrics.items(): + assert isinstance(metric, Metric) + metrics_dict[name] = metric + + return metrics_dict + + def update_metric(self, batch, outputs, metric): + clamped_imgs = outputs['x_recon'].clamp(-1, 1) + if isinstance(metric, MeanMetric): + metric.update(torch.square(outputs['latents'])) + elif isinstance(metric, LearnedPerceptualImagePatchSimilarity): + metric.update(clamped_imgs, batch[self.input_key]) + elif isinstance(metric, PeakSignalNoiseRatio): + metric.update(clamped_imgs, batch[self.input_key]) + elif isinstance(metric, StructuralSimilarityIndexMeasure): + metric.update(clamped_imgs, batch[self.input_key]) + elif isinstance(metric, MeanSquaredError): + metric.update(outputs['x_recon'], batch[self.input_key]) + else: + metric.update(outputs['x_recon'], batch[self.input_key]) + -def load_autoencoder(load_path: str, local_path: str = '/tmp/autoencoder_weights.pt', torch_dtype=None): +def load_autoencoder(load_path: str, + local_path: str = '/tmp/autoencoder_weights.pt', + torch_dtype=None) -> Tuple[AutoEncoder, Optional[Dict]]: """Function to load an AutoEncoder from a composer checkpoint without the loss weights. Will also load the latent statistics if the statistics tracking callback was used. diff --git a/diffusion/models/layers.py b/diffusion/models/layers.py index eb4b980e..cd7ee04e 100644 --- a/diffusion/models/layers.py +++ b/diffusion/models/layers.py @@ -3,7 +3,7 @@ """Helpful layers and functions for UNet and Autoencoder construction.""" -from typing import Optional +from typing import Optional, TypeVar import torch import torch.nn as nn @@ -14,8 +14,10 @@ except: pass +_T = TypeVar('_T', bound=nn.Module) -def zero_module(module: torch.nn.Module) -> torch.nn.Module: + +def zero_module(module: _T) -> _T: """Zero out the parameters of a module and return it.""" for p in module.parameters(): p.detach().zero_() diff --git a/diffusion/models/models.py b/diffusion/models/models.py index fed0db63..1a4d1b8e 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -115,11 +115,11 @@ def stable_diffusion_2( # Make the unet if pretrained: unet = UNet2DConditionModel.from_pretrained(model_name, subfolder='unet') - if autoencoder_path is not None and vae.config['latent_channels'] != 4: + if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4: raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.') else: unet_config = PretrainedConfig.get_config_dict(model_name, subfolder='unet')[0] - if autoencoder_path is not None: + if isinstance(vae, AutoEncoder): # Adapt the unet config to account for differing number of latent channels if necessary unet_config['in_channels'] = vae.config['latent_channels'] unet_config['out_channels'] = vae.config['latent_channels'] @@ -271,11 +271,11 @@ def stable_diffusion_xl( # Make the unet if pretrained: unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet') - if autoencoder_path is not None and vae.config['latent_channels'] != 4: + if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4: raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.') else: unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] - if autoencoder_path is not None: + if isinstance(vae, AutoEncoder): # Adapt the unet config to account for differing number of latent channels if necessary unet_config['in_channels'] = vae.config['latent_channels'] unet_config['out_channels'] = vae.config['latent_channels'] @@ -462,6 +462,7 @@ def build_diffusers_autoencoder(model_name: str = 'stabilityai/stable-diffusion- else: config = PretrainedConfig.get_config_dict(model_name) model = AutoencoderKL(**config[0]) + assert isinstance(model, AutoencoderKL) # Configure the loss function autoencoder_loss = AutoEncoderLoss(input_key=input_key, @@ -488,12 +489,13 @@ def discrete_pixel_diffusion(clip_model_name: str = 'openai/clip-vit-large-patch Defaults to 'epsilon'. """ # Create a pixel space unet - unet = UNet2DConditionModel(in_channels=3, - out_channels=3, - attention_head_dim=[5, 10, 20, 20], - cross_attention_dim=768, - flip_sin_to_cos=True, - use_linear_projection=True) + unet = UNet2DConditionModel( + in_channels=3, + out_channels=3, + attention_head_dim=[5, 10, 20, 20], # type: ignore + cross_attention_dim=768, + flip_sin_to_cos=True, + use_linear_projection=True) # Get the CLIP text encoder and tokenizer: text_encoder = CLIPTextModel.from_pretrained(clip_model_name) tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) @@ -562,12 +564,13 @@ def continuous_pixel_diffusion(clip_model_name: str = 'openai/clip-vit-large-pat Defaults to 1.56 (pi/2 - 0.01 for stability). """ # Create a pixel space unet - unet = UNet2DConditionModel(in_channels=3, - out_channels=3, - attention_head_dim=[5, 10, 20, 20], - cross_attention_dim=768, - flip_sin_to_cos=True, - use_linear_projection=True) + unet = UNet2DConditionModel( + in_channels=3, + out_channels=3, + attention_head_dim=[5, 10, 20, 20], # type: ignore + cross_attention_dim=768, + flip_sin_to_cos=True, + use_linear_projection=True) # Get the CLIP text encoder and tokenizer: text_encoder = CLIPTextModel.from_pretrained(clip_model_name) tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) diff --git a/scripts/precompute_latents.py b/scripts/precompute_latents.py index b2cb4d4d..9ba8a3d6 100644 --- a/scripts/precompute_latents.py +++ b/scripts/precompute_latents.py @@ -13,6 +13,7 @@ from composer.devices import DeviceGPU from composer.utils import dist from diffusers import AutoencoderKL +from diffusers.models.modeling_outputs import AutoencoderKLOutput from PIL import Image from streaming import MDSWriter, Stream, StreamingDataset from torch.utils.data import DataLoader @@ -251,6 +252,7 @@ def main(args: Namespace) -> None: device = DeviceGPU() vae = AutoencoderKL.from_pretrained(args.model_name, subfolder='vae', torch_dtype=torch.float16) + assert isinstance(vae, AutoencoderKL) text_encoder = CLIPTextModel.from_pretrained(args.model_name, subfolder='text_encoder', torch_dtype=torch.float16) vae = device.module_to_device(vae) text_encoder = device.module_to_device(text_encoder) @@ -294,8 +296,12 @@ def main(args: Namespace) -> None: with torch.no_grad(): # Encode the images to the latent space with magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515) - latents_256 = vae.encode(image_256.half())['latent_dist'].sample().data * 0.18215 - latents_512 = vae.encode(image_512.half())['latent_dist'].sample().data * 0.18215 + latent_dist_256 = vae.encode(image_256.half()) + assert isinstance(latent_dist_256, AutoencoderKLOutput) + latents_256 = latent_dist_256['latent_dist'].sample().data * 0.18215 + latent_dist_512 = vae.encode(image_512.half()) + assert isinstance(latent_dist_512, AutoencoderKLOutput) + latents_512 = latent_dist_512['latent_dist'].sample().data * 0.18215 # Encode the text. Assume that the text is already tokenized conditioning = text_encoder(captions.view(-1, captions.shape[-1]))[0] # Should be (batch_size, 77, 768) diff --git a/setup.py b/setup.py index 19bf10a3..1a33b1b4 100644 --- a/setup.py +++ b/setup.py @@ -6,20 +6,20 @@ from setuptools import find_packages, setup install_requires = [ - 'mosaicml==0.16.3', - 'mosaicml-streaming>=0.7.1,<1.0', + 'mosaicml==0.20.1', + 'mosaicml-streaming==0.7.4', 'hydra-core>=1.2', 'hydra-colorlog>=1.1.0', - 'diffusers[torch]==0.21.0', - 'transformers[torch]==4.31.0', - 'wandb==0.15.4', - 'xformers==0.0.21', - 'triton==2.0.0', - 'torchmetrics[image]==0.11.4', + 'diffusers[torch]==0.26.3', + 'transformers[torch]==4.38.2', + 'wandb==0.16.3', + 'xformers==0.0.23.post1', + 'triton==2.1.0', + 'torchmetrics[image]==1.3.1', 'lpips==0.1.4', - 'clean-fid', - 'clip@git+https://github.com/openai/CLIP.git', - 'gradio==4.14.0', + 'clean-fid==0.1.35', + 'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33', + 'gradio==4.19.2', ] extras_require = {}