-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathmain.py
149 lines (118 loc) · 5.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Label Convert
Convert from YOLO -> VOC | VOC -> YOLO
"""
import argparse
import multiprocessing
import os
from xml.etree import ElementTree
from PIL import Image
from pascal_voc_writer import Writer
import config
def yolo2voc(txt_file: str) -> None:
"""Convert YOLO to VOC
Args:
txt_file: str
"""
w, h = Image.open(os.path.join(config.IMAGE_DIR, f"{txt_file[:-4]}.jpg")).size
writer = Writer(f"{txt_file[:-4]}.xml", w, h)
with open(os.path.join(config.LABEL_DIR, txt_file)) as f:
for line in f.readlines():
label, x_center, y_center, width, height = line.rstrip().split(" ")
x_min = int(w * max(float(x_center) - float(width) / 2, 0))
x_max = int(w * min(float(x_center) + float(width) / 2, 1))
y_min = int(h * max(float(y_center) - float(height) / 2, 0))
y_max = int(h * min(float(y_center) + float(height) / 2, 1))
writer.addObject(config.names[int(label)], x_min, y_min, x_max, y_max)
writer.save(os.path.join(config.XML_DIR, f"{txt_file[:-4]}.xml"))
def voc2yolo(xml_file: str) -> None:
"""Convert VOC to YOLO
Args:
xml_file: str
"""
with open(f"{config.XML_DIR}/{xml_file}") as in_file:
tree = ElementTree.parse(in_file)
size = tree.getroot().find("size")
height, width = map(int, [size.find("height").text, size.find("width").text])
class_exists = False
for obj in tree.findall("object"):
name = obj.find("name").text
if name in config.names:
class_exists = True
if class_exists:
with open(f"{config.LABEL_DIR}/{xml_file[:-4]}.txt", "w") as out_file:
for obj in tree.findall("object"):
difficult = obj.find("difficult").text
if int(difficult) == 1:
continue
xml_box = obj.find("bndbox")
x_min = float(xml_box.find("xmin").text)
y_min = float(xml_box.find("ymin").text)
x_max = float(xml_box.find("xmax").text)
y_max = float(xml_box.find("ymax").text)
# according to darknet annotation
box_x_center = (x_min + x_max) / 2.0 - 1
box_y_center = (y_min + y_max) / 2.0 - 1
box_w = x_max - x_min
box_h = y_max - y_min
box_x = box_x_center * 1.0 / width
box_w = box_w * 1.0 / width
box_y = box_y_center * 1.0 / height
box_h = box_h * 1.0 / height
b = [box_x, box_y, box_w, box_h]
cls_id = config.names.index(obj.find("name").text)
out_file.write(str(cls_id) + " " + " ".join([str(f"{i:.6f}") for i in b]) + "\n")
def voc2yolo_a(xml_file: str) -> None:
"""Convert VOC to YOLO with absolute cordinates
Args:
xml_file: str
"""
with open(f"{config.XML_DIR}/{xml_file}") as in_file:
tree = ElementTree.parse(in_file)
class_exists = False
for obj in tree.findall("object"):
name = obj.find("name").text
if name in config.names:
class_exists = True
if class_exists:
with open(f"{config.LABEL_DIR}/{xml_file[:-4]}.txt", "w") as out_file:
for obj in tree.findall("object"):
difficult = obj.find("difficult").text
if int(difficult) == 1:
continue
xml_box = obj.find("bndbox")
x_min = round(float(xml_box.find("xmin").text))
y_min = round(float(xml_box.find("ymin").text))
x_max = round(float(xml_box.find("xmax").text))
y_max = round(float(xml_box.find("ymax").text))
b = [x_min, y_min, x_max, y_max]
cls_id = config.names.index(obj.find("name").text)
out_file.write(str(cls_id) + " " + " ".join([str(f"{i}") for i in b]) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--yolo2voc", action="store_true", help="YOLO to VOC")
parser.add_argument("--voc2yolo", action="store_true", help="VOC to YOLO")
parser.add_argument("--voc2yolo_a", action="store_true", help="VOC to YOLO absolute")
args = parser.parse_args()
if args.yolo2voc:
print("YOLO to VOC")
txt_files = [
name for name in os.listdir(config.LABEL_DIR) if name.endswith(".txt")
]
with multiprocessing.Pool(os.cpu_count()) as pool:
pool.map(yolo2voc, txt_files)
pool.join()
if args.voc2yolo:
print("VOC to YOLO")
xml_files = [
name for name in os.listdir(config.XML_DIR) if name.endswith(".xml")
]
with multiprocessing.Pool(os.cpu_count()) as pool:
pool.map(voc2yolo, xml_files)
pool.join()
if args.voc2yolo_a:
xml_files = [
name for name in os.listdir(config.XML_DIR) if name.endswith(".xml")
]
with multiprocessing.Pool(os.cpu_count()) as pool:
pool.map(voc2yolo_a, xml_files)
pool.close()