Skip to content

Commit

Permalink
[MINOR] fix bug of mesh recon after slam process
Browse files Browse the repository at this point in the history
  • Loading branch information
Yue Pan committed Aug 21, 2024
1 parent c72f1b0 commit 75aff77
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 54 deletions.
18 changes: 10 additions & 8 deletions dataset/slam_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,10 @@ def read_frame_ros(self, msg):
self.get_point_ts(point_ts)

# read frame with specific data loader (partially borrow from kiss-icp: https://github.com/PRBonn/kiss-icp)
def read_frame_with_loader(self, frame_id):

self.set_ref_pose(frame_id)
def read_frame_with_loader(self, frame_id, init_pose: bool = True):

if init_pose:
self.set_ref_pose(frame_id)

frame_id_in_folder = self.config.begin_frame + frame_id * self.config.step_frame
frame_data = self.loader[frame_id_in_folder]
Expand Down Expand Up @@ -243,9 +244,10 @@ def read_frame_with_loader(self, frame_id):
if self.config.deskew:
self.get_point_ts(point_ts)

def read_frame(self, frame_id):

self.set_ref_pose(frame_id)
def read_frame(self, frame_id, init_pose: bool = True):

if init_pose:
self.set_ref_pose(frame_id)

point_ts = None

Expand Down Expand Up @@ -852,9 +854,9 @@ def write_merged_point_cloud(self, down_vox_m=None,
range(0, self.total_pc_count, frame_step)
): # frame id as the idx of the frame in the data folder without skipping
if self.config.use_dataloader:
self.read_frame_with_loader(frame_id)
self.read_frame_with_loader(frame_id, False)
else:
self.read_frame(frame_id)
self.read_frame(frame_id, False)

if self.config.kitti_correction_on:
self.cur_point_cloud_torch = intrinsic_correct(
Expand Down
92 changes: 54 additions & 38 deletions model/neural_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, config: Config) -> None:

self.buffer_size = config.buffer_size

