-
Notifications
You must be signed in to change notification settings - Fork 1
/
imgnet_copy.py
51 lines (40 loc) · 1.56 KB
/
imgnet_copy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from os.path import exists, join, dirname, curdir
from os import listdir, mkdir, sep
from shutil import rmtree, copyfile
ROOT = join(curdir, "tmp", "tiny-imagenet-200")
TGT_ROOT_VAL = join(ROOT, "val", "ds")
TGT_ROOT_TEST = join(ROOT, "test", "ds")
SRC_ROOT_VAL = join(ROOT, "val", "images")
SRC_ROOT_TEST = join(ROOT, "test", "images")
if exists(TGT_ROOT_VAL):
rmtree(TGT_ROOT_VAL)
if exists(TGT_ROOT_TEST):
rmtree(TGT_ROOT_TEST)
mkdir(TGT_ROOT_VAL)
mkdir(TGT_ROOT_TEST)
def get_class_mapping(txt_path):
mapping = {}
with open(txt_path, 'r') as fp:
for i, line in enumerate(fp):
splitted = line.split('\t')
print(i, txt_path)
if splitted[1] in mapping:
mapping[splitted[1]].append(splitted[0])
else:
mapping[splitted[1]] = [splitted[0]]
return mapping
def copy_by_class_mapping(mapping: dict, tgt_path, src_path):
c = 0
for cls, datapoints in mapping.items():
if not exists(join(tgt_path, cls)):
mkdir(join(tgt_path, cls))
for dp in datapoints:
c += 1
if c%100 == 0:
print(c)
print(join(src_path, dp), join(tgt_path, dp))
copyfile(join(src_path, dp), join(tgt_path, cls, dp))
val_mapping = get_class_mapping(join(dirname(TGT_ROOT_VAL), "val_annotations.txt"))
copy_by_class_mapping(val_mapping, TGT_ROOT_VAL, SRC_ROOT_VAL)
#tst_mapping = get_class_mapping(join(dirname(TGT_ROOT_TEST), "test_annotations.txt"))
#copy_by_class_mapping(val_mapping, TGT_ROOT_TEST, SRC_ROOT_TEST)