-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinference.py
63 lines (45 loc) · 1.8 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import os.path as osp
import json
import sys
import yaml
from glob import glob
import torch
import cv2
from torch import cuda
from tqdm import tqdm
from base.detect import detect
from base.model import EAST
from utils.argeParser import parse_args
CHECKPOINT_EXTENSIONS = ['.pth', '.ckpt']
LANGUAGE_LIST = ['chinese', 'japanese', 'thai', 'vietnamese']
def do_inference(input_size, batch_size, data_dir, model_dir, pth_path, device, output_dir, output_fname, split='test'):
model = EAST(pretrained=False).to(device)
# Get paths to checkpoint files
ckpt_fpath = osp.join(model_dir, pth_path+'.pth')
if not osp.exists(output_dir):
os.makedirs(output_dir)
model.load_state_dict(torch.load(ckpt_fpath, map_location='cpu'))
model.eval()
image_fnames, by_sample_bboxes = [], []
images = []
print('Inference in progress')
for image_fpath in tqdm(sum([glob(osp.join(data_dir, f'{lang}_receipt/img/{split}/*')) for lang in LANGUAGE_LIST], [])):
image_fnames.append(osp.basename(image_fpath))
images.append(cv2.imread(image_fpath)[:, :, ::-1])
if len(images) == batch_size:
by_sample_bboxes.extend(detect(model, images, input_size))
images = []
if len(images):
by_sample_bboxes.extend(detect(model, images, input_size))
ufo_result = dict(images=dict())
for image_fname, bboxes in zip(image_fnames, by_sample_bboxes):
words_info = {idx: dict(points=bbox.tolist()) for idx, bbox in enumerate(bboxes)}
ufo_result['images'][image_fname] = dict(words=words_info)
with open(osp.join(output_dir, output_fname+'.csv'), 'w') as f:
json.dump(ufo_result, f, indent=4)
def main(args):
do_inference(**args.__dict__)
if __name__ == '__main__':
args = parse_args('inference')
main(args)