-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
37 lines (29 loc) · 1016 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# coding:utf-8
from __future__ import print_function
import torch
import torchvision.transforms as tfs
import torch.nn.functional as F
from model import SSD
from data import PriorBox
from config import opt
from PIL import Image, ImageDraw
net = SSD(opt)
net.load_state_dict(torch.load(opt.ckpt_path)['net'])
net.eval()
# 加载测试图片
img = Image.open('/home/j/MYSSD/pytorch-ssd-master/image/img1.jpg')
img1 = img.resize((300, 300))
transform = tfs.Compose([tfs.ToTensor(), tfs.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
img1 = transform(img1)
# 前向传播
loc, conf = net(img1[None, :, :, :])
# 将数据转换格式
prior_box = PriorBox(opt)
# squeeze是把batch_size那一层去掉
boxes, labels, scores = prior_box.convert_result(loc.squeeze(0), F.softmax(conf.squeeze(0), dim=0))
draw = ImageDraw.Draw(img)
for box in boxes:
box[::2] *= img.width
box[1::2] *= img.height
draw.rectangle(list(box), outline='red')
img.show()