Skip to content

Commit

Permalink
Add evaluation using clean-fid (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Jul 20, 2023
1 parent c872947 commit 131f30e
Show file tree
Hide file tree
Showing 7 changed files with 519 additions and 0 deletions.
93 changes: 93 additions & 0 deletions diffusion/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Evaluate model."""

import operator
from typing import List

import hydra
from composer import Algorithm, ComposerModel
from composer.algorithms.low_precision_groupnorm import apply_low_precision_groupnorm
from composer.algorithms.low_precision_layernorm import apply_low_precision_layernorm
from composer.core import Precision
from composer.loggers import LoggerDestination
from composer.utils import reproducibility
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from torchmetrics.multimodal import CLIPScore

from diffusion.evaluation.clean_fid_eval import CleanFIDEvaluator


def evaluate(config: DictConfig) -> None:
"""Evaluate a model.
Args:
config (DictConfig): Configuration composed by Hydra
"""
reproducibility.seed_all(config.seed)

# The model to evaluate
model: ComposerModel = hydra.utils.instantiate(config.model)

# The dataloader to use for evaluation
eval_dataloader: DataLoader = hydra.utils.instantiate(config.eval_dataloader)

# The CLIPScores metric to use for evaluation
clip_metric: CLIPScore = hydra.utils.instantiate(config.clip_metric)

# Build list of loggers and algorithms.
logger: List[LoggerDestination] = []
algorithms: List[Algorithm] = []

# Set up logging for results
if 'logger' in config:
for log, lg_conf in config.logger.items():
if '_target_' in lg_conf:
print(f'Instantiating logger <{lg_conf._target_}>')
if log == 'wandb':
container = OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
# use _partial_ so it doesn't try to init everything
wandb_logger = hydra.utils.instantiate(lg_conf, _partial_=True)
logger.append(wandb_logger(init_kwargs={'config': container}))
else:
logger.append(hydra.utils.instantiate(lg_conf))

# Some algorithms should also be applied at inference time
if 'algorithms' in config:
for ag_name, ag_conf in config.algorithms.items():
if '_target_' in ag_conf:
print(f'Instantiating algorithm <{ag_conf._target_}>')
algorithms.append(hydra.utils.instantiate(ag_conf))
elif ag_name == 'low_precision_groupnorm':
surgery_target = model
if 'attribute' in ag_conf:
surgery_target = operator.attrgetter(ag_conf.attribute)(model)
apply_low_precision_groupnorm(
model=surgery_target,
precision=Precision(ag_conf['precision']),
optimizers=None,
)
elif ag_name == 'low_precision_layernorm':
surgery_target = model
if 'attribute' in ag_conf:
surgery_target = operator.attrgetter(ag_conf.attribute)(model)
apply_low_precision_layernorm(
model=surgery_target,
precision=Precision(ag_conf['precision']),
optimizers=None,
)

evaluator: CleanFIDEvaluator = hydra.utils.instantiate(
config.evaluator,
model=model,
eval_dataloader=eval_dataloader,
clip_metric=clip_metric,
loggers=logger,
)

def evaluate_model():
evaluator.evaluate()

return evaluate_model()
4 changes: 4 additions & 0 deletions diffusion/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Evaluation."""
246 changes: 246 additions & 0 deletions diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Evaluation using the clean-fid package."""

import json
import os
from typing import List, Optional

import clip
import torch
import wandb
from cleanfid import fid
from composer import ComposerModel, Trainer
from composer.core import get_precision_context
from composer.loggers import LoggerDestination, WandBLogger
from composer.utils import dist
from torch.utils.data import DataLoader
from torchmetrics.multimodal import CLIPScore
from torchvision.transforms.functional import to_pil_image
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizerBase

os.environ['TOKENIZERS_PARALLELISM'] = 'false'


class CleanFIDEvaluator:
"""Evaluator for CLIP, FID, KID, CLIP-FID scores using clean-fid.
See https://github.com/GaParmar/clean-fid for more information on clean-fid.
CLIP scores are computed using the torchmetrics CLIPScore metric.
Args:
model (ComposerModel): The model to evaluate.
eval_dataloader (DataLoader): The dataloader to use for evaluation.
clip_metric (CLIPScore): The CLIPScore metric to use for evaluation.
load_path (str, optional): The path to load the model from. Default: ``None``.
guidance_scales (List[float]): The guidance scales to use for evaluation.
Default: ``[1.0]``.
size (int): The size of the images to generate. Default: ``256``.
batch_size (int): The per-device batch size to use for evaluation. Default: ``16``.
loggers (List[LoggerDestination], optional): The loggers to use for logging results. Default: ``None``.
seed (int): The seed to use for evaluation. Default: ``17``.
output_dir (str): The directory to save results to. Default: ``/tmp/``.
num_samples (int, optional): The maximum number of samples to generate. Depending on batch size, actual
number may be slightly higher. If not specified, all the samples in the dataloader will be used.
Default: ``None``.
precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``.
prompts (List[str], optional): The prompts to use for image visualtization.
Default: ``["A shiba inu wearing a blue sweater]``.
"""

def __init__(self,
model: ComposerModel,
eval_dataloader: DataLoader,
clip_metric: CLIPScore,
load_path: Optional[str] = None,
guidance_scales: Optional[List[float]] = None,
size: int = 256,
batch_size: int = 16,
image_key: str = 'image',
caption_key: str = 'caption',
loggers: Optional[List[LoggerDestination]] = None,
seed: int = 17,
output_dir: str = '/tmp/',
num_samples: Optional[int] = None,
precision: str = 'amp_fp16',
prompts: Optional[List[str]] = None):
self.model = model
self.tokenizer: PreTrainedTokenizerBase = model.tokenizer
self.eval_dataloader = eval_dataloader
self.clip_metric = clip_metric
self.load_path = load_path
self.guidance_scales = guidance_scales if guidance_scales is not None else [1.0]
self.size = size
self.batch_size = batch_size
self.image_key = image_key
self.caption_key = caption_key
self.loggers = loggers
self.seed = seed
self.output_dir = output_dir
self.num_samples = num_samples if num_samples is not None else float('inf')
self.precision = precision
self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater']

# Init loggers
if self.loggers and dist.get_local_rank() == 0:
for logger in self.loggers:
if isinstance(logger, WandBLogger):
wandb.init(**logger._init_kwargs)

# Load the model
Trainer(model=self.model,
load_path=self.load_path,
load_weights_only=True,
eval_dataloader=self.eval_dataloader,
seed=self.seed)

# Move CLIP metric to device
self.device = dist.get_local_rank()
self.clip_metric = self.clip_metric.to(self.device)

# Predownload the CLIP model for computing clip-fid
_, _ = clip.load('ViT-B/32', device=self.device)

def _generate_images(self, guidance_scale: float):
"""Core image generation function. Generates images at a given guidance scale.
Args:
guidance_scale (float): The guidance scale to use for image generation.
"""
# Verify output dirs exist, if they don't, create them
real_image_path = os.path.join(self.output_dir, f'real_images_gs_{guidance_scale}')
gen_image_path = os.path.join(self.output_dir, f'gen_images_gs_{guidance_scale}')
if not os.path.exists(real_image_path) and dist.get_local_rank() == 0:
os.makedirs(real_image_path)
if not os.path.exists(gen_image_path) and dist.get_local_rank() == 0:
os.makedirs(gen_image_path)

# Reset the CLIP metric
self.clip_metric.reset()

# Storage for prompts
prompts = {}
# Iterate over the eval dataloader
num_batches = len(self.eval_dataloader)
starting_seed = self.seed + num_batches * dist.get_local_rank()
for batch_id, batch in tqdm(enumerate(self.eval_dataloader)):
# Break if enough samples have been generated
if batch_id * self.batch_size * dist.get_world_size() >= self.num_samples:
break

real_images = batch[self.image_key]
captions = batch[self.caption_key]
# Ensure a new seed for each batch, as randomness in model.generate is fixed.
seed = starting_seed + batch_id
# Generate images from the captions
with get_precision_context(self.precision):
generated_images = self.model.generate(tokenized_prompts=captions,
height=self.size,
width=self.size,
guidance_scale=guidance_scale,
seed=seed,
progress_bar=False) # type: ignore
# Get the prompts from the tokens
text_captions = self.tokenizer.batch_decode(captions, skip_special_tokens=True)
self.clip_metric.update((generated_images * 255).to(torch.uint8), text_captions)
# Save the real images
# Verify that the real images are in the proper range
if real_images.min() < 0.0 or real_images.max() > 1.0:
raise ValueError(
f'Images are expected to be in the range [0, 1]. Got max {real_images.max()} and min {real_images.min()}'
)
for i, img in enumerate(real_images):
to_pil_image(img).save(f'{real_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png')
prompts[f'{batch_id}_{i}_rank_{dist.get_local_rank()}'] = text_captions[i]
# Save the generated images
for i, img in enumerate(generated_images):
to_pil_image(img).save(f'{gen_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png')

# Save the prompts as json
json.dump(prompts, open(f'{real_image_path}/prompts_rank_{dist.get_local_rank()}.json', 'w'))

def _compute_metrics(self, guidance_scale: float):
"""Compute metrics for the generated images at a given guidance scale.
Args:
guidance_scale (float): The guidance scale to use for image generation.
Returns:
Dict[str, float]: The computed metrics.
"""
# Path to find the generated images in
real_image_path = os.path.join(self.output_dir, f'real_images_gs_{guidance_scale}')
gen_image_path = os.path.join(self.output_dir, f'gen_images_gs_{guidance_scale}')

metrics = {}
# CLIP score
clip_score = self.clip_metric.compute()
metrics['CLIP-score'] = clip_score
print(f'{guidance_scale} CLIP score: {clip_score}')

# Need to tell clean-fid which device to use
device = torch.device(self.device)
# Standard FID
fid_score = fid.compute_fid(real_image_path,
gen_image_path,
device=device,
use_dataparallel=False,
verbose=False)
metrics['FID'] = fid_score
print(f'{guidance_scale} FID: {fid_score}')
# CLIP-FID from https://arxiv.org/abs/2203.06026
clip_fid_score = fid.compute_fid(real_image_path,
gen_image_path,
mode='clean',
model_name='clip_vit_b_32',
device=device,
use_dataparallel=False,
verbose=False)
metrics['CLIP-FID'] = clip_fid_score
print(f'{guidance_scale} CLIP-FID: {clip_fid_score}')
# KID
kid_score = fid.compute_kid(real_image_path,
gen_image_path,
device=device,
use_dataparallel=False,
verbose=False)
metrics['KID'] = kid_score
print(f'{guidance_scale} KID: {kid_score}')
return metrics

def _generate_images_from_prompts(self, guidance_scale: float):
"""Generate images from prompts for visualization."""
if self.prompts:
with get_precision_context(self.precision):
generated_images = self.model.generate(prompt=self.prompts,
height=self.size,
width=self.size,
guidance_scale=guidance_scale,
seed=self.seed) # type: ignore
else:
generated_images = []
return generated_images

def evaluate(self):
# Generate images and compute metrics for each guidance scale
for guidance_scale in self.guidance_scales:
dist.barrier()
# Generate images and compute metrics
self._generate_images(guidance_scale=guidance_scale)
# Need to wait until all ranks have finished generating images before computing metrics
dist.barrier()
# Compute the metrics on the generated images
metrics = self._compute_metrics(guidance_scale=guidance_scale)
# Generate images from prompts for visualization
generated_images = self._generate_images_from_prompts(guidance_scale=guidance_scale)
# Log metrics and images on rank 0
if self.loggers and dist.get_local_rank() == 0:
for logger in self.loggers:
for metric, value in metrics.items():
logger.log_metrics({f'{guidance_scale}/{metric}': value})
for prompt, image in zip(self.prompts, generated_images):
logger.log_images(images=image, name=f'{prompt}_gs_{guidance_scale}')
26 changes: 26 additions & 0 deletions run_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Run evaluation."""

import textwrap

import hydra
from omegaconf import DictConfig

from diffusion.evaluate import evaluate


@hydra.main(version_base=None)
def main(config: DictConfig) -> None:
"""Hydra wrapper for evaluation."""
if not config:
raise ValueError(
textwrap.dedent("""\
Config path and name not specified!
Please specify these by using --config-path and --config-name, respectively."""))
return evaluate(config)


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
'xformers==0.0.16',
'triton==2.0.0',
'torchmetrics[image]==0.11.3',
'clean-fid',
'clip@git+https://github.com/openai/CLIP.git',
]

extras_require = {}
Expand Down
Loading

0 comments on commit 131f30e

Please sign in to comment.