-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathddp_test_nerf.py
172 lines (140 loc) · 6.58 KB
/
ddp_test_nerf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
# import torch.nn as nn
import torch.optim
import torch.distributed
# from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing
import numpy as np
import os
# from collections import OrderedDict
# from ddp_model import NerfNet
import time
from data_loader_split import load_data_split
from utils import mse2psnr, colorize_np, to8b
import imageio
from ddp_train_nerf import config_parser, setup_logger, setup, cleanup, render_single_image, create_nerf
import logging
logger = logging.getLogger(__package__)
def ddp_test_nerf(rank, args):
###### set up multi-processing
setup(rank, args.world_size)
###### set up logger
logger = logging.getLogger(__package__)
setup_logger()
###### decide chunk size according to gpu memory
if torch.cuda.get_device_properties(rank).total_memory / 1e9 > 24:
logger.info('setting batch size according to 48G gpu')
args.N_rand = 512
args.chunk_size = 4096
elif torch.cuda.get_device_properties(rank).total_memory / 1e9 > 14:
logger.info('setting batch size according to 24G gpu')
args.N_rand = 1024
args.chunk_size = 8192
elif torch.cuda.get_device_properties(rank).total_memory / 1e9 > 6:
logger.info('setting batch size according to 12G gpu')
args.N_rand = 512
args.chunk_size = 4096
elif torch.cuda.get_device_properties(rank).total_memory / 1e9 > 3:
logger.info('setting batch size according to 4G gpu')
args.N_rand = 128
args.chunk_size = 1024
###### create network and wrap in ddp; each process should do this
start, models = create_nerf(rank, args)
render_splits = [x.strip() for x in args.render_splits.strip().split(',')]
# start testing
for split in render_splits:
out_dir = os.path.join(args.basedir, args.expname,
'render_{}_{:06d}'.format(split, start))
if rank == 0:
os.makedirs(out_dir, exist_ok=True)
###### load data and create ray samplers; each process should do this
ray_samplers = load_data_split(args.datadir, args.scene, split, try_load_min_depth=args.load_min_depth,
resolution_level=args.resolution_level)
for idx in range(len(ray_samplers)):
### each process should do this; but only main process merges the results
fname = '{:06d}.png'.format(idx)
if ray_samplers[idx].img_path is not None:
fname = os.path.basename(ray_samplers[idx].img_path)
if os.path.isfile(os.path.join(out_dir, fname)):
logger.info('Skipping {}'.format(fname))
continue
time0 = time.time()
rot_angle = 0
if args.rotate_test_env:
rot_angle = idx/len(ray_samplers)*np.pi*2
ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size, start, rot_angle=rot_angle, img_name=fname)
# rot_angle = 0*idx/len(ray_samplers)*np.pi*2
# ret = render_single_image(rank, args.world_size, models, ray_samplers[idx], args.chunk_size, start, rot_angle=rot_angle)
dt = time.time() - time0
if rank == 0: # only main process should do this
logger.info('Rendered {} in {} seconds'.format(fname, dt))
# only save last level
im = ret[-1]['rgb'].numpy()
# compute psnr if ground-truth is available
if ray_samplers[idx].img_path is not None:
gt_im = ray_samplers[idx].get_img()
psnr = mse2psnr(np.mean((gt_im - im) * (gt_im - im)))
logger.info('{}: psnr={}'.format(fname, psnr))
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, fname), im)
im = ret[-1]['fg_rgb'].numpy()
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'fg_' + fname), im)
im = ret[-1]['pure_rgb'].numpy()
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'pure_' + fname), im)
im = ret[-1]['bg_rgb'].numpy()
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'bg_' + fname), im)
im = ret[-1]['fg_depth'].numpy()
im = colorize_np(im, cmap_name='jet', append_cbar=True)
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'fg_depth_' + fname), im)
im = ret[-1]['bg_depth'].numpy()
im = colorize_np(im, cmap_name='jet', append_cbar=True)
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'bg_depth_' + fname), im)
im = ret[-1]['fg_albedo'].numpy()
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'fg_albedo_' + fname), im)
im = ret[-1]['fg_normal'].numpy()
im = (im + 1)/2
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'fg_normal_' + fname), im)
im = ret[-1]['viewdir'].numpy()
im = (im + 1)/2
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'viewdir_' + fname), im)
im = ret[-1]['fg_normal'].numpy()
light = np.maximum(0, np.sum(im * np.array([1.0, 1.0, 1.0])/np.sqrt(3), -1, keepdims=True))
im = im*0+light
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'shaded_' + fname), im)
im = ret[-1]['shadow'].numpy()
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'shadow_' + fname), im)
im = ret[-1]['irradiance'].numpy()
if im.max() > 1:
im = im / im.max()
im = to8b(im)
imageio.imwrite(os.path.join(out_dir, 'irradiance_' + fname), im)
torch.cuda.empty_cache()
# clean up for multi-processing
cleanup()
def test():
parser = config_parser()
args = parser.parse_args()
logger.info(parser.format_values())
if args.world_size == -1:
args.world_size = torch.cuda.device_count()
logger.info('Using # gpus: {}'.format(args.world_size))
# if args.world_size > -100:
torch.multiprocessing.spawn(ddp_test_nerf,
args=(args,),
nprocs=args.world_size,
join=True)
# else:
# ddp_test_nerf(0, args)
if __name__ == '__main__':
setup_logger()
test()