-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnerf.py
419 lines (297 loc) · 14.5 KB
/
nerf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
# -*- coding: utf-8 -*-
"""nerf.ipynb
Automatically generated by Colaboratory.
"""
from google.colab import drive
drive.mount('/content/drive')
"""# Loading Blender Data"""
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import imageio.v2 as imageio
import os
import cv2
import json
def load_blender_data(directory, scale_factor=1):
"""
This functions takes in a base directory and reads the transforms.json
file which contains the [R | t] AKA extrinsic AKA Camera2World AKA pose
matrix for every training image.
It also takes in a downscale factor from 0-1 to help with memory and speed
limitations because the original 800x800 images were crashing even on
Google Collab.
"""
with open(os.path.join(directory, "transforms_train.json"), "r") as fp:
transforms = json.load(fp)
imgs = []
extrinsics = []
for frame in transforms['frames']:
fname = os.path.join(directory, frame['file_path']) + '.png'
Rt = frame['transform_matrix']
curr_img = imageio.imread(fname)
resized = cv2.resize(curr_img, None, fx=scale_factor, fy=scale_factor,
interpolation=cv2.INTER_AREA) # Not sure if cv2.INTER_LINEAR vs cv2.INTER_AREA will affect final model quality but probably does
imgs.append(resized)
extrinsics.append(np.array(frame["transform_matrix"]))
H, W = imgs[0].shape[:2]
focal_length = 0.5 * W / np.tan(0.5 * float(
transforms['camera_angle_x'])) # Focal length in pixel units (Calculation done in tutorial 8) using FOV of x
imgs = (np.array(imgs) / 255.).astype(np.float32) # Keeps pixel values in [0,1] (Normalization)
extrinsics = torch.from_numpy(np.array(extrinsics).astype(np.float32))
return imgs, extrinsics, H, W, focal_length
"""# The actual model/MLP"""
def positional_encoding(v, device, L=10):
"""
Performs positional encoding from the paper (Section 5.1).
Encoding applies sin() and cos() to every v entry with different frequencies
from 2^0 * pi to 2^{L-1} * pi to increase dimensionality and make our
network.
Example: v = [10] and L = 3
-> [10, sin(2^0 * pi * 10), cos(2^0 * pi * 10), sin(2^1 * pi * 10), cos(2^1 * pi * 10), sin(2^2 * pi * 10), cos(2^2 * pi * 10)]
"""
# We want to include the original (x,y,z) in our newly encoded vector to retain original "information"
res = [v]
for i in range(L):
freq = (2 ** i) * torch.pi * v
res.append(torch.sin(freq)) # Applying sin and cos to every entry as done in the paper
res.append(torch.cos(freq))
res = torch.cat(res, dim=-1).to(device) # Flatten all into one tensor
return res
class NerfModel(nn.Module):
"""
This network structure was made by referencing Fig. 7 of the NeRF paper
but is much smaller due to computational limits, I decided
to reduce it to only 4 layers otherwise training would be extremely
slow but I still retained the 256 hidden channels per hidden layer.
I referenced the TinyNerf available at https://github.com/bmild/nerf/blob/master/tiny_nerf.ipynb to help
with this as I needed to be sure I was implementing the MLP correctly in addition to making sure I
didn't remove any important features.
The final layers outputs a tuple of size 4 that contains R,G,B,sigma
where sigma is the volume density mentioned throughout the paper
You can change the number of hidden_channels to your liking, 128
might be better for training speed.
"""
def __init__(self, L_val_for_input=10, L_val_for_d=4, hidden_channels=256):
super().__init__()
self.L = L_val_for_input
self.L_d = L_val_for_d
self.relu = nn.functional.relu
# Note, the input layer is 6 * L + 3, this is because positional encoding
# goes from R -> R^{2L} so consequently an R^3 vector is taken to an R^{6L} vector,
# but if we concat our original input to the new vector it gives a total vector size of 6L + 3
self.layer1 = nn.Linear(3 + 6 * L_val_for_input, hidden_channels)
self.layer2 = nn.Linear(hidden_channels, hidden_channels)
self.layer3 = nn.Linear(hidden_channels, hidden_channels//2)
self.layer4 = nn.Linear(hidden_channels//2, 4)
def forward(self, x):
l1 = self.relu(self.layer1(x))
l2 = self.relu(self.layer2(l1))
l3 = self.relu(self.layer3(l2))
l4 = self.layer4(l3)
return l4
"""# Rendering and Ray Projecting/Marching"""
# Code referenced from line 123 of: https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py and
# modified to fit pytorch
def get_rays(H, W, focal, c2w, device):
"""
Given an image Height, Width, focal value, and [R|t] matrix. The code
generates all 2D (i,j) pixel coords using a meshgrid, then computes the corresponding NDC
rays for each pixel coordinate w.r.t to the current view/pose.
"""
# Note, what's being done here is that pytorch doesn't have an equivalent of
# tf.meshgrid(indexing = 'xy'), so we need to tranpose along the last
# two dimensions where the first dimension represents the H direction and
# the second dimension represents the W direction.
i, j = torch.meshgrid(torch.arange(W).to(c2w), torch.arange(H).to(c2w))
i = i.transpose(-1, -2)
j = j.transpose(-1, -2)
# What this is doing is that it essentially creates a difference vector from
# pixel (i,j) to the center of the image, and then dividing by the focal length
# to account for scale. This tells us what direction we need to move in from the center
# of our camera to get to our point of concern. Note that we multiply the y coordinate
# by -1 to account for traditional image indexing. ((0,0) at top left).
# I also had an idea of using the inverse of the intrinsic matrix to get a ray direction
# and despite the fact that we don't get a proper depth, we don't fully need it considering
# it is the direction vector of a line
directions = torch.stack([(i - W/2) / focal, -(j - H/2) / focal, -torch.ones_like(i)], dim=-1) #.to(device)
# Apply c2w to every direction vector to make it relative to current view/pose
ray_directions = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1).to(device)
# Translation vector from c2w is our ray origin, because we want the ray to start from our camera's
# position and go outwards
ray_origins = c2w[:3, -1].expand(ray_directions.shape).to(device)
return ray_origins, ray_directions
def render_code(scene_colours, ray_origins, t_vals):
"""
This function is used to render the scene using the volume density and RGB values along with the ray origins
and values we plugged into each o + t * d ray.
This rendering is from Section 4 of the NeRF paper. It involves a ray marching algorithm and a volume density
function to render the scene.
This uses the quadature integral estimation along with the alpha compositing mentioned on the same page.
"""
sigma = nn.functional.relu(scene_colours[..., 3]) # We get sigma by applying ReLU onto output of MLP
rgb = torch.sigmoid(scene_colours[..., :3]) # We put our colors through sigmoid as mentioned in Fig 7
one_e_10 = torch.tensor([1e10], dtype=ray_origins.dtype, device=ray_origins.device) # This added to prevent multipication by 0/floating point error
# This computes the delta_i values which are the difference between t_i and t_{i+1}
delta = torch.cat((t_vals[..., 1:] - t_vals[..., :-1], one_e_10.expand(t_vals[..., :1].shape)), dim=-1)
e_sig_delt = torch.exp(-sigma * delta) # Computes intermediate e^{-sigma_i * delta_i} terms
alpha_i = 1. - e_sig_delt # Computes alpha_i
# Note, pytorch doesn't have equivalent of tf.cumprod(exclusive=True)
# So we need to roll by one element and then set the first element to 1
cprod = torch.cumprod(e_sig_delt + 1e-10, -1) # Cumulative product is basically each T_i
cprod = torch.roll(cprod, 1, -1)
cprod[..., 0] = 1
alpha_coeffs = alpha_i * cprod # Computes coefficients for each c_i
rgb_map = (alpha_coeffs[..., None] * rgb).sum(dim=-2) # Final C^(r) sum
return rgb_map # Final RGB estimate for whole scene
def get_points_along_line(Model, ray_origins, ray_directions, device, near=2, far=6, n_samples=100):
"""
Function takes in o and d vectors along with near and far bounds. We
take evenly spaced points from [near, far] which splits up the interval into
equally sized bins. Then we generate some random uniform noise from the interval [0, 1)
and then add that noise to our boundary points, this is the stratified sampling mentioned on page 6.
Essentially we take points along every line in ray_origins and ray_directions with the evenly spaced
points as inputs. We basically move along each line.
"""
# Evenly separate [near, far] into n_samples values, this creates our buckets/bins
t_vals = torch.linspace(near, far, n_samples).to(ray_origins)
# This is used to determine the shape of the noise we add to t_vals depending on the number of t_vals
noise_shape = ray_origins.shape[:-1] + (n_samples,)
# The updated t_vals with random noise
t_vals = t_vals + torch.rand(noise_shape).to(ray_origins) * (far - near) / n_samples
# Basically computes o + t * d but for every o, for every d, and for every t in t_vals
points_along_all_lines = ray_origins[..., None, :] + ray_directions[..., None, :] * t_vals[..., :, None]
# points_along_all_lines.to(device)
# Positionally encode all inputs before feeding it into the model
encoded_points_along_all_lines = positional_encoding(points_along_all_lines, device)
encoded_points_along_all_lines.to(device)
predictions = Model(encoded_points_along_all_lines)
unflattened_shape = points_along_all_lines.shape[:-1] + (4,)
predictions = torch.reshape(predictions, unflattened_shape)
# Render the scene with the alpha composition
# using the Model's guesses for the RGB values and volume density.
return render_code(predictions, ray_origins, t_vals)
"""# Training Loop"""
import gc
gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated() / 1024 ** 2)
print(torch.cuda.memory_reserved() / 1024 ** 2)
if torch.cuda.is_available():
d = "cuda"
else:
d = "cpu"
c = torch.device("cpu")
device = torch.device(d)
print(device)
model = NerfModel(hidden_channels=128)
model.to(device)
# Note that the data used here is from https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1
# from the official NeRF repo, specifically nerf_synthetic.zip
imgs, extrinsics, H, W, focal_length = load_blender_data("./drive/MyDrive/nerf_synthetic/lego", scale_factor=0.25)
# extrinsics = extrinsics.to(device)
imgs = torch.from_numpy(imgs[..., :3]) #.to(device) # Don't need alpha values, only want RGB
losses = []
xs = []
training_vid_frames = []
# In the paper, they use 5e-4 and it decays exponentially to 5e-3 but for now I decided to keep it constant at 5e-3
learning_rate = 5e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for i in range(2501):
gc.collect()
torch.cuda.empty_cache()
idx = np.random.randint(imgs.shape[0])
pred = get_points_along_line(model, *get_rays(H, W, focal_length, extrinsics[idx], device) , device, n_samples=75)
pred = pred.to(c)
loss = nn.functional.mse_loss(pred, imgs[idx])
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 10 == 0:
gc.collect()
torch.cuda.empty_cache()
print("On Iteration: ", i, "the MSE loss is: ", loss.item()) # Print loss every 100 iterations
o, d = get_rays(H, W, focal_length, extrinsics[21], device) # Always render from POV 21
o.to(device)
d.to(device)
pred = get_points_along_line(model, o, d, device)
pred = pred.to(c)
psnr = -10. * torch.log10(loss)
losses.append(float(psnr.item()))
xs.append(i)
# Plot Model Prediction
plt.figure(figsize=(15, 6))
plt.subplot(131)
plt.imshow(pred.detach().cpu().numpy())
plt.title(f"Model Prediction on Iteration {i}")
training_vid_frames.append(pred.detach().cpu().numpy())
# Plot Original fixed view
plt.subplot(132)
plt.imshow(imgs[21].detach().cpu().numpy())
plt.title("Original View")
# Plot Losses
plt.subplot(133)
plt.plot(xs, losses)
plt.title("Peak SignalToNoise Ratio Loss")
plt.show()
gc.collect()
torch.cuda.empty_cache()
f = 'training.mp4'
print(len(training_vid_frames))
imageio.mimwrite(f, training_vid_frames, fps=10, quality=9)
"""# Novel View Synthesis"""
def custom_pose(xangle, yangle, d):
"""
Given two angles, one for about the x-axis and one for about the y-axis, and
a translation distance d, this function creates a new pose matrix to be
fed into the learned model and render from a new POV.
"""
xangle, yangle = np.radians(xangle), np.radians(yangle)
# Note that all the matrices will be 4x4s with the bottom right
# entry being kept 1 to maintain the homogenous coordinates [X, Y, Z, 1]
# X matrix
x_sin = np.sin(xangle)
x_cos = np.cos(xangle)
x_rotation_matrix = np.eye(4)
x_rotation_matrix[1, 1], x_rotation_matrix[2, 2] = x_cos, x_cos
x_rotation_matrix[1, 2] = -x_sin
x_rotation_matrix[2, 1] = x_sin
# Y matrix
y_sin = np.sin(yangle)
y_cos = np.cos(yangle)
y_rotation_matrix = np.eye(4)
y_rotation_matrix[0,0], y_rotation_matrix[2, 2] = y_cos, y_cos
y_rotation_matrix[0, 2] = -y_sin
y_rotation_matrix[2, 0] = y_sin
# Z axis translation matrix
translation = np.eye(4)
translation[2, 3] = d
new_pose = y_rotation_matrix @ x_rotation_matrix @ translation
# Converting camera coordinates to world coordinates
transformation_matrix = np.array([
[-1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]])
return (transformation_matrix @ new_pose).astype(np.float32)
gc.collect()
torch.cuda.empty_cache()
video_frames = []
for angle in np.linspace(0., 360., 120, endpoint=False):
gc.collect()
torch.cuda.empty_cache()
curr_pose = torch.from_numpy(custom_pose(-angle, 60, 4)).to(device) # torch.from_numpy(custom_pose(-30, angle, 4)).to(device) for first orbiting view
o, d = get_rays(H, W, focal_length, curr_pose, device)
o.to(device)
d.to(device)
pred = get_points_along_line(model, o, d, device, n_samples=75)
pred = pred.to(c)
video_frames.append(pred.detach().cpu().numpy())
del pred
del o
del d
gc.collect()
torch.cuda.empty_cache()
f = 'test.mp4'
print(len(video_frames))
imageio.mimwrite(f, video_frames, fps=10, quality=9)