-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_model_labels.py
108 lines (94 loc) · 3.83 KB
/
train_model_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
import random
import os
import numpy as np
import cv2
from detectron2.utils.visualizer import Visualizer
from detectron2.data.datasets import register_coco_instances
from detectron2.data import MetadataCatalog
from detectron2.data import Metadata
from detectron2.data import DatasetCatalog
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode
def visualize_input(metadata, count):
name = metadata.get("name")
dataset_dicts = DatasetCatalog.get(name)
for d in random.sample(dataset_dicts, count):
full_path = d['file_name']
file_name = d['file_name'].split('/')[-1]
img = cv2.imread(full_path)
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, scale=1.0)
vis = visualizer.draw_dataset_dict(d)
os.makedirs('images', exist_ok=True)
print(f'images/{name}_{file_name}')
cv2.imwrite(f'images/{name}_{file_name}', vis.get_image()[:, :, ::-1])
def main():
prefix = open('config/overall_prefix.txt').readlines()[0].strip()
conf = json.load(open('config/training_data_ocr.json'))
metadata = None # Need it in outer block for reuse
train = []
test_images = f'{prefix}inhs_images_smaller/'
for img_dir in conf.keys():
ims = f'{prefix}{img_dir}'
for dataset in conf[img_dir]:
json_file = f'datasets/{dataset}'
name = dataset.split('.')[0]
train.append(name)
# This if only matters if you want to visualize a certain
# dataset with the `visualize_input` function after the loop.
# Otherwise, any of the datasets will work.
if name == '1_labels':
metadata = Metadata(evaluator_type='coco', image_root=ims,
json_file=json_file,
name=name,
thing_classes=['label'],
thing_dataset_id_to_contiguous_id={1: 0}
)
register_coco_instances(name, {}, json_file, ims)
# visualize_input(metadata, 1)
cfg = get_cfg()
cfg.merge_from_file("config/mask_rcnn_R_50_FPN_3x.yaml")
cfg.DATASETS.TRAIN = tuple(train)
cfg.DATASETS.TEST = () # no metrics implemented yet
cfg.DATALOADER.NUM_WORKERS = 2
# initialize from model zoo
cfg.MODEL.WEIGHTS = "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.02
cfg.SOLVER.MAX_ITER = (
50000
)
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (
128
)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=True)
trainer.train()
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "ocr_model_NEW.pth")
# set the testing threshold for this model
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.1
predictor = DefaultPredictor(cfg)
i = 0
names = os.listdir(test_images)
outputs = []
for d in random.sample(names, 10):
im = cv2.imread(test_images + d)
outputs.append(predictor(im))
v = Visualizer(im[:, :, ::-1],
metadata=metadata,
scale=0.8,
# remove the colors of unsegmented pixels
instance_mode=ColorMode.IMAGE_BW
)
v = v.draw_instance_predictions(outputs[-1]["instances"].to("cpu"))
i += 1
print(f'{i}: {d}')
os.makedirs('images', exist_ok=True)
print(f'images/prediction_{d}')
cv2.imwrite(f'images/prediction_{d}', v.get_image()[:, :, ::-1])
return outputs
if __name__ == '__main__':
main()