self.temporal_local_map_on = True
self.temporal_local_map_on = True # false for the pure localization mode
self.local_map_radius = self.config.local_map_radius
self.diff_travel_dist_local = (
self.config.local_map_radius * self.config.local_map_travel_dist_ratio
Expand All @@ -73,7 +73,7 @@ def __init__(self, config: Config) -> None:
self.cur_ts = 0 # current frame No. or the current timestamp
self.max_ts = 0

self.travel_dist = None # for determine the local map, update from the dataset class for each frame
self.travel_dist = None # for determine the local map, update from the dataset class for each frame # this is not saved
self.est_poses = None
self.after_pgo = False

Expand Down Expand Up @@ -309,27 +309,29 @@ def update(
vec_points = self.neural_points[hash_idx] - sample_points
dist2 = torch.sum(vec_points**2, dim=-1)

# the voxel is not occupied before or the case when hash collision happens
# delta_t = (cur_ts - self.point_ts_create[hash_idx]) # use time diff
delta_travel_dist = (
self.travel_dist[cur_ts]
- self.travel_dist[self.point_ts_update[hash_idx]]
) # use travel dist diff

# the last time mask is necessary (but better change to the accumulated distance or the pose uncertainty), done
update_mask = (
(hash_idx == -1)
| (dist2 > 3 * cur_resolution**2)
| (delta_travel_dist > self.diff_travel_dist_local)
)
update_mask = (hash_idx == -1) | (dist2 > 3 * cur_resolution**2)

if self.temporal_local_map_on: # only done for the slam mode
# the voxel is not occupied before or the case when hash collision happens
# delta_t = (cur_ts - self.point_ts_create[hash_idx]) # use time diff
delta_travel_dist = (
self.travel_dist[cur_ts]
- self.travel_dist[self.point_ts_update[hash_idx]]
) # use travel dist diff

# the last time mask is necessary
update_mask = update_mask | (delta_travel_dist > self.diff_travel_dist_local)
else:
update_mask = torch.ones(
hash_idx.shape, dtype=torch.bool, device=self.device
)

added_pt = sample_points[update_mask]

new_point_count = added_pt.shape[0]

new_point_ratio = new_point_count / sample_points.shape[0]

cur_pt_idx = self.buffer_pt_index[hash]
# allocate new neural points
cur_pt_count = self.neural_points.shape[0]
Expand Down Expand Up @@ -384,6 +386,8 @@ def update(
sensor_position, sensor_orientation, cur_ts
) # no need to recreate hash

return new_point_ratio

def reset_local_map(
self,
sensor_position: torch.Tensor,
Expand All @@ -393,32 +397,37 @@ def reset_local_map(
diff_ts_local: int = 50,
):
# TODO: not very efficient, optimize the code

self.cur_ts = cur_ts
self.max_ts = max(self.max_ts, cur_ts)

if self.config.use_mid_ts:
point_ts_used = (
(self.point_ts_create + self.point_ts_update) / 2
).int()
else:
point_ts_used = self.point_ts_create
if self.temporal_local_map_on:
if self.config.use_mid_ts:
point_ts_used = (
(self.point_ts_create + self.point_ts_update) / 2
).int()
else:
point_ts_used = self.point_ts_create

if use_travel_dist: # self.travel_dist as torch tensor
delta_travel_dist = torch.abs(
self.travel_dist[cur_ts] - self.travel_dist[point_ts_used]
)
time_mask = (delta_travel_dist < self.diff_travel_dist_local)
else: # use delta_t
delta_t = torch.abs(cur_ts - point_ts_used)
time_mask = (delta_t < diff_ts_local)
if use_travel_dist: # self.travel_dist as torch tensor
delta_travel_dist = torch.abs(
self.travel_dist[cur_ts] - self.travel_dist[point_ts_used]
)
time_mask = (delta_travel_dist < self.diff_travel_dist_local)
else: # use delta_t
delta_t = torch.abs(cur_ts - point_ts_used)
time_mask = (delta_t < diff_ts_local)

else:
time_mask = torch.ones(self.count(), dtype=torch.bool, device=self.device)

# speed up by calulating distance only with the t filtered points
masked_vec2sensor = self.neural_points[time_mask] - sensor_position
masked_dist2sensor = torch.sum(masked_vec2sensor**2, dim=-1) # dist square

dist_mask = (masked_dist2sensor < self.local_map_radius**2)
time_mask_idx = torch.nonzero(time_mask).squeeze() # True index

local_mask_idx = time_mask_idx[dist_mask] # True index

local_mask = torch.full((time_mask.shape), False, dtype=torch.bool, device=self.device)
Expand Down Expand Up @@ -555,6 +564,8 @@ def query_feature(

N, K = valid_mask.shape # K = nn_k here

# print(self.local_point_certainties)

if query_locally:
certainty = self.local_point_certainties[idx] # [N, K]
neighb_vector = (
Expand Down Expand Up @@ -674,18 +685,22 @@ def query_feature(
)

# prune inactive uncertain neural points
def prune_map(self, prune_certainty_thre, min_prune_count = 500):
def prune_map(self, prune_certainty_thre, min_prune_count = 500, global_prune = False):

diff_travel_dist = torch.abs(
self.travel_dist[self.cur_ts] - self.travel_dist[self.point_ts_update]
)
inactive_mask = diff_travel_dist > self.diff_travel_dist_local
certainty_mask = self.point_certainties < prune_certainty_thre

prune_mask = inactive_mask & (
self.point_certainties < prune_certainty_thre
) # True for prune
if global_prune:
prune_mask = certainty_mask
else:
diff_travel_dist = torch.abs(
self.travel_dist[self.cur_ts] - self.travel_dist[self.point_ts_update]
)
inactive_mask = diff_travel_dist > self.diff_travel_dist_local
prune_mask = inactive_mask & certainty_mask
# True for prune

prune_count = torch.sum(prune_mask).item()

if prune_count > min_prune_count:
if not self.silence:
print("# Prune neural points: ", prune_count)
Expand Down Expand Up @@ -961,6 +976,7 @@ def get_map_o3d_bbx(self):
# Borrow from Louis's LocNDF
# https://github.com/PRBonn/LocNDF
class PositionalEncoder(nn.Module):

# out_dim = in_dimnesionality * (2 * bands + 1)
def __init__(self, config: Config):
super().__init__()
Expand Down
13 changes: 7 additions & 6 deletions pin_slam.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def run_pin_slam(config_path=None, dataset_name=None, sequence_name=None, seed=N
loop_reg_failed_count = 0

# save merged point cloud map from gt pose as a reference map
if config.save_merged_pc and dataset.gt_pose_provided:
print("Load and merge the map point cloud with the reference (GT) poses ... ...")
dataset.write_merged_point_cloud(use_gt_pose=True, out_file_name='merged_gt_pc',
frame_step=5, merged_downsample=True)
# if config.save_merged_pc and dataset.gt_pose_provided:
# print("Load and merge the map point cloud with the reference (GT) poses ... ...")
# dataset.write_merged_point_cloud(use_gt_pose=True, out_file_name='merged_gt_pc',
# frame_step=5, merged_downsample=True)

# for each frame
# frame id as the processed frame, possible skipping done in data loader
Expand Down Expand Up @@ -431,8 +431,9 @@ def run_pin_slam(config_path=None, dataset_name=None, sequence_name=None, seed=N
if config.o3d_vis_on:
pgm.plot_loops(os.path.join(run_path, "loop_plot.png"), vis_now=False)

neural_points.recreate_hash(None, None, False, False) # merge the final neural point map
neural_points.prune_map(config.max_prune_certainty, 0) # prune uncertain points for the final output
neural_points.prune_map(config.max_prune_certainty, 0, True) # prune uncertain points for the final output
neural_points.recreate_hash(None, None, False, False) # merge the final neural point map

neural_pcd = neural_points.get_neural_points_o3d(query_global=True, color_mode = 0)
if config.save_map:
o3d.io.write_point_cloud(os.path.join(run_path, "map", "neural_points.ply"), neural_pcd) # write the neural point cloud
Expand Down
4 changes: 2 additions & 2 deletions pin_slam_ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ def save_results(self, terminate: bool = False):
self.pgm.plot_loops(os.path.join(self.run_path, "loop_plot.png"), vis_now=False)

if terminate:
self.neural_points.recreate_hash(None, None, False, False) # merge the final neural point map
self.neural_points.prune_map(self.config.max_prune_certainty, 0) # prune uncertain points for the final output

self.neural_points.recreate_hash(None, None, False, False) # merge the final neural point map

if self.config.save_map:
neural_pcd = self.neural_points.get_neural_points_o3d(query_global=True, color_mode=0)
o3d.io.write_point_cloud(os.path.join(self.run_path, "map", "neural_points.ply"), neural_pcd)
Expand Down

0 comments on commit 75aff77

Please sign in to comment.