Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian splatting support for Aria #2785

Merged
merged 10 commits into from
Jan 19, 2024
2 changes: 1 addition & 1 deletion nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def fisheye624_unproject_helper(uv, params, max_iters: int = 5):
function so this solves an optimization problem using Newton's method to get
the inverse.
Inputs:
uv: BxNx3 tensor of 2D pixels to be projected
uv: BxNx2 tensor of 2D pixels to be unprojected
params: Bx16 tensor of Fisheye624 parameters formatted like this:
[f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
or Bx15 tensor of Fisheye624 parameters formatted like this:
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def _compute_rays_for_vr180(

assert distortion_params is not None
masked_coords = pcoord_stack[coord_mask, :]
# The fisheye unprojection does not rely on planar/pinhold unprojection, thus the method needs
# The fisheye unprojection does not rely on planar/pinhole unprojection, thus the method needs
# to access the focal length and principle points directly.
camera_params = torch.cat(
[
Expand Down
276 changes: 166 additions & 110 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.nn import Parameter
from tqdm import tqdm

from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, TDataset
Expand Down Expand Up @@ -135,70 +136,20 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
# update the width, height
self.train_dataset.cameras.width[i] = w
self.train_dataset.cameras.height[i] = h
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
data["mask"] = torch.from_numpy(mask).bool()
K = newK
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask

cached_train.append(data)

self.train_dataset.cameras.fx[i] = float(K[0, 0])
self.train_dataset.cameras.fy[i] = float(K[1, 1])
self.train_dataset.cameras.cx[i] = float(K[0, 2])
self.train_dataset.cameras.cy[i] = float(K[1, 2])
self.train_dataset.cameras.width[i] = image.shape[1]
self.train_dataset.cameras.height[i] = image.shape[0]

CONSOLE.log("Caching / undistorting eval images")
for i in tqdm(range(len(self.eval_dataset)), leave=False):
Expand All @@ -210,68 +161,20 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
# update the width, height
self.eval_dataset.cameras.width[i] = w
self.eval_dataset.cameras.height[i] = h
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
data["mask"] = torch.from_numpy(mask).bool()
K = newK
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask

cached_eval.append(data)

self.eval_dataset.cameras.fx[i] = float(K[0, 0])
self.eval_dataset.cameras.fy[i] = float(K[1, 1])
self.eval_dataset.cameras.cx[i] = float(K[0, 2])
self.eval_dataset.cameras.cy[i] = float(K[1, 2])
self.eval_dataset.cameras.width[i] = image.shape[1]
self.eval_dataset.cameras.height[i] = image.shape[0]

if cache_images_option == "gpu":
for cache in cached_train:
Expand Down Expand Up @@ -416,3 +319,156 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
assert len(self.eval_dataset.cameras.shape) == 1, "Assumes single batch dimension"
camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device)
return camera, data


def _undistort_image(
camera: Cameras, distortion_params: np.ndarray, data: dict, image: np.ndarray, K: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Optional[torch.Tensor]]:
mask = None
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
mask = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
mask = torch.from_numpy(mask).bool()
K = newK
elif camera.camera_type.item() == CameraType.FISHEYE624.value:
fisheye624_params = torch.cat(
[camera.fx, camera.fy, camera.cx, camera.cy, torch.from_numpy(distortion_params)], dim=0
)
assert fisheye624_params.shape == (16,)
assert (
"mask" not in data
and camera.metadata is not None
and "fisheye_crop_radius" in camera.metadata
and isinstance(camera.metadata["fisheye_crop_radius"], float)
)
fisheye_crop_radius = camera.metadata["fisheye_crop_radius"]

# Approximate the FOV of the unmasked region of the camera.
upper, lower, left, right = fisheye624_unproject_helper(
torch.tensor(
[
[camera.cx, camera.cy - fisheye_crop_radius],
[camera.cx, camera.cy + fisheye_crop_radius],
[camera.cx - fisheye_crop_radius, camera.cy],
[camera.cx + fisheye_crop_radius, camera.cy],
],
dtype=torch.float32,
)[None],
params=fisheye624_params[None],
).squeeze(dim=0)
fov_radians = torch.max(
torch.acos(torch.sum(upper * lower / torch.linalg.norm(upper) / torch.linalg.norm(lower))),
torch.acos(torch.sum(left * right / torch.linalg.norm(left) / torch.linalg.norm(right))),
)

# Heuristics to determine parameters of an undistorted image.
undist_h = int(fisheye_crop_radius * 2)
undist_w = int(fisheye_crop_radius * 2)
undistort_focal = undist_h / (2 * torch.tan(fov_radians / 2.0))
undist_K = torch.eye(3)
undist_K[0, 0] = undistort_focal # fx
undist_K[1, 1] = undistort_focal # fy
undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0).
undist_K[1, 2] = (undist_h - 1) / 2.0 # cy

# Undistorted 2D coordinates -> rays -> reproject to distorted UV coordinates.
undist_uv_homog = torch.stack(
[
*torch.meshgrid(
torch.arange(undist_w, dtype=torch.float32),
torch.arange(undist_h, dtype=torch.float32),
),
torch.ones((undist_w, undist_h), dtype=torch.float32),
],
dim=-1,
)
assert undist_uv_homog.shape == (undist_w, undist_h, 3)
dist_uv = (
fisheye624_project(
xyz=(
torch.einsum(
"ij,bj->bi",
torch.linalg.inv(undist_K),
undist_uv_homog.reshape((undist_w * undist_h, 3)),
)[None]
),
params=fisheye624_params[None, :],
)
.reshape((undist_w, undist_h, 2))
.numpy()
)
map1 = dist_uv[..., 1]
map2 = dist_uv[..., 0]

# Use correspondence to undistort image.
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)

# Compute undistorted mask as well.
dist_h = camera.height.item()
dist_w = camera.width.item()
mask = np.mgrid[:dist_h, :dist_w]
mask[0, ...] -= dist_h // 2
mask[1, ...] -= dist_w // 2
mask = np.linalg.norm(mask, axis=0) < fisheye_crop_radius
mask = torch.from_numpy(
cv2.remap(
mask.astype(np.uint8) * 255,
map1,
map2,
interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0,
)
/ 255.0
).bool()
K = undist_K.numpy()
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

return K, image, mask
Loading
Loading