From 131f30e27882a43f1009fcaf6a3f787fe0c6fb8d Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Thu, 20 Jul 2023 16:09:56 -0700 Subject: [PATCH] Add evaluation using clean-fid (#51) --- diffusion/evaluate.py | 93 ++++++++++ diffusion/evaluation/__init__.py | 4 + diffusion/evaluation/clean_fid_eval.py | 246 +++++++++++++++++++++++++ run_eval.py | 26 +++ setup.py | 2 + yamls/hydra-yamls/eval-clean-fid.yaml | 62 +++++++ yamls/mosaic-yamls/eval-clean-fid.yaml | 86 +++++++++ 7 files changed, 519 insertions(+) create mode 100644 diffusion/evaluate.py create mode 100644 diffusion/evaluation/__init__.py create mode 100644 diffusion/evaluation/clean_fid_eval.py create mode 100644 run_eval.py create mode 100644 yamls/hydra-yamls/eval-clean-fid.yaml create mode 100644 yamls/mosaic-yamls/eval-clean-fid.yaml diff --git a/diffusion/evaluate.py b/diffusion/evaluate.py new file mode 100644 index 00000000..0bac28b9 --- /dev/null +++ b/diffusion/evaluate.py @@ -0,0 +1,93 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Evaluate model.""" + +import operator +from typing import List + +import hydra +from composer import Algorithm, ComposerModel +from composer.algorithms.low_precision_groupnorm import apply_low_precision_groupnorm +from composer.algorithms.low_precision_layernorm import apply_low_precision_layernorm +from composer.core import Precision +from composer.loggers import LoggerDestination +from composer.utils import reproducibility +from omegaconf import DictConfig, OmegaConf +from torch.utils.data import DataLoader +from torchmetrics.multimodal import CLIPScore + +from diffusion.evaluation.clean_fid_eval import CleanFIDEvaluator + + +def evaluate(config: DictConfig) -> None: + """Evaluate a model. + + Args: + config (DictConfig): Configuration composed by Hydra + """ + reproducibility.seed_all(config.seed) + + # The model to evaluate + model: ComposerModel = hydra.utils.instantiate(config.model) + + # The dataloader to use for evaluation + eval_dataloader: DataLoader = hydra.utils.instantiate(config.eval_dataloader) + + # The CLIPScores metric to use for evaluation + clip_metric: CLIPScore = hydra.utils.instantiate(config.clip_metric) + + # Build list of loggers and algorithms. + logger: List[LoggerDestination] = [] + algorithms: List[Algorithm] = [] + + # Set up logging for results + if 'logger' in config: + for log, lg_conf in config.logger.items(): + if '_target_' in lg_conf: + print(f'Instantiating logger <{lg_conf._target_}>') + if log == 'wandb': + container = OmegaConf.to_container(config, resolve=True, throw_on_missing=True) + # use _partial_ so it doesn't try to init everything + wandb_logger = hydra.utils.instantiate(lg_conf, _partial_=True) + logger.append(wandb_logger(init_kwargs={'config': container})) + else: + logger.append(hydra.utils.instantiate(lg_conf)) + + # Some algorithms should also be applied at inference time + if 'algorithms' in config: + for ag_name, ag_conf in config.algorithms.items(): + if '_target_' in ag_conf: + print(f'Instantiating algorithm <{ag_conf._target_}>') + algorithms.append(hydra.utils.instantiate(ag_conf)) + elif ag_name == 'low_precision_groupnorm': + surgery_target = model + if 'attribute' in ag_conf: + surgery_target = operator.attrgetter(ag_conf.attribute)(model) + apply_low_precision_groupnorm( + model=surgery_target, + precision=Precision(ag_conf['precision']), + optimizers=None, + ) + elif ag_name == 'low_precision_layernorm': + surgery_target = model + if 'attribute' in ag_conf: + surgery_target = operator.attrgetter(ag_conf.attribute)(model) + apply_low_precision_layernorm( + model=surgery_target, + precision=Precision(ag_conf['precision']), + optimizers=None, + ) + + evaluator: CleanFIDEvaluator = hydra.utils.instantiate( + config.evaluator, + model=model, + eval_dataloader=eval_dataloader, + clip_metric=clip_metric, + loggers=logger, + ) + + def evaluate_model(): + evaluator.evaluate() + + return evaluate_model() diff --git a/diffusion/evaluation/__init__.py b/diffusion/evaluation/__init__.py new file mode 100644 index 00000000..a8bd69e7 --- /dev/null +++ b/diffusion/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Evaluation.""" diff --git a/diffusion/evaluation/clean_fid_eval.py b/diffusion/evaluation/clean_fid_eval.py new file mode 100644 index 00000000..521eb7ed --- /dev/null +++ b/diffusion/evaluation/clean_fid_eval.py @@ -0,0 +1,246 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Evaluation using the clean-fid package.""" + +import json +import os +from typing import List, Optional + +import clip +import torch +import wandb +from cleanfid import fid +from composer import ComposerModel, Trainer +from composer.core import get_precision_context +from composer.loggers import LoggerDestination, WandBLogger +from composer.utils import dist +from torch.utils.data import DataLoader +from torchmetrics.multimodal import CLIPScore +from torchvision.transforms.functional import to_pil_image +from tqdm.auto import tqdm +from transformers import PreTrainedTokenizerBase + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + +class CleanFIDEvaluator: + """Evaluator for CLIP, FID, KID, CLIP-FID scores using clean-fid. + + See https://github.com/GaParmar/clean-fid for more information on clean-fid. + + CLIP scores are computed using the torchmetrics CLIPScore metric. + + Args: + model (ComposerModel): The model to evaluate. + eval_dataloader (DataLoader): The dataloader to use for evaluation. + clip_metric (CLIPScore): The CLIPScore metric to use for evaluation. + load_path (str, optional): The path to load the model from. Default: ``None``. + guidance_scales (List[float]): The guidance scales to use for evaluation. + Default: ``[1.0]``. + size (int): The size of the images to generate. Default: ``256``. + batch_size (int): The per-device batch size to use for evaluation. Default: ``16``. + loggers (List[LoggerDestination], optional): The loggers to use for logging results. Default: ``None``. + seed (int): The seed to use for evaluation. Default: ``17``. + output_dir (str): The directory to save results to. Default: ``/tmp/``. + num_samples (int, optional): The maximum number of samples to generate. Depending on batch size, actual + number may be slightly higher. If not specified, all the samples in the dataloader will be used. + Default: ``None``. + precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``. + prompts (List[str], optional): The prompts to use for image visualtization. + Default: ``["A shiba inu wearing a blue sweater]``. + + """ + + def __init__(self, + model: ComposerModel, + eval_dataloader: DataLoader, + clip_metric: CLIPScore, + load_path: Optional[str] = None, + guidance_scales: Optional[List[float]] = None, + size: int = 256, + batch_size: int = 16, + image_key: str = 'image', + caption_key: str = 'caption', + loggers: Optional[List[LoggerDestination]] = None, + seed: int = 17, + output_dir: str = '/tmp/', + num_samples: Optional[int] = None, + precision: str = 'amp_fp16', + prompts: Optional[List[str]] = None): + self.model = model + self.tokenizer: PreTrainedTokenizerBase = model.tokenizer + self.eval_dataloader = eval_dataloader + self.clip_metric = clip_metric + self.load_path = load_path + self.guidance_scales = guidance_scales if guidance_scales is not None else [1.0] + self.size = size + self.batch_size = batch_size + self.image_key = image_key + self.caption_key = caption_key + self.loggers = loggers + self.seed = seed + self.output_dir = output_dir + self.num_samples = num_samples if num_samples is not None else float('inf') + self.precision = precision + self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater'] + + # Init loggers + if self.loggers and dist.get_local_rank() == 0: + for logger in self.loggers: + if isinstance(logger, WandBLogger): + wandb.init(**logger._init_kwargs) + + # Load the model + Trainer(model=self.model, + load_path=self.load_path, + load_weights_only=True, + eval_dataloader=self.eval_dataloader, + seed=self.seed) + + # Move CLIP metric to device + self.device = dist.get_local_rank() + self.clip_metric = self.clip_metric.to(self.device) + + # Predownload the CLIP model for computing clip-fid + _, _ = clip.load('ViT-B/32', device=self.device) + + def _generate_images(self, guidance_scale: float): + """Core image generation function. Generates images at a given guidance scale. + + Args: + guidance_scale (float): The guidance scale to use for image generation. + """ + # Verify output dirs exist, if they don't, create them + real_image_path = os.path.join(self.output_dir, f'real_images_gs_{guidance_scale}') + gen_image_path = os.path.join(self.output_dir, f'gen_images_gs_{guidance_scale}') + if not os.path.exists(real_image_path) and dist.get_local_rank() == 0: + os.makedirs(real_image_path) + if not os.path.exists(gen_image_path) and dist.get_local_rank() == 0: + os.makedirs(gen_image_path) + + # Reset the CLIP metric + self.clip_metric.reset() + + # Storage for prompts + prompts = {} + # Iterate over the eval dataloader + num_batches = len(self.eval_dataloader) + starting_seed = self.seed + num_batches * dist.get_local_rank() + for batch_id, batch in tqdm(enumerate(self.eval_dataloader)): + # Break if enough samples have been generated + if batch_id * self.batch_size * dist.get_world_size() >= self.num_samples: + break + + real_images = batch[self.image_key] + captions = batch[self.caption_key] + # Ensure a new seed for each batch, as randomness in model.generate is fixed. + seed = starting_seed + batch_id + # Generate images from the captions + with get_precision_context(self.precision): + generated_images = self.model.generate(tokenized_prompts=captions, + height=self.size, + width=self.size, + guidance_scale=guidance_scale, + seed=seed, + progress_bar=False) # type: ignore + # Get the prompts from the tokens + text_captions = self.tokenizer.batch_decode(captions, skip_special_tokens=True) + self.clip_metric.update((generated_images * 255).to(torch.uint8), text_captions) + # Save the real images + # Verify that the real images are in the proper range + if real_images.min() < 0.0 or real_images.max() > 1.0: + raise ValueError( + f'Images are expected to be in the range [0, 1]. Got max {real_images.max()} and min {real_images.min()}' + ) + for i, img in enumerate(real_images): + to_pil_image(img).save(f'{real_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png') + prompts[f'{batch_id}_{i}_rank_{dist.get_local_rank()}'] = text_captions[i] + # Save the generated images + for i, img in enumerate(generated_images): + to_pil_image(img).save(f'{gen_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png') + + # Save the prompts as json + json.dump(prompts, open(f'{real_image_path}/prompts_rank_{dist.get_local_rank()}.json', 'w')) + + def _compute_metrics(self, guidance_scale: float): + """Compute metrics for the generated images at a given guidance scale. + + Args: + guidance_scale (float): The guidance scale to use for image generation. + + Returns: + Dict[str, float]: The computed metrics. + """ + # Path to find the generated images in + real_image_path = os.path.join(self.output_dir, f'real_images_gs_{guidance_scale}') + gen_image_path = os.path.join(self.output_dir, f'gen_images_gs_{guidance_scale}') + + metrics = {} + # CLIP score + clip_score = self.clip_metric.compute() + metrics['CLIP-score'] = clip_score + print(f'{guidance_scale} CLIP score: {clip_score}') + + # Need to tell clean-fid which device to use + device = torch.device(self.device) + # Standard FID + fid_score = fid.compute_fid(real_image_path, + gen_image_path, + device=device, + use_dataparallel=False, + verbose=False) + metrics['FID'] = fid_score + print(f'{guidance_scale} FID: {fid_score}') + # CLIP-FID from https://arxiv.org/abs/2203.06026 + clip_fid_score = fid.compute_fid(real_image_path, + gen_image_path, + mode='clean', + model_name='clip_vit_b_32', + device=device, + use_dataparallel=False, + verbose=False) + metrics['CLIP-FID'] = clip_fid_score + print(f'{guidance_scale} CLIP-FID: {clip_fid_score}') + # KID + kid_score = fid.compute_kid(real_image_path, + gen_image_path, + device=device, + use_dataparallel=False, + verbose=False) + metrics['KID'] = kid_score + print(f'{guidance_scale} KID: {kid_score}') + return metrics + + def _generate_images_from_prompts(self, guidance_scale: float): + """Generate images from prompts for visualization.""" + if self.prompts: + with get_precision_context(self.precision): + generated_images = self.model.generate(prompt=self.prompts, + height=self.size, + width=self.size, + guidance_scale=guidance_scale, + seed=self.seed) # type: ignore + else: + generated_images = [] + return generated_images + + def evaluate(self): + # Generate images and compute metrics for each guidance scale + for guidance_scale in self.guidance_scales: + dist.barrier() + # Generate images and compute metrics + self._generate_images(guidance_scale=guidance_scale) + # Need to wait until all ranks have finished generating images before computing metrics + dist.barrier() + # Compute the metrics on the generated images + metrics = self._compute_metrics(guidance_scale=guidance_scale) + # Generate images from prompts for visualization + generated_images = self._generate_images_from_prompts(guidance_scale=guidance_scale) + # Log metrics and images on rank 0 + if self.loggers and dist.get_local_rank() == 0: + for logger in self.loggers: + for metric, value in metrics.items(): + logger.log_metrics({f'{guidance_scale}/{metric}': value}) + for prompt, image in zip(self.prompts, generated_images): + logger.log_images(images=image, name=f'{prompt}_gs_{guidance_scale}') diff --git a/run_eval.py b/run_eval.py new file mode 100644 index 00000000..c09c1871 --- /dev/null +++ b/run_eval.py @@ -0,0 +1,26 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Run evaluation.""" + +import textwrap + +import hydra +from omegaconf import DictConfig + +from diffusion.evaluate import evaluate + + +@hydra.main(version_base=None) +def main(config: DictConfig) -> None: + """Hydra wrapper for evaluation.""" + if not config: + raise ValueError( + textwrap.dedent("""\ + Config path and name not specified! + Please specify these by using --config-path and --config-name, respectively.""")) + return evaluate(config) + + +if __name__ == '__main__': + main() diff --git a/setup.py b/setup.py index da12bdcf..d05090ce 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,8 @@ 'xformers==0.0.16', 'triton==2.0.0', 'torchmetrics[image]==0.11.3', + 'clean-fid', + 'clip@git+https://github.com/openai/CLIP.git', ] extras_require = {} diff --git a/yamls/hydra-yamls/eval-clean-fid.yaml b/yamls/hydra-yamls/eval-clean-fid.yaml new file mode 100644 index 00000000..3073b79b --- /dev/null +++ b/yamls/hydra-yamls/eval-clean-fid.yaml @@ -0,0 +1,62 @@ +image_size: 256 # This is the image resolution to evaluate at (assumes square images) +batch_size: 16 +name: clean-fid-eval # Name for the eval run for logging +project: diffusion-clean-fid-eval # Name of the wandb project for logging +seed: 42 # Random seed. This affects the randomness used in image generation. + +model: # This is the model to evaluate + _target_: diffusion.models.models.stable_diffusion_2 + pretrained: false + precomputed_latents: false + encode_latents_in_fp16: true + fsdp: false + val_metrics: + - _target_: torchmetrics.MeanSquaredError + val_guidance_scales: [] + loss_bins: [] +eval_dataloader: + _target_: diffusion.datasets.build_streaming_image_caption_dataloader + remote: + - # Remote(s) for the evaluation dataset go here + local: + - # Local(s) for the evaluation dataset go here + batch_size: ${batch_size} + resize_size: ${image_size} + image_key: image # This should be set to the image key specific to the eval dataset + caption_key: captions # This should be set to the caption key specific to the eval dataset + transform: # How to transform the images for evaluation + - _target_ : diffusion.datasets.laion.transforms.LargestCenterSquare + size: ${image_size} + - _target_: torchvision.transforms.ToTensor + dataloader_kwargs: + drop_last: false + shuffle: false + num_workers: 8 + pin_memory: true + streaming_kwargs: + shuffle: false +clip_metric: # This is the metric used to compute CLIP score, which is not part of clean-fid + _target_: torchmetrics.multimodal.CLIPScore + model_name_or_path: openai/clip-vit-base-patch16 +logger: + wandb: + _target_: composer.loggers.wandb_logger.WandBLogger + name: ${name} + project: ${project} + group: ${name} +evaluator: + _target_: diffusion.evaluation.clean_fid_eval.CleanFIDEvaluator + load_path: # Path to the checkpoint to load and evaluate. + guidance_scales: + - 1.0 + - 1.5 + - 2.0 + - 3.0 + - 4.0 + - 5.0 + - 6.0 + - 7.0 + - 8.0 + size: ${image_size} + batch_size: ${batch_size} + seed: ${seed} diff --git a/yamls/mosaic-yamls/eval-clean-fid.yaml b/yamls/mosaic-yamls/eval-clean-fid.yaml new file mode 100644 index 00000000..73988d73 --- /dev/null +++ b/yamls/mosaic-yamls/eval-clean-fid.yaml @@ -0,0 +1,86 @@ +run_name: diffusion-clean-fid-eval +image: mosaicml/pytorch_vision:1.13.1_cu117-python3.10-ubuntu20.04 +compute: + gpus: 8 # Number of GPUs to use. Note evaluating with clean-fid currently only supports 1 node. + ## These configurations are optional + # cluster: TODO # Name of the cluster to use for this run + # gpu_type: a100_80gb # Type of GPU to use. We use a100_80gb in our experiments +integrations: + - integration_type: "git_repo" + git_repo: mosaicml/diffusion + git_branch: main + pip_install: . + - integration_type: "wandb" + project: # Insert wandb project name + entity: # Insert wandb entity name + +# To run eval, one must specify the path to the dataset in the remote field. +# Guidance scale can also be specified here, 3.0 is a good starting point +# load_path is the location of the checkpoint. If load_path is not specified, the pretrained model will be used. +command: | + cd diffusion + HYDRA_FULL_ERROR=1 composer run_eval.py --config-path /mnt/config --config-name parameters + +parameters: + image_size: 256 # This is the image resolution to evaluate at (assumes square images) + batch_size: 16 + name: clean-fid-eval # Name for the eval run for logging + project: diffusion-clean-fid-eval # Name of the wandb project for logging + seed: 42 # Random seed. This affects the randomness used in image generation. + + model: # This is the model to evaluate + _target_: diffusion.models.models.stable_diffusion_2 + pretrained: false + precomputed_latents: false + encode_latents_in_fp16: true + fsdp: false + val_metrics: + - _target_: torchmetrics.MeanSquaredError + val_guidance_scales: [] + loss_bins: [] + eval_dataloader: + _target_: diffusion.datasets.build_streaming_image_caption_dataloader + remote: + - # Remote(s) for the evaluation dataset go here + local: + - # Local(s) for the evaluation dataset go here + batch_size: ${batch_size} + resize_size: ${image_size} + image_key: image # This should be set to the image key specific to the eval dataset + caption_key: captions # This should be set to the caption key specific to the eval dataset + transform: # How to transform the images for evaluation + - _target_ : diffusion.datasets.laion.transforms.LargestCenterSquare + size: ${image_size} + - _target_: torchvision.transforms.ToTensor + dataloader_kwargs: + drop_last: false + shuffle: false + num_workers: 8 + pin_memory: true + streaming_kwargs: + shuffle: false + clip_metric: # This is the metric used to compute CLIP score, which is not part of clean-fid + _target_: torchmetrics.multimodal.CLIPScore + model_name_or_path: openai/clip-vit-base-patch16 + logger: + wandb: + _target_: composer.loggers.wandb_logger.WandBLogger + name: ${name} + project: ${project} + group: ${name} + evaluator: + _target_: diffusion.evaluation.clean_fid_eval.CleanFIDEvaluator + load_path: # Path to the checkpoint to load and evaluate. + guidance_scales: + - 1.0 + - 1.5 + - 2.0 + - 3.0 + - 4.0 + - 5.0 + - 6.0 + - 7.0 + - 8.0 + size: ${image_size} + batch_size: ${batch_size} + seed: ${seed}