Skip to content

Commit

Permalink
Gaussian splatting support for Aria
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 18, 2024
1 parent 368c9ec commit 51b54eb
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 123 deletions.
2 changes: 1 addition & 1 deletion nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def fisheye624_unproject_helper(uv, params, max_iters: int = 5):
function so this solves an optimization problem using Newton's method to get
the inverse.
Inputs:
uv: BxNx3 tensor of 2D pixels to be projected
uv: BxNx2 tensor of 2D pixels to be unprojected
params: Bx16 tensor of Fisheye624 parameters formatted like this:
[f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
or Bx15 tensor of Fisheye624 parameters formatted like this:
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def _compute_rays_for_vr180(

assert distortion_params is not None
masked_coords = pcoord_stack[coord_mask, :]
# The fisheye unprojection does not rely on planar/pinhold unprojection, thus the method needs
# The fisheye unprojection does not rely on planar/pinhole unprojection, thus the method needs
# to access the focal length and principle points directly.
camera_params = torch.cat(
[
Expand Down
255 changes: 145 additions & 110 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.nn import Parameter
from tqdm import tqdm

from nerfstudio.cameras.camera_utils import fisheye624_project
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, TDataset
Expand Down Expand Up @@ -133,70 +134,20 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
# update the width, height
self.train_dataset.cameras.width[i] = w
self.train_dataset.cameras.height[i] = h
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
data["mask"] = torch.from_numpy(mask).bool()
K = newK
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask

cached_train.append(data)

self.train_dataset.cameras.fx[i] = float(K[0, 0])
self.train_dataset.cameras.fy[i] = float(K[1, 1])
self.train_dataset.cameras.cx[i] = float(K[0, 2])
self.train_dataset.cameras.cy[i] = float(K[1, 2])
self.train_dataset.cameras.width[i] = image.shape[1]
self.train_dataset.cameras.height[i] = image.shape[0]

CONSOLE.log("Caching / undistorting eval images")
for i in tqdm(range(len(self.eval_dataset)), leave=False):
Expand All @@ -208,61 +159,11 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
# update the width, height
self.eval_dataset.cameras.width[i] = w
self.eval_dataset.cameras.height[i] = h
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
data["mask"] = torch.from_numpy(mask).bool()
K = newK
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask

cached_eval.append(data)

Expand Down Expand Up @@ -414,3 +315,137 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
assert len(self.eval_dataset.cameras.shape) == 1, "Assumes single batch dimension"
camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device)
return camera, data


def _undistort_image(
camera: Cameras, distortion_params: np.ndarray, data: dict, image: np.ndarray, K: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Optional[torch.Tensor]]:
mask = None
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
mask = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
mask = torch.from_numpy(mask).bool()
K = newK
elif camera.camera_type.item() == CameraType.FISHEYE624.value:
fisheye624_params = np.concatenate([camera.fx, camera.fy, camera.cx, camera.cy, distortion_params], axis=0)
assert fisheye624_params.shape == (16,)

# Desired parameters of the undistorted image.
import warnings

warnings.warn("Fisheye624 support in the full images datamanager currently assumes data from Project Aria.")
undist_h = 500
undist_w = 500
undist_K = torch.eye(3)
undist_K[0, 0] = 150.0 # fx
undist_K[1, 1] = 150.0 # fy
undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0).
undist_K[1, 2] = (undist_h - 1) / 2.0 # cy

undistort_uv_homog = torch.stack(
[
*torch.meshgrid(
torch.arange(undist_w, dtype=torch.float32),
torch.arange(undist_h, dtype=torch.float32),
),
torch.ones((undist_w, undist_h), dtype=torch.float32),
],
dim=-1,
)
assert undistort_uv_homog.shape == (undist_w, undist_h, 3)

unproj = (
fisheye624_project(
xyz=(
torch.einsum(
"ij,bj->bi",
torch.linalg.inv(undist_K),
undistort_uv_homog.reshape((undist_w * undist_h, 3)),
)[None]
),
params=torch.from_numpy(fisheye624_params[None, :]),
)
.reshape((undist_w, undist_h, 2))
.numpy()
)
map1 = unproj[..., 1]
map2 = unproj[..., 0]

image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)

dist_h = camera.height.item()
dist_w = camera.width.item()

# Compute distorted mask.
assert (
"mask" not in data
and camera.metadata is not None
and "fisheye_crop_radius" in camera.metadata
and isinstance(camera.metadata["fisheye_crop_radius"], float)
)
mask = np.mgrid[:dist_h, :dist_w]
mask[0, ...] -= dist_h // 2
mask[1, ...] -= dist_w // 2
mask = np.linalg.norm(mask, axis=0) < camera.metadata["fisheye_crop_radius"]
mask = torch.from_numpy(
cv2.remap(
mask.astype(np.uint8) * 255,
map1,
map2,
interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0,
)
/ 255.0
).bool()
K = undist_K.numpy()
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

return K, image, mask
1 change: 1 addition & 0 deletions nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def _generate_dataparser_outputs(self, split="train"):
if "ply_file_path" in meta:
ply_file_path = data_dir / meta["ply_file_path"]
metadata.update(self._load_3D_points(ply_file_path, transform_matrix, scale_factor))
print("loaded points!")

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
Expand Down
29 changes: 18 additions & 11 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,18 @@ class GaussianSplattingModel(Model):

config: GaussianSplattingModelConfig

def __init__(self, *args, **kwargs):
if "seed_points" in kwargs:
self.seed_pts = kwargs["seed_points"]
else:
self.seed_pts = None
def __init__(
self,
*args,
seed_points: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
):
self.seed_points = seed_points
super().__init__(*args, **kwargs)

def populate_modules(self):
if self.seed_pts is not None and not self.config.random_init:
self.means = torch.nn.Parameter(self.seed_pts[0]) # (Location, Color)
if self.seed_points is not None and not self.config.random_init:
self.means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
else:
self.means = torch.nn.Parameter((torch.rand((500000, 3)) - 0.5) * 10)
self.xys_grad_norm = None
Expand All @@ -184,14 +186,19 @@ def populate_modules(self):
self.quats = torch.nn.Parameter(random_quat_tensor(self.num_points))
dim_sh = num_sh_bases(self.config.sh_degree)

if self.seed_pts is not None and not self.config.random_init:
shs = torch.zeros((self.seed_pts[1].shape[0], dim_sh, 3)).float().cuda()
if (
self.seed_points is not None
and not self.config.random_init
# We can have colors without points.
and self.seed_points[1].shape[0] > 0
):
shs = torch.zeros((self.seed_points[1].shape[0], dim_sh, 3)).float().cuda()
if self.config.sh_degree > 0:
shs[:, 0, :3] = RGB2SH(self.seed_pts[1] / 255)
shs[:, 0, :3] = RGB2SH(self.seed_points[1] / 255)
shs[:, 1:, 3:] = 0.0
else:
CONSOLE.log("use color only optimization with sigmoid activation")
shs[:, 0, :3] = torch.logit(self.seed_pts[1] / 255, eps=1e-10)
shs[:, 0, :3] = torch.logit(self.seed_points[1] / 255, eps=1e-10)
self.features_dc = torch.nn.Parameter(shs[:, 0, :])
self.features_rest = torch.nn.Parameter(shs[:, 1:, :])
else:
Expand Down
17 changes: 17 additions & 0 deletions nerfstudio/scripts/datasets/process_project_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import Dict, List

import numpy as np
import open3d as o3d
import pandas as pd
import tyro
from PIL import Image

Expand Down Expand Up @@ -219,6 +221,21 @@ def main(self) -> None:
"fisheye_crop_radius": rgb_valid_radius,
}

# save global point cloud, which is useful for Gaussian Splatting.
points_path = self.mps_data_dir / "global_points.csv.gz"
if points_path.exists():
print("Found global points, saving to PLY...")
df = pd.read_csv(points_path, compression="gzip")
points = df[["px_world", "py_world", "pz_world"]].values
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
ply_file_path = self.output_dir / "global_points.ply"
o3d.io.write_point_cloud(str(ply_file_path), pcd)

nerfstudio_frames["ply_file_path"] = "global_points.ply"
else:
print("No global points found!")

# write the json out to disk as transforms.json
print("Writing transforms.json")
transform_file = self.output_dir / "transforms.json"
Expand Down
Loading

0 comments on commit 51b54eb

Please sign in to comment.