Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only download CLIP on rank 0 when doing eval #135

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import json
import os
from typing import List, Optional
from typing import Dict, List, Optional

import clip
import torch
Expand Down Expand Up @@ -50,6 +50,7 @@ class CleanFIDEvaluator:
precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``.
prompts (List[str], optional): The prompts to use for image visualtization.
Default: ``["A shiba inu wearing a blue sweater]``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.

"""

Expand All @@ -69,7 +70,8 @@ def __init__(self,
output_dir: str = '/tmp/',
num_samples: Optional[int] = None,
precision: str = 'amp_fp16',
prompts: Optional[List[str]] = None):
prompts: Optional[List[str]] = None,
additional_generate_kwargs: Optional[Dict] = None):
Landanjs marked this conversation as resolved.
Show resolved Hide resolved
self.model = model
self.tokenizer: PreTrainedTokenizerBase = model.tokenizer
self.eval_dataloader = eval_dataloader
Expand All @@ -86,6 +88,7 @@ def __init__(self,
self.num_samples = num_samples if num_samples is not None else float('inf')
self.precision = precision
self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater']
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
self.sdxl = model.sdxl

# Init loggers
Expand All @@ -107,7 +110,13 @@ def __init__(self,
self.clip_metric = self.clip_metric.to(self.device)

# Predownload the CLIP model for computing clip-fid
_, _ = clip.load('ViT-B/32', device=self.device)
clip_url = 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt'
clip_name = os.path.basename(clip_url)
clip_path = os.path.expanduser('~/.cache/clip')
if dist.get_local_rank() == 0:
clip.clip._download(clip_url, clip_path)
with dist.local_rank_zero_download_and_wait(os.path.join(clip_path, clip_name)):
clip.load('ViT-B/32', device=self.device)

def _generate_images(self, guidance_scale: float):
"""Core image generation function. Generates images at a given guidance scale.
Expand Down Expand Up @@ -156,7 +165,8 @@ def _generate_images(self, guidance_scale: float):
seed=seed,
crop_params=crop_params,
input_size_params=input_size_params,
progress_bar=False) # type: ignore
progress_bar=False,
**self.additional_generate_kwargs) # type: ignore
# Get the prompts from the tokens
text_captions = self.tokenizer.batch_decode(captions, skip_special_tokens=True)
self.clip_metric.update((generated_images * 255).to(torch.uint8), text_captions)
Expand Down Expand Up @@ -233,7 +243,8 @@ def _generate_images_from_prompts(self, guidance_scale: float):
height=self.size,
width=self.size,
guidance_scale=guidance_scale,
seed=self.seed) # type: ignore
seed=self.seed,
**self.additional_generate_kwargs) # type: ignore
else:
generated_images = []
return generated_images
Expand Down
Loading