diff --git a/diffusion/evaluation/clean_fid_eval.py b/diffusion/evaluation/clean_fid_eval.py index 45c3faf2..2bac182b 100644 --- a/diffusion/evaluation/clean_fid_eval.py +++ b/diffusion/evaluation/clean_fid_eval.py @@ -5,7 +5,7 @@ import json import os -from typing import List, Optional +from typing import Dict, List, Optional import clip import torch @@ -50,6 +50,7 @@ class CleanFIDEvaluator: precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``. prompts (List[str], optional): The prompts to use for image visualtization. Default: ``["A shiba inu wearing a blue sweater]``. + additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method. """ @@ -69,7 +70,8 @@ def __init__(self, output_dir: str = '/tmp/', num_samples: Optional[int] = None, precision: str = 'amp_fp16', - prompts: Optional[List[str]] = None): + prompts: Optional[List[str]] = None, + additional_generate_kwargs: Optional[Dict] = None): self.model = model self.tokenizer: PreTrainedTokenizerBase = model.tokenizer self.eval_dataloader = eval_dataloader @@ -86,6 +88,7 @@ def __init__(self, self.num_samples = num_samples if num_samples is not None else float('inf') self.precision = precision self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater'] + self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {} self.sdxl = model.sdxl # Init loggers @@ -107,7 +110,13 @@ def __init__(self, self.clip_metric = self.clip_metric.to(self.device) # Predownload the CLIP model for computing clip-fid - _, _ = clip.load('ViT-B/32', device=self.device) + clip_url = 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt' + clip_name = os.path.basename(clip_url) + clip_path = os.path.expanduser('~/.cache/clip') + if dist.get_local_rank() == 0: + clip.clip._download(clip_url, clip_path) + with dist.local_rank_zero_download_and_wait(os.path.join(clip_path, clip_name)): + clip.load('ViT-B/32', device=self.device) def _generate_images(self, guidance_scale: float): """Core image generation function. Generates images at a given guidance scale. @@ -156,7 +165,8 @@ def _generate_images(self, guidance_scale: float): seed=seed, crop_params=crop_params, input_size_params=input_size_params, - progress_bar=False) # type: ignore + progress_bar=False, + **self.additional_generate_kwargs) # type: ignore # Get the prompts from the tokens text_captions = self.tokenizer.batch_decode(captions, skip_special_tokens=True) self.clip_metric.update((generated_images * 255).to(torch.uint8), text_captions) @@ -233,7 +243,8 @@ def _generate_images_from_prompts(self, guidance_scale: float): height=self.size, width=self.size, guidance_scale=guidance_scale, - seed=self.seed) # type: ignore + seed=self.seed, + **self.additional_generate_kwargs) # type: ignore else: generated_images = [] return generated_images