diff --git a/requirements.txt b/requirements.txt index 3ab14b29..ec04018e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ jupyter scipy pymcubes trimesh -dearpygui \ No newline at end of file +dearpygui +taichi \ No newline at end of file diff --git a/show_gui.py b/show_gui.py index ce675378..68e2f396 100644 --- a/show_gui.py +++ b/show_gui.py @@ -2,7 +2,6 @@ from opt import get_opts import numpy as np from einops import rearrange -import dearpygui.dearpygui as dpg from scipy.spatial.transform import Rotation as R import time @@ -15,14 +14,27 @@ import warnings; warnings.filterwarnings("ignore") +import taichi as ti + +@ti.kernel +def write_buffer(W:ti.i32, H:ti.i32, x: ti.types.ndarray(), final_pixel:ti.template()): + for i, j in ti.ndrange(W, H): + for p in ti.static(range(3)): + final_pixel[i, j][p] = x[H-j, i, p] class OrbitCamera: - def __init__(self, K, img_wh, r): + def __init__(self, K, img_wh, poses, r): self.K = K self.W, self.H = img_wh self.radius = r self.center = np.zeros(3) - self.rot = np.eye(3) + + pose_np = poses.cpu().numpy() + # choose a pose as the initial rotation + self.rot = pose_np[0][:3, :3] + + self.rotate_speed = 0.8 + self.res_defalut = pose_np[0] @property def pose(self): @@ -37,9 +49,16 @@ def pose(self): res[:3, 3] -= self.center return res + def reset(self, pose=None): + self.rot = np.eye(3) + self.center = np.zeros(3) + self.radius = 2.0 + if pose is not None: + self.rot = pose.cpu().numpy()[:3, :3] + def orbit(self, dx, dy): - rotvec_x = self.rot[:, 1] * np.radians(0.05 * dx) - rotvec_y = self.rot[:, 0] * np.radians(-0.05 * dy) + rotvec_x = self.rot[:, 1] * np.radians(100*self.rotate_speed * dx) + rotvec_y = self.rot[:, 0] * np.radians(-100*self.rotate_speed * dy) self.rot = R.from_rotvec(rotvec_y).as_matrix() @ \ R.from_rotvec(rotvec_x).as_matrix() @ \ self.rot @@ -52,27 +71,29 @@ def pan(self, dx, dy, dz=0): class NGPGUI: - def __init__(self, hparams, K, img_wh, radius=2.5): + def __init__(self, hparams, K, img_wh, poses, radius=2.5): self.hparams = hparams rgb_act = 'None' if self.hparams.use_exposure else 'Sigmoid' self.model = NGP(scale=hparams.scale, rgb_act=rgb_act).cuda() load_ckpt(self.model, hparams.ckpt_path) - self.cam = OrbitCamera(K, img_wh, r=radius) + self.poses = poses + + self.cam = OrbitCamera(K, img_wh, poses, r=radius) self.W, self.H = img_wh - self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32) + # self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32) + self.render_buffer = ti.Vector.field(3, dtype=float, shape=(self.W, self.H)) # placeholders self.dt = 0 self.mean_samples = 0 self.img_mode = 0 + self.exposure = 0.2 - self.register_dpg() - - def render_cam(self, cam): + def render_cam(self): t = time.time() - directions = get_ray_directions(cam.H, cam.W, cam.K, device='cuda') - rays_o, rays_d = get_rays(directions, torch.cuda.FloatTensor(cam.pose)) + directions = get_ray_directions(self.cam.H, self.cam.W, self.cam.K, device='cuda') + rays_o, rays_d = get_rays(directions, torch.cuda.FloatTensor(self.cam.pose)) # TODO: set these attributes by gui if self.hparams.dataset_name in ['colmap', 'nerfpp']: @@ -81,113 +102,99 @@ def render_cam(self, cam): results = render(self.model, rays_o, rays_d, **{'test_time': True, - 'to_cpu': True, 'to_numpy': True, + 'to_cpu': False, 'to_numpy': False, 'T_threshold': 1e-2, - 'exposure': torch.cuda.FloatTensor([dpg.get_value('_exposure')]), + 'exposure': torch.cuda.FloatTensor([self.exposure]), 'max_samples': 100, 'exp_step_factor': exp_step_factor}) rgb = rearrange(results["rgb"], "(h w) c -> h w c", h=self.H) depth = rearrange(results["depth"], "(h w) -> h w", h=self.H) - torch.cuda.synchronize() + # torch.cuda.synchronize() self.dt = time.time()-t self.mean_samples = results['total_samples']/len(rays_o) if self.img_mode == 0: return rgb elif self.img_mode == 1: - return depth2img(depth).astype(np.float32)/255.0 - - def register_dpg(self): - dpg.create_context() - dpg.create_viewport(title="ngp_pl", width=self.W, height=self.H, resizable=False) - - ## register texture ## - with dpg.texture_registry(show=False): - dpg.add_raw_texture( - self.W, - self.H, - self.render_buffer, - format=dpg.mvFormat_Float_rgb, - tag="_texture") - - ## register window ## - with dpg.window(tag="_primary_window", width=self.W, height=self.H): - dpg.add_image("_texture") - dpg.set_primary_window("_primary_window", True) - - def callback_depth(sender, app_data): - self.img_mode = 1-self.img_mode - - ## control window ## - with dpg.window(label="Control", tag="_control_window", width=200, height=150): - dpg.add_slider_float(label="exposure", default_value=0.2, - min_value=1/60, max_value=32, tag="_exposure") - dpg.add_button(label="show depth", tag="_button_depth", - callback=callback_depth) - dpg.add_separator() - dpg.add_text('no data', tag="_log_time") - dpg.add_text('no data', tag="_samples_per_ray") - - ## register camera handler ## - def callback_camera_drag_rotate(sender, app_data): - if not dpg.is_item_focused("_primary_window"): - return - self.cam.orbit(app_data[1], app_data[2]) - - def callback_camera_wheel_scale(sender, app_data): - if not dpg.is_item_focused("_primary_window"): - return - self.cam.scale(app_data) - - def callback_camera_drag_pan(sender, app_data): - if not dpg.is_item_focused("_primary_window"): - return - self.cam.pan(app_data[1], app_data[2]) - - with dpg.handler_registry(): - dpg.add_mouse_drag_handler( - button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate - ) - dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) - dpg.add_mouse_drag_handler( - button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan - ) - - ## Avoid scroll bar in the window ## - with dpg.theme() as theme_no_padding: - with dpg.theme_component(dpg.mvAll): - dpg.add_theme_style( - dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core - ) - dpg.add_theme_style( - dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core - ) - dpg.add_theme_style( - dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core - ) - dpg.bind_item_theme("_primary_window", theme_no_padding) - - ## Launch the gui ## - dpg.setup_dearpygui() - dpg.set_viewport_small_icon("assets/icon.png") - dpg.set_viewport_large_icon("assets/icon.png") - dpg.show_viewport() + return depth2img(depth.cpu().numpy()).astype(np.float32)/255.0 + + def check_cam_rotate(self, window, last_orbit_x, last_orbit_y): + if window.is_pressed(ti.ui.RMB): + curr_mouse_x, curr_mouse_y = window.get_cursor_pos() + if last_orbit_x is None or last_orbit_y is None: + last_orbit_x, last_orbit_y = curr_mouse_x, curr_mouse_y + else: + dx = curr_mouse_x - last_orbit_x + dy = curr_mouse_y - last_orbit_y + self.cam.orbit(dx, -dy) + last_orbit_x, last_orbit_y = curr_mouse_x, curr_mouse_y + else: + last_orbit_x = None + last_orbit_y = None + + return last_orbit_x, last_orbit_y + + def check_key_press(self, window): + if window.is_pressed('w'): + self.cam.scale(0.2) + if window.is_pressed('s'): + self.cam.scale(-0.2) + if window.is_pressed('a'): + self.cam.pan(100, 0.) + if window.is_pressed('d'): + self.cam.pan(-100, 0.) + if window.is_pressed('e'): + self.cam.pan(0., -100) + if window.is_pressed('q'): + self.cam.pan(0., 100) def render(self): - while dpg.is_dearpygui_running(): - dpg.set_value("_texture", self.render_cam(self.cam)) - dpg.set_value("_log_time", f'Render time: {1000*self.dt:.2f} ms') - dpg.set_value("_samples_per_ray", f'Samples/ray: {self.mean_samples:.2f}') - dpg.render_dearpygui_frame() + + window = ti.ui.Window('ngp_pl', (self.W, self.H),) + canvas = window.get_canvas() + gui = window.get_gui() + + # GUI controls variables + last_orbit_x = None + last_orbit_y = None + + view_id = 0 + last_view_id = 0 + + views_size = self.poses.shape[0]-1 + + while window.running: + self.check_key_press(window) + last_orbit_x, last_orbit_y = self.check_cam_rotate(window, last_orbit_x, last_orbit_y) + with gui.sub_window("Control", 0.01, 0.01, 0.4, 0.2) as w: + self.cam.rotate_speed = w.slider_float('rotate speed', self.cam.rotate_speed, 0.1, 1.) + self.exposure = w.slider_float('exposure', self.exposure, 1/60, 32) + + self.img_mode = w.checkbox("show depth", self.img_mode) + + view_id = w.slider_int('train view', view_id, 0, views_size) + + if last_view_id != view_id: + last_view_id = view_id + self.cam.reset(self.poses[view_id]) + + w.text(f'samples per rays: {self.mean_samples:.2f} s/r') + w.text(f'render times: {1000*self.dt:.2f} ms') + + ngp_buffer = self.render_cam() + write_buffer(self.W, self.H, ngp_buffer, self.render_buffer) + canvas.set_image(self.render_buffer) + window.show() if __name__ == "__main__": + ti.init(arch=ti.cuda) + hparams = get_opts() kwargs = {'root_dir': hparams.root_dir, 'downsample': hparams.downsample, - 'read_meta': False} + 'read_meta': True} dataset = dataset_dict[hparams.dataset_name](**kwargs) - NGPGUI(hparams, dataset.K, dataset.img_wh).render() - dpg.destroy_context() + NGPGUI(hparams, dataset.K, dataset.img_wh, dataset.poses).render()