Skip to content

Commit

Permalink
Add orthophoto rendering support. (#2648)
Browse files Browse the repository at this point in the history
* add orthophoto CameraType.

* test case of orthophoto camera.

* Fix the rotation errors caused by the usage of left-handed system

* Optimized ortho-cam,  it can be used with other camera models now.

* fix TypeError of coords in `test_multi_camera_type()`

---------

Co-authored-by: Alexander Kristoffersen <[email protected]>
  • Loading branch information
LeaFendd and akristoffersen authored Jan 8, 2024
1 parent cc61caf commit 1e23781
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
17 changes: 17 additions & 0 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class CameraType(Enum):
OMNIDIRECTIONALSTEREO_R = auto()
VR180_L = auto()
VR180_R = auto()
ORTHOPHOTO = auto()
FISHEYE624 = auto()


Expand All @@ -63,6 +64,7 @@ class CameraType(Enum):
"OMNIDIRECTIONALSTEREO_R": CameraType.OMNIDIRECTIONALSTEREO_R,
"VR180_L": CameraType.VR180_L,
"VR180_R": CameraType.VR180_R,
"ORTHOPHOTO": CameraType.ORTHOPHOTO,
"FISHEYE624": CameraType.FISHEYE624,
}

Expand Down Expand Up @@ -834,6 +836,21 @@ def _compute_rays_for_vr180(
# assign final camera origins
c2w[..., :3, 3] = vr180_origins

elif CameraType.ORTHOPHOTO.value in cam_types:
# here the focal length determine the imaging area, the smaller fx, the bigger imaging area.
mask = (self.camera_type[true_indices] == CameraType.ORTHOPHOTO.value).squeeze(-1)
dir_mask = torch.stack([mask, mask, mask], dim=0)
# in orthophoto cam, all rays have same direction, dir = R @ [0, 0, 1], R will be applied following.
directions_stack[dir_mask] = torch.tensor(
[0.0, 0.0, -1.0], dtype=directions_stack.dtype, device=directions_stack.device
)
# in orthophoto cam, ray origins are grids, then transform grids with c2w, c2w @ P.
grids = coord[mask]
grids[..., 1] *= -1.0 # convert to left-hand system.
grids = torch.cat([grids, torch.zeros_like(grids[..., -1:]), torch.ones_like(grids[..., -1:])], dim=-1)
grids = torch.matmul(c2w[mask], grids[..., None]).squeeze(-1)
c2w[..., :3, 3][mask] = grids

elif CameraType.FISHEYE624.value in cam_types:
mask = (self.camera_type[true_indices] == CameraType.FISHEYE624.value).squeeze(-1) # (num_rays)
coord_mask = torch.stack([mask, mask, mask], dim=0)
Expand Down
66 changes: 66 additions & 0 deletions tests/cameras/test_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,70 @@ def test_pinhole_camera():
pinhole_camera.generate_rays(camera_indices=0, coords=coords)


def test_orthophoto_camera():
"""Test that the orthographic camera model works."""
c2w = torch.eye(4)[None, :3, :]
# apply R and T.
R = torch.Tensor(
[
[0.5, -0.14644661, 0.85355339],
[0.5, 0.85355339, -0.14644661],
[-0.70710678, 0.5, 0.5],
]
).unsqueeze(0)
T = torch.Tensor([[0.5, 0, -0.5]])
c2w[..., :3, :3] = R
c2w[..., :3, 3] = T

ortho_cam = Cameras(cx=1.5, cy=1.5, fx=1.0, fy=1.0, camera_to_worlds=c2w, camera_type=CameraType.ORTHOPHOTO)
ortho_rays = ortho_cam.generate_rays(camera_indices=0)
# campare with `PERSPECTIVE` to validate `ORTHOPHOTO`.
pinhole_cam = Cameras(cx=1.5, cy=1.5, fx=1.0, fy=1.0, camera_to_worlds=c2w, camera_type=CameraType.PERSPECTIVE)
pinhole_rays = pinhole_cam.generate_rays(camera_indices=0)

assert ortho_rays.shape == pinhole_rays.shape
# `ortho_rays.directions` should equal to the center ray of `pinhole_rays.directions`.
assert torch.allclose(
ortho_rays.directions, pinhole_rays.directions[1, 1].broadcast_to(ortho_rays.directions.shape)
)
# `ortho_rays.origins` should be grid points with a mean value of `pinhole_rays.origins`.
assert torch.allclose(ortho_rays.origins.mean(dim=(0, 1)), pinhole_rays.origins[1, 1])


def test_multi_camera_type():
"""Test that the orthographic camera model works."""
# here we test two different camera types.
num_cams = [2]
c2w = torch.eye(4)[None, :3, :].broadcast_to(*num_cams, 3, 4)
cx = torch.Tensor([20]).broadcast_to(*num_cams, 1)
cy = torch.Tensor([10]).broadcast_to(*num_cams, 1)
fx = torch.Tensor([10]).broadcast_to(*num_cams, 1)
fy = torch.Tensor([10]).broadcast_to(*num_cams, 1)
h = torch.Tensor([40]).long().broadcast_to(*num_cams, 1)
w = torch.Tensor([20]).long().broadcast_to(*num_cams, 1)
camera_type = [CameraType.PERSPECTIVE, CameraType.ORTHOPHOTO]
multitype_cameras = Cameras(c2w, fx, fy, cx, cy, w, h, camera_type=camera_type)

# test `generate_rays`, 1 cam.
ray0 = multitype_cameras.generate_rays(camera_indices=0)
assert ray0.shape == torch.Size([40, 20])

# test `generate_rays`, multiple cams.
num_rays = [30, 30]
camera_indices = torch.randint(0, 2, [*num_rays, len(num_cams)]) # (*num_rays, num_cameras_batch_dims)
ray1 = multitype_cameras.generate_rays(camera_indices=camera_indices)
assert ray1.shape == torch.Size([40, 20, *num_rays])

# test `_generate_rays_from_coords`, 1 cam.
coords = torch.randint(0, 10, [*num_rays, 2]).float() # (*num_rays 2)
ray2 = multitype_cameras.generate_rays(camera_indices=0, coords=coords)
assert ray2.shape == torch.Size([*num_rays])

# test `_generate_rays_from_coords`, multiple cam.
ray3 = multitype_cameras.generate_rays(camera_indices=camera_indices, coords=coords)
assert ray3.shape == torch.Size([*num_rays])


def test_equirectangular_camera():
"""Test that the equirectangular camera model works."""
height = 100 # width is twice the height
Expand Down Expand Up @@ -311,3 +375,5 @@ def _check_cam_shapes(cam: Cameras, _batch_size):
test_pinhole_camera()
test_equirectangular_camera()
test_camera_as_tensordataclass()
test_orthophoto_camera()
test_multi_camera_type()

0 comments on commit 1e23781

Please sign in to comment.