Skip to content

Commit

Permalink
Heuristics for computed undistorted parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 19, 2024
2 parents f5057d6 + f55bd04 commit 88f9f04
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
58 changes: 41 additions & 17 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
import cv2
import numpy as np
import torch
from torch.nn import Parameter
from tqdm import tqdm

from nerfstudio.cameras.camera_utils import fisheye624_project
from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, TDataset
Expand All @@ -43,6 +40,8 @@
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.utils.misc import get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE
from torch.nn import Parameter
from tqdm import tqdm


@dataclass
Expand Down Expand Up @@ -377,18 +376,48 @@ def _undistort_image(
mask = torch.from_numpy(mask).bool()
K = newK
elif camera.camera_type.item() == CameraType.FISHEYE624.value:
fisheye624_params = np.concatenate([camera.fx, camera.fy, camera.cx, camera.cy, distortion_params], axis=0)
fisheye624_params = torch.cat(
[camera.fx, camera.fy, camera.cx, camera.cy, torch.from_numpy(distortion_params)], dim=0
)
assert fisheye624_params.shape == (16,)
assert (
"mask" not in data
and camera.metadata is not None
and "fisheye_crop_radius" in camera.metadata
and isinstance(camera.metadata["fisheye_crop_radius"], float)
)
fisheye_crop_radius = camera.metadata["fisheye_crop_radius"]

# Desired parameters of the undistorted image.
import warnings

warnings.warn("Fisheye624 support in the full images datamanager currently assumes data from Project Aria.")
undist_h = 500
undist_w = 500

# Approximate the FOV of the unmasked region of the camera.
upper, lower, left, right = fisheye624_unproject_helper(
torch.tensor(
[
[camera.cx, camera.cy - fisheye_crop_radius],
[camera.cx, camera.cy + fisheye_crop_radius],
[camera.cx - fisheye_crop_radius, camera.cy],
[camera.cx + fisheye_crop_radius, camera.cy],
],
dtype=torch.float32,
)[None],
params=fisheye624_params[None],
).squeeze(dim=0)
fov_radians = torch.max(
torch.acos(torch.sum(upper * lower / torch.linalg.norm(upper) / torch.linalg.norm(lower))),
torch.acos(torch.sum(left * right / torch.linalg.norm(left) / torch.linalg.norm(right))),
)

# Heuristic to determine the parameters of the undistorted image.
undist_h = int(fisheye_crop_radius * 2)
undist_w = int(fisheye_crop_radius * 2)
undistort_focal = undist_h / (2 * torch.tan(fov_radians / 2.0))
undist_K = torch.eye(3)
undist_K[0, 0] = 150.0 # fx
undist_K[1, 1] = 150.0 # fy
undist_K[0, 0] = undistort_focal # fx
undist_K[1, 1] = undistort_focal # fy
undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0).
undist_K[1, 2] = (undist_h - 1) / 2.0 # cy

Expand All @@ -413,7 +442,7 @@ def _undistort_image(
undistort_uv_homog.reshape((undist_w * undist_h, 3)),
)[None]
),
params=torch.from_numpy(fisheye624_params[None, :]),
params=fisheye624_params[None, :],
)
.reshape((undist_w, undist_h, 2))
.numpy()
Expand All @@ -427,16 +456,10 @@ def _undistort_image(
dist_w = camera.width.item()

# Compute distorted mask.
assert (
"mask" not in data
and camera.metadata is not None
and "fisheye_crop_radius" in camera.metadata
and isinstance(camera.metadata["fisheye_crop_radius"], float)
)
mask = np.mgrid[:dist_h, :dist_w]
mask[0, ...] -= dist_h // 2
mask[1, ...] -= dist_w // 2
mask = np.linalg.norm(mask, axis=0) < camera.metadata["fisheye_crop_radius"]
mask = np.linalg.norm(mask, axis=0) < fisheye_crop_radius
mask = torch.from_numpy(
cv2.remap(
mask.astype(np.uint8) * 255,
Expand All @@ -449,6 +472,7 @@ def _undistort_image(
/ 255.0
).bool()
K = undist_K.numpy()

else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

Expand Down
4 changes: 3 additions & 1 deletion nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def _generate_dataparser_outputs(self, split="train"):
# - applied_transform contains the transformation to saved coordinates from original data coordinates.
applied_transform = None
colmap_path = self.config.data / "colmap/sparse/0"
if "applied_transform" not in meta and colmap_path.exists():
if "applied_transform" in meta:
applied_transform = torch.tensor(meta["applied_transform"], dtype=transform_matrix.dtype)
elif colmap_path.exists():
# For converting from colmap, this was the effective value of applied_transform that was being
# used before we added the applied_transform field to the output dataformat.
meta["applied_transform"] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, -1, 0]]
Expand Down

0 comments on commit 88f9f04

Please sign in to comment.