Skip to content

Commit

Permalink
a big cleanup: data fix seed; ngp cubic box
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Dec 16, 2023
1 parent 6ab97ae commit e6647a0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
9 changes: 6 additions & 3 deletions examples/datasets/nerf_360_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def __init__(
)
self.K = torch.tensor(self.K).to(torch.float32).to(device)
self.height, self.width = self.images.shape[1:3]
self.g = torch.Generator(device=device)
self.g.manual_seed(42)

def __len__(self):
return len(self.images)
Expand All @@ -274,7 +276,7 @@ def preprocess(self, data):

if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device)
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
Expand Down Expand Up @@ -304,14 +306,15 @@ def fetch_data(self, index):
len(self.images),
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.width, size=(num_rays,), device=self.images.device
0, self.width, size=(num_rays,), device=self.images.device, generator=self.g
)
y = torch.randint(
0, self.height, size=(num_rays,), device=self.images.device
0, self.height, size=(num_rays,), device=self.images.device, generator=self.g
)
else:
image_id = [index]
Expand Down
9 changes: 6 additions & 3 deletions examples/datasets/nerf_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def __init__(
self.camtoworlds = self.camtoworlds.to(device)
self.K = self.K.to(device)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
self.g = torch.Generator(device=device)
self.g.manual_seed(42)

def __len__(self):
return len(self.images)
Expand All @@ -141,7 +143,7 @@ def preprocess(self, data):

if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device)
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
Expand Down Expand Up @@ -172,14 +174,15 @@ def fetch_data(self, index):
len(self.images),
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device
0, self.WIDTH, size=(num_rays,), device=self.images.device, generator=self.g
)
y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device
0, self.HEIGHT, size=(num_rays,), device=self.images.device, generator=self.g
)
else:
image_id = [index]
Expand Down
7 changes: 7 additions & 0 deletions examples/radiance_fields/ngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def __init__(
super().__init__()
if not isinstance(aabb, torch.Tensor):
aabb = torch.tensor(aabb, dtype=torch.float32)

# Turns out rectangle aabb will leads to uneven collision so bad performance.
# We enforce a cube aabb here.
center = (aabb[..., :num_dim] + aabb[..., num_dim:]) / 2.0
size = (aabb[..., num_dim:] - aabb[..., :num_dim]).max()
aabb = torch.cat([center - size / 2.0, center + size / 2.0], dim=-1)

self.register_buffer("aabb", aabb)
self.num_dim = num_dim
self.use_viewdirs = use_viewdirs
Expand Down

0 comments on commit e6647a0

Please sign in to comment.