-
Notifications
You must be signed in to change notification settings - Fork 0
/
constants.py
87 lines (72 loc) · 2.63 KB
/
constants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import pandas as pd
from torchvision.transforms import (Compose,
ToTensor,
Normalize,
PILToTensor,
Resize,
CenterCrop)
from pathlib import Path
CIFAR10_TRANSFORM = Compose([
ToTensor(),
Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
IMAGENET_TRANSFORM = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
CLIP_TRANSFORM = PILToTensor()
CLIP_TRANSFORM_MODELS = {"clip", "blip", "altclip", "groupvit", "owlvit"}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CIFAR10_LABELS_TEXT = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
]
# json file from https://github.com/anishathalye/imagenet-simple-labels/blob/master/imagenet-simple-labels.json
IMAGENET_LABELS_TEXT = list(pd.read_json("imagenet-simple-labels.json")[0])
DATA_PATH_DEFAULT = "C:/ml_datasets"
FIGURES_PATH_DEFAULT = "figures"
VALID_MODELS_TRANSFORMERS = {
"clip@openai/clip-vit-large-patch14",
"clip@openai/clip-vit-base-patch16",
"clip@openai/clip-vit-base-patch32",
"clip@openai/clip-vit-large-patch14-336",
"clip@laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
"clip@laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
"clip@laion/CLIP-ViT-B-16-laion2B-s34B-b88K",
"clip@laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
"altclip@BAAI/AltCLIP",
"altclip@BAAI/AltCLIP-m9",
"altclip@BAAI/AltCLIP-m18",
"groupvit@nvidia/groupvit-gcc-yfcc",
"groupvit@nvidia/groupvit-gcc-redcaps",
"owlvit@google/owlvit-base-patch32",
"owlvit@google/owlvit-base-patch16",
"owlvit@google/owlvit-large-patch14",
"blip@Salesforce/blip-itm-base-coco",
"blip@Salesforce/blip-itm-large-coco",
"blip@Salesforce/blip-itm-base-flickr"
}
VALID_MODELS_LAVIS = set()
VALID_MODELS = VALID_MODELS_TRANSFORMERS.union(VALID_MODELS_LAVIS)
def model_name_parser(inp):
return inp.split("@")
def get_output_path(results_dir, dataset_name, model_type, weights_name, image_noun, prefix_mod, suffix_mod,
filetype="pt"):
s_image_noun = f"_{image_noun}" if image_noun else ""
s_prefix_mod = f"_{prefix_mod}" if prefix_mod else ""
s_suffix_mod = f"_{suffix_mod}" if suffix_mod else ""
return (Path(results_dir) / dataset_name / model_type /
f"{weights_name}{s_image_noun}{s_prefix_mod}{s_suffix_mod}.{filetype}")
if __name__ == "__main__":
print(IMAGENET_LABELS_TEXT)