Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian splatting support for Aria #2785

Merged
merged 10 commits into from
Jan 19, 2024
2 changes: 1 addition & 1 deletion nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def fisheye624_unproject_helper(uv, params, max_iters: int = 5):
function so this solves an optimization problem using Newton's method to get
the inverse.
Inputs:
uv: BxNx3 tensor of 2D pixels to be projected
uv: BxNx2 tensor of 2D pixels to be unprojected
params: Bx16 tensor of Fisheye624 parameters formatted like this:
[f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
or Bx15 tensor of Fisheye624 parameters formatted like this:
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def _compute_rays_for_vr180(

assert distortion_params is not None
masked_coords = pcoord_stack[coord_mask, :]
# The fisheye unprojection does not rely on planar/pinhold unprojection, thus the method needs
# The fisheye unprojection does not rely on planar/pinhole unprojection, thus the method needs
# to access the focal length and principle points directly.
camera_params = torch.cat(
[
Expand Down
257 changes: 147 additions & 110 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.nn import Parameter
from tqdm import tqdm

from nerfstudio.cameras.camera_utils import fisheye624_project
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, TDataset
Expand Down Expand Up @@ -133,70 +134,20 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
# update the width, height
self.train_dataset.cameras.width[i] = w
self.train_dataset.cameras.height[i] = h
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
data["mask"] = torch.from_numpy(mask).bool()
K = newK
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask

cached_train.append(data)

self.train_dataset.cameras.fx[i] = float(K[0, 0])
self.train_dataset.cameras.fy[i] = float(K[1, 1])
self.train_dataset.cameras.cx[i] = float(K[0, 2])
self.train_dataset.cameras.cy[i] = float(K[1, 2])
self.train_dataset.cameras.width[i] = image.shape[1]
self.train_dataset.cameras.height[i] = image.shape[0]

CONSOLE.log("Caching / undistorting eval images")
for i in tqdm(range(len(self.eval_dataset)), leave=False):
Expand All @@ -208,68 +159,20 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
# update the width, height
self.eval_dataset.cameras.width[i] = w
self.eval_dataset.cameras.height[i] = h
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
data["mask"] = torch.from_numpy(mask).bool()
K = newK
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask

cached_eval.append(data)

self.eval_dataset.cameras.fx[i] = float(K[0, 0])
self.eval_dataset.cameras.fy[i] = float(K[1, 1])
self.eval_dataset.cameras.cx[i] = float(K[0, 2])
self.eval_dataset.cameras.cy[i] = float(K[1, 2])
self.eval_dataset.cameras.width[i] = image.shape[1]
self.eval_dataset.cameras.height[i] = image.shape[0]

if cache_images_option == "gpu":
for cache in cached_train:
Expand Down Expand Up @@ -414,3 +317,137 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
assert len(self.eval_dataset.cameras.shape) == 1, "Assumes single batch dimension"
camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device)
return camera, data


def _undistort_image(
camera: Cameras, distortion_params: np.ndarray, data: dict, image: np.ndarray, K: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Optional[torch.Tensor]]:
mask = None
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
mask = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
mask = torch.from_numpy(mask).bool()
K = newK
elif camera.camera_type.item() == CameraType.FISHEYE624.value:
fisheye624_params = np.concatenate([camera.fx, camera.fy, camera.cx, camera.cy, distortion_params], axis=0)
assert fisheye624_params.shape == (16,)

# Desired parameters of the undistorted image.
import warnings

warnings.warn("Fisheye624 support in the full images datamanager currently assumes data from Project Aria.")
undist_h = 500
undist_w = 500
undist_K = torch.eye(3)
undist_K[0, 0] = 150.0 # fx
undist_K[1, 1] = 150.0 # fy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seems pretty small from a quick eye glance! is this just hardcoded here for testing/prototyping?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not come from the camera calibration itself so it would work for various Aria recording data profile

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, isn't it weird for these to be hardcoded?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was test code; I added some heuristics for computing this automatically! Not sure if there's a better way...

undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0).
undist_K[1, 2] = (undist_h - 1) / 2.0 # cy

undistort_uv_homog = torch.stack(
[
*torch.meshgrid(
torch.arange(undist_w, dtype=torch.float32),
torch.arange(undist_h, dtype=torch.float32),
),
torch.ones((undist_w, undist_h), dtype=torch.float32),
],
dim=-1,
)
assert undistort_uv_homog.shape == (undist_w, undist_h, 3)

unproj = (
fisheye624_project(
xyz=(
torch.einsum(
"ij,bj->bi",
torch.linalg.inv(undist_K),
undistort_uv_homog.reshape((undist_w * undist_h, 3)),
)[None]
),
params=torch.from_numpy(fisheye624_params[None, :]),
)
.reshape((undist_w, undist_h, 2))
.numpy()
)
map1 = unproj[..., 1]
map2 = unproj[..., 0]

image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)

dist_h = camera.height.item()
dist_w = camera.width.item()

# Compute distorted mask.
assert (
"mask" not in data
and camera.metadata is not None
and "fisheye_crop_radius" in camera.metadata
and isinstance(camera.metadata["fisheye_crop_radius"], float)
)
mask = np.mgrid[:dist_h, :dist_w]
mask[0, ...] -= dist_h // 2
mask[1, ...] -= dist_w // 2
mask = np.linalg.norm(mask, axis=0) < camera.metadata["fisheye_crop_radius"]
mask = torch.from_numpy(
cv2.remap(
mask.astype(np.uint8) * 255,
map1,
map2,
interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0,
)
/ 255.0
).bool()
K = undist_K.numpy()
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

return K, image, mask
1 change: 1 addition & 0 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def _generate_dataparser_outputs(self, split="train"):
if "ply_file_path" in meta:
ply_file_path = data_dir / meta["ply_file_path"]
metadata.update(self._load_3D_points(ply_file_path, transform_matrix, scale_factor))
print("loaded points!")

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
Expand Down
29 changes: 18 additions & 11 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,18 @@ class GaussianSplattingModel(Model):

config: GaussianSplattingModelConfig

def __init__(self, *args, **kwargs):
if "seed_points" in kwargs:
self.seed_pts = kwargs["seed_points"]
else:
self.seed_pts = None
def __init__(
self,
*args,
seed_points: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
):
self.seed_points = seed_points
super().__init__(*args, **kwargs)

def populate_modules(self):
if self.seed_pts is not None and not self.config.random_init:
self.means = torch.nn.Parameter(self.seed_pts[0]) # (Location, Color)
if self.seed_points is not None and not self.config.random_init:
self.means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
else:
self.means = torch.nn.Parameter((torch.rand((500000, 3)) - 0.5) * 10)
self.xys_grad_norm = None
Expand All @@ -184,14 +186,19 @@ def populate_modules(self):
self.quats = torch.nn.Parameter(random_quat_tensor(self.num_points))
dim_sh = num_sh_bases(self.config.sh_degree)

if self.seed_pts is not None and not self.config.random_init:
shs = torch.zeros((self.seed_pts[1].shape[0], dim_sh, 3)).float().cuda()
if (
self.seed_points is not None
and not self.config.random_init
# We can have colors without points.
and self.seed_points[1].shape[0] > 0
):
shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda()
if self.config.sh_degree > 0:
shs[:, 0, :3] = RGB2SH(self.seed_pts[1] / 255)
shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255)
shs[:, 1:, 3:] = 0.0
else:
CONSOLE.log("use color only optimization with sigmoid activation")
shs[:, 0, :3] = torch.logit(self.seed_pts[1] / 255, eps=1e-10)
shs[:, 0, :3] = torch.logit(self.seed_points[1] / 255, eps=1e-10)
self.features_dc = torch.nn.Parameter(shs[:, 0, :])
self.features_rest = torch.nn.Parameter(shs[:, 1:, :])
else:
Expand Down
Loading
Loading