Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix & speedup show_gui #112

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ jupyter
scipy
pymcubes
trimesh
dearpygui
dearpygui
taichi
209 changes: 108 additions & 101 deletions show_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from opt import get_opts
import numpy as np
from einops import rearrange
import dearpygui.dearpygui as dpg
from scipy.spatial.transform import Rotation as R
import time

Expand All @@ -15,14 +14,27 @@

import warnings; warnings.filterwarnings("ignore")

import taichi as ti

@ti.kernel
def write_buffer(W:ti.i32, H:ti.i32, x: ti.types.ndarray(), final_pixel:ti.template()):
for i, j in ti.ndrange(W, H):
for p in ti.static(range(3)):
final_pixel[i, j][p] = x[H-j, i, p]

class OrbitCamera:
def __init__(self, K, img_wh, r):
def __init__(self, K, img_wh, poses, r):
self.K = K
self.W, self.H = img_wh
self.radius = r
self.center = np.zeros(3)
self.rot = np.eye(3)

pose_np = poses.cpu().numpy()
# choose a pose as the initial rotation
self.rot = pose_np[0][:3, :3]

self.rotate_speed = 0.8
self.res_defalut = pose_np[0]

@property
def pose(self):
Expand All @@ -37,9 +49,16 @@ def pose(self):
res[:3, 3] -= self.center
return res

def reset(self, pose=None):
self.rot = np.eye(3)
self.center = np.zeros(3)
self.radius = 2.0
if pose is not None:
self.rot = pose.cpu().numpy()[:3, :3]

def orbit(self, dx, dy):
rotvec_x = self.rot[:, 1] * np.radians(0.05 * dx)
rotvec_y = self.rot[:, 0] * np.radians(-0.05 * dy)
rotvec_x = self.rot[:, 1] * np.radians(100*self.rotate_speed * dx)
rotvec_y = self.rot[:, 0] * np.radians(-100*self.rotate_speed * dy)
self.rot = R.from_rotvec(rotvec_y).as_matrix() @ \
R.from_rotvec(rotvec_x).as_matrix() @ \
self.rot
Expand All @@ -52,27 +71,29 @@ def pan(self, dx, dy, dz=0):


class NGPGUI:
def __init__(self, hparams, K, img_wh, radius=2.5):
def __init__(self, hparams, K, img_wh, poses, radius=2.5):
self.hparams = hparams
rgb_act = 'None' if self.hparams.use_exposure else 'Sigmoid'
self.model = NGP(scale=hparams.scale, rgb_act=rgb_act).cuda()
load_ckpt(self.model, hparams.ckpt_path)

self.cam = OrbitCamera(K, img_wh, r=radius)
self.poses = poses

self.cam = OrbitCamera(K, img_wh, poses, r=radius)
self.W, self.H = img_wh
self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32)
# self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32)
self.render_buffer = ti.Vector.field(3, dtype=float, shape=(self.W, self.H))

# placeholders
self.dt = 0
self.mean_samples = 0
self.img_mode = 0
self.exposure = 0.2

self.register_dpg()

def render_cam(self, cam):
def render_cam(self):
t = time.time()
directions = get_ray_directions(cam.H, cam.W, cam.K, device='cuda')
rays_o, rays_d = get_rays(directions, torch.cuda.FloatTensor(cam.pose))
directions = get_ray_directions(self.cam.H, self.cam.W, self.cam.K, device='cuda')
rays_o, rays_d = get_rays(directions, torch.cuda.FloatTensor(self.cam.pose))

# TODO: set these attributes by gui
if self.hparams.dataset_name in ['colmap', 'nerfpp']:
Expand All @@ -81,113 +102,99 @@ def render_cam(self, cam):

results = render(self.model, rays_o, rays_d,
**{'test_time': True,
'to_cpu': True, 'to_numpy': True,
'to_cpu': False, 'to_numpy': False,
'T_threshold': 1e-2,
'exposure': torch.cuda.FloatTensor([dpg.get_value('_exposure')]),
'exposure': torch.cuda.FloatTensor([self.exposure]),
'max_samples': 100,
'exp_step_factor': exp_step_factor})

rgb = rearrange(results["rgb"], "(h w) c -> h w c", h=self.H)
depth = rearrange(results["depth"], "(h w) -> h w", h=self.H)
torch.cuda.synchronize()
# torch.cuda.synchronize()
self.dt = time.time()-t
self.mean_samples = results['total_samples']/len(rays_o)

if self.img_mode == 0:
return rgb
elif self.img_mode == 1:
return depth2img(depth).astype(np.float32)/255.0

def register_dpg(self):
dpg.create_context()
dpg.create_viewport(title="ngp_pl", width=self.W, height=self.H, resizable=False)

