-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute_fid_scores.py
55 lines (39 loc) · 1.84 KB
/
compute_fid_scores.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import json
import pandas as pd
import torch
import argparse
from cleanfid import fid
def main(args):
im_val_dir = args.im_val_dir
sd_gen_base = args.sd_gen_dir
results_dict = {"experiment": [], "duration_mean": [], "duration_std": [], "fid": []}
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
for file in sorted(os.listdir(sd_gen_base)):
basename, ext = os.path.splitext(file)
if ext != ".json":
continue
with open(os.path.join(sd_gen_base, file), 'r') as f:
data = json.load(f)
sd_gen_dir = data["output_dir"]
fid_score = fid.compute_fid(fdir1 = im_val_dir,
fdir2 = sd_gen_dir,
mode=args.mode,
model_name="inception_v3"
num_workers = 12,
batch_size = 32,
device=device)
results_dict["experiment"].append(basename)
results_dict["duration_mean"].append(data["duration_mean"])
results_dict["duration_std"].append(data["duration_std"])
results_dict["fid"].append(fid_score)
print(results_dict)
data_csv = pd.DataFrame.from_dict(results_dict)
data_csv.to_csv(f"SD_gen_results_{args.mode}_fid.csv", sep=";", encoding='utf-8', index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser('Computation of FID between reference ImageNet set and SD generations', add_help=False)
parser.add_argument('--mode', default="clean", type=str)
parser.add_argument('--im_val_dir', default="path/to/im_val_set", type=str)
parser.add_argument('--sd_gen_dir', default="path/to/sd_generations", type=str)
args = parser.parse_args()
main(args)