Skip to content

Commit

Permalink
Fix pyright errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored and Jianbo Ye committed Dec 28, 2023
1 parent a288bf9 commit e1d28e0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
19 changes: 10 additions & 9 deletions nerfstudio/models/gaussian_splatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def refinement_after(self, optimizers: Optimizers, step):

# After a guassian is split into two new gaussians, the original one should be also pruned.
splits_mask = torch.cat(
(splits, torch.zeros(nsamps * splits.sum() + dups.sum(), device="cuda", dtype=bool))
(splits, torch.zeros(nsamps * splits.sum() + dups.sum(), device="cuda", dtype=torch.bool))
)
deleted_mask = self.cull_gaussians(splits_mask)
param_groups = self.get_gaussian_param_groups()
Expand Down Expand Up @@ -649,7 +649,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
H,
W,
tile_bounds,
)
) # type: ignore
if (self.radii).sum() == 0:
return {"rgb": background.repeat(int(camera.height.item()), int(camera.width.item()), 1)}

Expand All @@ -661,9 +661,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
n_coefs_use = num_sh_bases(n)
rgbs = SphericalHarmonics.apply(n, viewdirs, colors_crop[:, :n_coefs_use, :])
rgbs = torch.clamp(rgbs + 0.5, min=0.0)
rgbs = SphericalHarmonics.apply(n, viewdirs, colors_crop)
rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore
else:
rgbs = torch.sigmoid(colors_crop[:, 0, :])

Expand All @@ -685,8 +684,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
H,
W,
background,
)
rgb = torch.clamp(rgb, max=1.0)
) # type: ignore
rgb = torch.clamp(rgb, max=1.0) # type: ignore
depth_im = None
if not self.training:
depth_im = RasterizeGaussians.apply( # type: ignore
Expand All @@ -700,9 +699,11 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
H,
W,
torch.ones(3, device=self.device) * 10,
)[..., 0:1]
)[
..., 0:1
] # type: ignore

return {"rgb": rgb, "depth": depth_im}
return {"rgb": rgb, "depth": depth_im} # type: ignore

def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
"""Compute and returns metrics.
Expand Down
1 change: 1 addition & 0 deletions nerfstudio/scripts/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from nerfstudio.exporter import texture_utils, tsdf_utils
from nerfstudio.exporter.exporter_utils import collect_camera_poses, generate_point_cloud, get_mesh_from_filename
from nerfstudio.exporter.marching_cubes import generate_mesh_with_multires_marching_cubes
from nerfstudio.fields.sdf_field import SDFField # noqa
from nerfstudio.models.gaussian_splatting import GaussianSplattingModel
from nerfstudio.pipelines.base_pipeline import Pipeline, VanillaPipeline
from nerfstudio.utils.eval_utils import eval_setup
Expand Down
16 changes: 8 additions & 8 deletions nerfstudio/scripts/gsplat_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class CameraInfo(NamedTuple):
uid: int
viewmat: torch.Tensor
projmat: torch.Tensor
FovY: np.array
FovX: np.array
FovY: float
FovX: float
image_name: str
width: int
height: int
Expand Down Expand Up @@ -149,15 +149,15 @@ def render(self, cam_info: CameraInfo, background: torch.Tensor):
cam_info.height,
cam_info.width,
tile_bounds,
)
) # type: ignore
torch.cuda.synchronize()

if self.n_sh_level > 0:
c2w = torch.inverse(viewmat)
viewdirs = self.xyz - c2w[None, :3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
rgbs = SphericalHarmonics.apply(self.n_sh_level, viewdirs, self.sh_coefs)
rgbs = torch.clamp(rgbs + 0.5, min=0.0)
rgbs = torch.clamp(rgbs + 0.5, min=0.0) # type: ignore
else:
rgbs = self.colors
rgb = RasterizeGaussians.apply(
Expand All @@ -171,10 +171,10 @@ def render(self, cam_info: CameraInfo, background: torch.Tensor):
cam_info.height,
cam_info.width,
background,
)
) # type: ignore
torch.cuda.synchronize()

return torch.clamp(rgb, max=1.0)
return torch.clamp(rgb, max=1.0) # type: ignore


def projection_matrix(znear, zfar, fovx, fovy, device="cpu"):
Expand Down Expand Up @@ -203,7 +203,7 @@ def focal2fov(focal, pixels):
return 2 * math.atan(pixels / (2 * focal))


def load_trajectory(filename: Path) -> List[CameraInfo]:
def load_trajectory(filename: Path):
"""Loads the cameras from a trajectory.json file like the ones used for rendering in nerfstudio
Parameters
Expand Down Expand Up @@ -351,7 +351,7 @@ def main(
assert len(bg_color) == 3
bg = bg_color

render_set(render_path, point_cloud, cam_infos, bg)
render_set(render_path=render_path, point_cloud=point_cloud, cam_infos=cam_infos, bg_color=bg)


def entrypoint():
Expand Down

0 comments on commit e1d28e0

Please sign in to comment.