-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
60 lines (52 loc) · 1.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from config import *
from PIL import Image
from mindspore import ops
import mindspore
import mindspore.dataset.vision.py_transforms as py_trans
from mindspore.dataset.transforms.py_transforms import Compose
def path_gen( train=False, val=False, test=False ):
if train:
split = 'train'
elif val:
split = 'val'
else:
split = 'test'
ap = annotations_path + split + '_align.json'
qp = questions_path + split + '_align.json'
ip = images_path + split + '/'
return ap, qp, ip
def decode(image):
return Image.fromarray(image)
def trans_gen( train=False, val=False, test=False ):
mode = 'train' if train else 'val'
# 定义transforms列表
transforms_dict = {
'train':[
decode,
py_trans.Resize(size=(224, 224)),
py_trans.RandomHorizontalFlip(0.2),
py_trans.ToTensor(),
py_trans.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
],
'val':[
decode,
py_trans.Resize(size=(224, 224)),
py_trans.ToTensor(),
py_trans.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]}
# 通过Compose操作将transforms列表中函数作用于数据集图片
return Compose(transforms_dict[mode])
def batch_accuracy(predicted, answers):
""" Compute the accuracies for a batch of predictions and answers """
print("predicted:", predicted)
print("answers:", answers)
arg_max = ops.Argmax(axis=1, output_type=mindspore.int32)
gather = ops.GatherD()
minimum = ops.Minimum()
unsqueeze = ops.ExpandDims()
squeeze = ops.Squeeze(1)
predicted_index = arg_max(predicted)
predicted_index = unsqueeze(predicted_index, 1)
agreeing = gather(answers, 1, predicted_index)
agreeing = squeeze(agreeing)
return minimum(agreeing * 0.3, 1.0)