From 8c20c58c5d037db64522f12011407b4684e90fca Mon Sep 17 00:00:00 2001 From: lcmrl Date: Sat, 21 Sep 2024 16:22:30 +0200 Subject: [PATCH] loftr and roma to multiview directly in main.py --- main.py | 12 + src/deep_image_matching/io/h5_to_db.py | 13 +- src/deep_image_matching/utils/database.py | 20 ++ .../utils/loftr_roma_to_multiview.py | 248 ++++++++++++++++++ 4 files changed, 285 insertions(+), 8 deletions(-) create mode 100644 src/deep_image_matching/utils/loftr_roma_to_multiview.py diff --git a/main.py b/main.py index 04597f4..e51fccc 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,10 @@ +import os import logging from importlib import import_module +from pathlib import Path import deep_image_matching as dim +from deep_image_matching.utils.loftr_roma_to_multiview import LoftrRomaToMultiview import yaml logger = dim.setup_logger("dim") @@ -30,6 +33,15 @@ camera_config_path=config.general["camera_options"], ) +if matcher.matching in ["loftr", "se2loftr", "roma"]: + images = os.listdir(imgs_dir) + image_format = Path(images[0]).suffix + LoftrRomaToMultiview( + input_dir=feature_path.parent, + output_dir=feature_path.parent, + image_dir=imgs_dir, + img_ext=image_format) + # Visualize view graph if config.general["graph"]: try: diff --git a/src/deep_image_matching/io/h5_to_db.py b/src/deep_image_matching/io/h5_to_db.py index 406dcdf..4e9ed71 100644 --- a/src/deep_image_matching/io/h5_to_db.py +++ b/src/deep_image_matching/io/h5_to_db.py @@ -236,7 +236,6 @@ def add_keypoints(db: Path, h5_path: Path, image_path: Path, camera_options: dic grouped_images = parse_camera_options(camera_options, db, image_path) with h5py.File(str(h5_path), "r") as keypoint_f: - # camera_id = None fname_to_id = {} k = 0 for filename in tqdm(list(keypoint_f.keys())): @@ -256,18 +255,16 @@ def add_keypoints(db: Path, h5_path: Path, image_path: Path, camera_options: dic k += 1 elif k > 0: camera_id = single_camera_id - else: + elif filename in list(grouped_images.keys()): camera_id = grouped_images[filename]["camera_id"] + else: + print('ERROR in h5_to_db.py') + quit() image_id = db.add_image(filename, camera_id) fname_to_id[filename] = image_id - # print('keypoints') - # print(keypoints) - # print('image_id', image_id) + if len(keypoints.shape) >= 2: db.add_keypoints(image_id, keypoints) - # else: - # keypoints = - # db.add_keypoints(image_id, keypoints) return fname_to_id diff --git a/src/deep_image_matching/utils/database.py b/src/deep_image_matching/utils/database.py index ccebea3..e692686 100644 --- a/src/deep_image_matching/utils/database.py +++ b/src/deep_image_matching/utils/database.py @@ -363,3 +363,23 @@ def get_matches(self) -> tuple: matches[im_ids] = mtc return matches, images + + def get_images(self) -> dict: + query = "SELECT image_id, name, camera_id FROM images" + data = self.execute(query).fetchall() + images = {} + for d in data: + image_id = d[0] + image_name = d[1] + camera_id = d[2] + images[image_name] = (image_id, camera_id) + return images + + def clean_keypoints(self): + self.execute("DELETE FROM keypoints") + + def clean_matches(self): + self.execute("DELETE FROM matches") + + def clean_two_view_geometries(self): + self.execute("DELETE FROM two_view_geometries") \ No newline at end of file diff --git a/src/deep_image_matching/utils/loftr_roma_to_multiview.py b/src/deep_image_matching/utils/loftr_roma_to_multiview.py new file mode 100644 index 0000000..52353bf --- /dev/null +++ b/src/deep_image_matching/utils/loftr_roma_to_multiview.py @@ -0,0 +1,248 @@ +import h5py +import torch +import numpy as np +from pathlib import Path +from copy import deepcopy +from collections import defaultdict +from deep_image_matching.io.h5_to_db import COLMAPDatabase, image_ids_to_pair_id + +import os, h5py, warnings +import numpy as np +from tqdm import tqdm +from PIL import Image, ExifTags +import argparse + +# Credit to: https://github.com/ducha-aiki/imc2023-kornia-starter-pack/blob/main/loftr-pycolmap-3dreconstruction.ipynb + +def get_focal(image_path, err_on_default=False): + image = Image.open(image_path) + max_size = max(image.size) + + exif = image.getexif() + focal = None + if exif is not None: + focal_35mm = None + # https://github.com/colmap/colmap/blob/d3a29e203ab69e91eda938d6e56e1c7339d62a99/src/util/bitmap.cc#L299 + for tag, value in exif.items(): + focal_35mm = None + if ExifTags.TAGS.get(tag, None) == 'FocalLengthIn35mmFilm': + focal_35mm = float(value) + break + + if focal_35mm is not None: + focal = focal_35mm / 35. * max_size + + if focal is None: + if err_on_default: + raise RuntimeError("Failed to find focal length") + + # failed to find it in exif, use prior + FOCAL_PRIOR = 1.2 + focal = FOCAL_PRIOR * max_size + + return focal + +def create_camera(db, image_path, camera_model): + image = Image.open(image_path) + width, height = image.size + + focal = get_focal(image_path) + + if camera_model == 'simple-pinhole': + model = 0 # simple pinhole + param_arr = np.array([focal, width / 2, height / 2]) + if camera_model == 'pinhole': + model = 1 # pinhole + param_arr = np.array([focal, focal, width / 2, height / 2]) + elif camera_model == 'simple-radial': + model = 2 # simple radial + param_arr = np.array([focal, width / 2, height / 2, 0.1]) + elif camera_model == 'opencv': + model = 4 # opencv + param_arr = np.array([focal, focal, width / 2, height / 2, 0., 0., 0., 0.]) + + return db.add_camera(model, width, height, param_arr) + + +def add_keypoints(db, h5_path, image_path, camera_model, single_camera = True): + keypoint_f = h5py.File(os.path.join(h5_path, 'keypoints.h5'), 'r') + fname_to_id = {} + db.clean_keypoints() + + for filename in tqdm(list(keypoint_f.keys())): + keypoints = keypoint_f[filename][()] + fname_with_ext = filename + path = os.path.join(image_path, fname_with_ext) + if not os.path.isfile(path): + raise IOError(f'Invalid image path {path}') + images = db.get_images() + image_id, camera_id = images[filename] + fname_to_id[filename] = image_id + db.add_keypoints(image_id, keypoints) + + return fname_to_id + +def add_matches(db, h5_path, fname_to_id): + db.clean_matches() + db.clean_two_view_geometries() + match_file = h5py.File(os.path.join(h5_path, 'matches.h5'), 'r') + + added = set() + n_keys = len(match_file.keys()) + n_total = (n_keys * (n_keys - 1)) // 2 + + with tqdm(total=n_total) as pbar: + for key_1 in match_file.keys(): + group = match_file[key_1] + for key_2 in group.keys(): + id_1 = fname_to_id[key_1] + id_2 = fname_to_id[key_2] + + pair_id = image_ids_to_pair_id(id_1, id_2) + if pair_id in added: + warnings.warn(f'Pair {pair_id} ({id_1}, {id_2}) already added!') + continue + + matches = group[key_2][()] + #db.add_matches(id_1, id_2, matches) + db.add_two_view_geometry(id_1, id_2, matches) + + added.add(pair_id) + + pbar.update(1) + +def get_unique_idxs(A, dim=1): + # https://stackoverflow.com/questions/72001505/how-to-get-unique-elements-and-their-firstly-appeared-indices-of-a-pytorch-tenso + unique, idx, counts = torch.unique(A, dim=dim, sorted=True, return_inverse=True, return_counts=True) + _, ind_sorted = torch.sort(idx, stable=True) + cum_sum = counts.cumsum(0) + cum_sum = torch.cat((torch.tensor([0],device=cum_sum.device), cum_sum[:-1])) + first_indicies = ind_sorted[cum_sum] + return first_indicies + +def import_into_colmap(img_dir, + feature_dir ='.featureout', + database_path = 'colmap.db' + ): + db = COLMAPDatabase.connect(database_path) + #db.create_tables() + single_camera = False + fname_to_id = add_keypoints(db, feature_dir, img_dir, 'simple-radial', single_camera) + add_matches( + db, + feature_dir, + fname_to_id, + ) + + db.commit() + return + +def LoftrRomaToMultiview( + input_dir: Path, + output_dir: Path, + image_dir: Path, + img_ext: Path, + ) -> None: + + with h5py.File(fr'{input_dir}\features.h5', mode='r') as h5_feats, \ + h5py.File(fr'{input_dir}\matches.h5', mode='r') as h5_matches, \ + h5py.File(fr'{input_dir}\matches_loftr.h5', mode='w') as h5_out: + + for img1 in h5_matches.keys(): + print(img1) + kpts1 = h5_feats[img1]['keypoints'][...] + group_match = h5_matches[img1] + group_out = h5_out.require_group(img1) + for img2 in group_match.keys(): + print(f"--- {img2}") + kpts2 = h5_feats[img2]['keypoints'][...] + matches = group_match[img2][...] + h5_out[img1][img2] = np.hstack((kpts1[matches[:,0],:], kpts2[matches[:,1],:])) + + kpts = defaultdict(list) + match_indexes = defaultdict(dict) + total_kpts=defaultdict(int) + + with h5py.File(fr'{input_dir}\matches_loftr.h5', mode='r') as f_match: + for k1 in f_match.keys(): + group = f_match[k1] + for k2 in group.keys(): + matches = group[k2][...] + total_kpts[k1] + kpts[k1].append(matches[:, :2]) + kpts[k2].append(matches[:, 2:]) + current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2) + current_match[:, 0]+=total_kpts[k1] + current_match[:, 1]+=total_kpts[k2] + total_kpts[k1]+=len(matches) + total_kpts[k2]+=len(matches) + match_indexes[k1][k2]=current_match + + for k in kpts.keys(): + kpts[k] = np.round(np.concatenate(kpts[k], axis=0)) + + unique_kpts = {} + unique_match_idxs = {} + out_match = defaultdict(dict) + for k in kpts.keys(): + uniq_kps, uniq_reverse_idxs = torch.unique(torch.from_numpy(kpts[k]),dim=0, return_inverse=True) + unique_match_idxs[k] = uniq_reverse_idxs + unique_kpts[k] = uniq_kps.numpy() + for k1, group in match_indexes.items(): + for k2, m in group.items(): + m2 = deepcopy(m) + m2[:,0] = unique_match_idxs[k1][m2[:,0]] + m2[:,1] = unique_match_idxs[k2][m2[:,1]] + mkpts = np.concatenate([unique_kpts[k1][ m2[:,0]], + unique_kpts[k2][ m2[:,1]], + ], + axis=1) + unique_idxs_current = get_unique_idxs(torch.from_numpy(mkpts), dim=0) + m2_semiclean = m2[unique_idxs_current] + unique_idxs_current1 = get_unique_idxs(m2_semiclean[:, 0], dim=0) + m2_semiclean = m2_semiclean[unique_idxs_current1] + unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0) + m2_semiclean2 = m2_semiclean[unique_idxs_current2] + out_match[k1][k2] = m2_semiclean2.numpy() + + with h5py.File(fr'{output_dir}\keypoints.h5', mode='w') as f_kp: + for k, kpts1 in unique_kpts.items(): + f_kp[k] = kpts1 + + with h5py.File(fr'{output_dir}\matches.h5', mode='w') as f_match: + for k1, gr in out_match.items(): + group = f_match.require_group(k1) + for k2, match in gr.items(): + group[k2] = match + + try: + os.remove(f"{output_dir}/database.db") + except: + pass + + import_into_colmap( + image_dir, + feature_dir=f"{output_dir}", + database_path=f"{output_dir}/database.db") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='LOFTR and Roma matchers to multi-view. Assign same index to keypoints in different images with distance < 1 px') + parser.add_argument('-i', '--input_dir', type=str, help='Path to directory containing databases features.h5 and matches.h5') + parser.add_argument('-o', '--output_dir', type=str, help='Output directory') + parser.add_argument('-d', '--image_dir', type=str, help='Image directory') + parser.add_argument('-e', '--img_ext', type=str, default='.jpg', help='Image extension') + + args = parser.parse_args() + + input_dir = args.input_dir + output_dir = args.output_dir + image_dir = args.image_dir + img_ext = args.img_ext + + LoftrRomaToMultiview( + input_dir, + output_dir, + image_dir, + img_ext, + ) \ No newline at end of file