-
Notifications
You must be signed in to change notification settings - Fork 66
/
inference_img.py
118 lines (108 loc) · 4.45 KB
/
inference_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import cv2
import torch
import argparse
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--img', dest='img', nargs=2, required=True)
parser.add_argument('--exp', default=4, type=int)
parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files')
args = parser.parse_args()
try:
try:
from model.RIFE_HDv2 import Model
model = Model()
model.load_model(args.modelDir, -1)
print("Loaded v2.x HD model.")
except:
from train_log.RIFE_HDv3 import Model
model = Model()
model.load_model(args.modelDir, -1)
print("Loaded v3.x HD model.")
except:
from model.RIFE_HD import Model
model = Model()
model.load_model(args.modelDir, -1)
print("Loaded v1.x HD model")
if not hasattr(model, 'version'):
model.version = 0
model.eval()
model.device()
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH)
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0)
else:
img0 = cv2.imread(args.img[0], cv2.IMREAD_UNCHANGED)
img1 = cv2.imread(args.img[1], cv2.IMREAD_UNCHANGED)
img0 = cv2.resize(img0, (448, 256))
img1 = cv2.resize(img1, (448, 256))
img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
n, c, h, w = img0.shape
ph = ((h - 1) // 64 + 1) * 64
pw = ((w - 1) // 64 + 1) * 64
padding = (0, pw - w, 0, ph - h)
img0 = F.pad(img0, padding)
img1 = F.pad(img1, padding)
if args.ratio:
if model.version >= 3.9:
img_list = [img0, model.inference(img0, img1, args.ratio), img1]
else:
img0_ratio = 0.0
img1_ratio = 1.0
if args.ratio <= img0_ratio + args.rthreshold / 2:
middle = img0
elif args.ratio >= img1_ratio - args.rthreshold / 2:
middle = img1
else:
tmp_img0 = img0
tmp_img1 = img1
for inference_cycle in range(args.rmaxcycles):
middle = model.inference(tmp_img0, tmp_img1)
middle_ratio = ( img0_ratio + img1_ratio ) / 2
if args.ratio - (args.rthreshold / 2) <= middle_ratio <= args.ratio + (args.rthreshold / 2):
break
if args.ratio > middle_ratio:
tmp_img0 = middle
img0_ratio = middle_ratio
else:
tmp_img1 = middle
img1_ratio = middle_ratio
img_list.append(middle)
img_list.append(img1)
else:
if model.version >= 3.9:
img_list = [img0]
n = 2 ** args.exp
for i in range(n-1):
img_list.append(model.inference(img0, img1, (i+1) * 1. / n))
img_list.append(img1)
else:
img_list = [img0, img1]
for i in range(args.exp):
tmp = []
for j in range(len(img_list) - 1):
mid = model.inference(img_list[j], img_list[j + 1])
tmp.append(img_list[j])
tmp.append(mid)
tmp.append(img1)
img_list = tmp
if not os.path.exists('output'):
os.mkdir('output')
for i in range(len(img_list)):
if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'):
cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
else:
cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])