## register texture ##
with dpg.texture_registry(show=False):
dpg.add_raw_texture(
self.W,
self.H,
self.render_buffer,
format=dpg.mvFormat_Float_rgb,
tag="_texture")

## register window ##
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
dpg.add_image("_texture")
dpg.set_primary_window("_primary_window", True)

def callback_depth(sender, app_data):
self.img_mode = 1-self.img_mode

## control window ##
with dpg.window(label="Control", tag="_control_window", width=200, height=150):
dpg.add_slider_float(label="exposure", default_value=0.2,
min_value=1/60, max_value=32, tag="_exposure")
dpg.add_button(label="show depth", tag="_button_depth",
callback=callback_depth)
dpg.add_separator()
dpg.add_text('no data', tag="_log_time")
dpg.add_text('no data', tag="_samples_per_ray")

## register camera handler ##
def callback_camera_drag_rotate(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
self.cam.orbit(app_data[1], app_data[2])

def callback_camera_wheel_scale(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
self.cam.scale(app_data)

def callback_camera_drag_pan(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
self.cam.pan(app_data[1], app_data[2])

with dpg.handler_registry():
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate
)
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
)

## Avoid scroll bar in the window ##
with dpg.theme() as theme_no_padding:
with dpg.theme_component(dpg.mvAll):
dpg.add_theme_style(
dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.bind_item_theme("_primary_window", theme_no_padding)

## Launch the gui ##
dpg.setup_dearpygui()
dpg.set_viewport_small_icon("assets/icon.png")
dpg.set_viewport_large_icon("assets/icon.png")
dpg.show_viewport()
return depth2img(depth.cpu().numpy()).astype(np.float32)/255.0

def check_cam_rotate(self, window, last_orbit_x, last_orbit_y):
if window.is_pressed(ti.ui.RMB):
curr_mouse_x, curr_mouse_y = window.get_cursor_pos()
if last_orbit_x is None or last_orbit_y is None:
last_orbit_x, last_orbit_y = curr_mouse_x, curr_mouse_y
else:
dx = curr_mouse_x - last_orbit_x
dy = curr_mouse_y - last_orbit_y
self.cam.orbit(dx, -dy)
last_orbit_x, last_orbit_y = curr_mouse_x, curr_mouse_y
else:
last_orbit_x = None
last_orbit_y = None

return last_orbit_x, last_orbit_y

def check_key_press(self, window):
if window.is_pressed('w'):
self.cam.scale(0.2)
if window.is_pressed('s'):
self.cam.scale(-0.2)
if window.is_pressed('a'):
self.cam.pan(100, 0.)
if window.is_pressed('d'):
self.cam.pan(-100, 0.)
if window.is_pressed('e'):
self.cam.pan(0., -100)
if window.is_pressed('q'):
self.cam.pan(0., 100)

def render(self):
while dpg.is_dearpygui_running():
dpg.set_value("_texture", self.render_cam(self.cam))
dpg.set_value("_log_time", f'Render time: {1000*self.dt:.2f} ms')
dpg.set_value("_samples_per_ray", f'Samples/ray: {self.mean_samples:.2f}')
dpg.render_dearpygui_frame()

window = ti.ui.Window('ngp_pl', (self.W, self.H),)
canvas = window.get_canvas()
gui = window.get_gui()

# GUI controls variables
last_orbit_x = None
last_orbit_y = None

view_id = 0
last_view_id = 0

views_size = self.poses.shape[0]-1

while window.running:
self.check_key_press(window)
last_orbit_x, last_orbit_y = self.check_cam_rotate(window, last_orbit_x, last_orbit_y)
with gui.sub_window("Control", 0.01, 0.01, 0.4, 0.2) as w:
self.cam.rotate_speed = w.slider_float('rotate speed', self.cam.rotate_speed, 0.1, 1.)
self.exposure = w.slider_float('exposure', self.exposure, 1/60, 32)

self.img_mode = w.checkbox("show depth", self.img_mode)

view_id = w.slider_int('train view', view_id, 0, views_size)

if last_view_id != view_id:
last_view_id = view_id
self.cam.reset(self.poses[view_id])

w.text(f'samples per rays: {self.mean_samples:.2f} s/r')
w.text(f'render times: {1000*self.dt:.2f} ms')

ngp_buffer = self.render_cam()
write_buffer(self.W, self.H, ngp_buffer, self.render_buffer)
canvas.set_image(self.render_buffer)
window.show()


if __name__ == "__main__":
ti.init(arch=ti.cuda)

hparams = get_opts()
kwargs = {'root_dir': hparams.root_dir,
'downsample': hparams.downsample,
'read_meta': False}
'read_meta': True}
dataset = dataset_dict[hparams.dataset_name](**kwargs)

NGPGUI(hparams, dataset.K, dataset.img_wh).render()
dpg.destroy_context()
NGPGUI(hparams, dataset.K, dataset.img_wh, dataset.poses).render()