Skip to content

Commit

Permalink
loftr and roma to multiview directly in main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lcmrl committed Sep 21, 2024
1 parent 50bd87e commit 8c20c58
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 8 deletions.
12 changes: 12 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import logging
from importlib import import_module
from pathlib import Path

import deep_image_matching as dim
from deep_image_matching.utils.loftr_roma_to_multiview import LoftrRomaToMultiview
import yaml

logger = dim.setup_logger("dim")
Expand Down Expand Up @@ -30,6 +33,15 @@
camera_config_path=config.general["camera_options"],
)

if matcher.matching in ["loftr", "se2loftr", "roma"]:
images = os.listdir(imgs_dir)
image_format = Path(images[0]).suffix
LoftrRomaToMultiview(
input_dir=feature_path.parent,
output_dir=feature_path.parent,
image_dir=imgs_dir,
img_ext=image_format)

# Visualize view graph
if config.general["graph"]:
try:
Expand Down
13 changes: 5 additions & 8 deletions src/deep_image_matching/io/h5_to_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def add_keypoints(db: Path, h5_path: Path, image_path: Path, camera_options: dic
grouped_images = parse_camera_options(camera_options, db, image_path)

with h5py.File(str(h5_path), "r") as keypoint_f:
# camera_id = None
fname_to_id = {}
k = 0
for filename in tqdm(list(keypoint_f.keys())):
Expand All @@ -256,18 +255,16 @@ def add_keypoints(db: Path, h5_path: Path, image_path: Path, camera_options: dic
k += 1
elif k > 0:
camera_id = single_camera_id
else:
elif filename in list(grouped_images.keys()):
camera_id = grouped_images[filename]["camera_id"]
else:
print('ERROR in h5_to_db.py')
quit()
image_id = db.add_image(filename, camera_id)
fname_to_id[filename] = image_id
# print('keypoints')
# print(keypoints)
# print('image_id', image_id)

if len(keypoints.shape) >= 2:
db.add_keypoints(image_id, keypoints)
# else:
# keypoints =
# db.add_keypoints(image_id, keypoints)

return fname_to_id

Expand Down
20 changes: 20 additions & 0 deletions src/deep_image_matching/utils/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,23 @@ def get_matches(self) -> tuple:
matches[im_ids] = mtc

return matches, images

def get_images(self) -> dict:
query = "SELECT image_id, name, camera_id FROM images"
data = self.execute(query).fetchall()
images = {}
for d in data:
image_id = d[0]
image_name = d[1]
camera_id = d[2]
images[image_name] = (image_id, camera_id)
return images

def clean_keypoints(self):
self.execute("DELETE FROM keypoints")

def clean_matches(self):
self.execute("DELETE FROM matches")

def clean_two_view_geometries(self):
self.execute("DELETE FROM two_view_geometries")
248 changes: 248 additions & 0 deletions src/deep_image_matching/utils/loftr_roma_to_multiview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import h5py
import torch
import numpy as np
from pathlib import Path
from copy import deepcopy
from collections import defaultdict
from deep_image_matching.io.h5_to_db import COLMAPDatabase, image_ids_to_pair_id

import os, h5py, warnings
import numpy as np
from tqdm import tqdm
from PIL import Image, ExifTags
import argparse

# Credit to: https://github.com/ducha-aiki/imc2023-kornia-starter-pack/blob/main/loftr-pycolmap-3dreconstruction.ipynb

def get_focal(image_path, err_on_default=False):
image = Image.open(image_path)
max_size = max(image.size)

exif = image.getexif()
focal = None
if exif is not None:
focal_35mm = None
# https://github.com/colmap/colmap/blob/d3a29e203ab69e91eda938d6e56e1c7339d62a99/src/util/bitmap.cc#L299
for tag, value in exif.items():
focal_35mm = None
if ExifTags.TAGS.get(tag, None) == 'FocalLengthIn35mmFilm':
focal_35mm = float(value)
break

if focal_35mm is not None:
focal = focal_35mm / 35. * max_size

if focal is None:
if err_on_default:
raise RuntimeError("Failed to find focal length")

# failed to find it in exif, use prior
FOCAL_PRIOR = 1.2
focal = FOCAL_PRIOR * max_size

return focal

def create_camera(db, image_path, camera_model):
image = Image.open(image_path)
width, height = image.size

focal = get_focal(image_path)

if camera_model == 'simple-pinhole':
model = 0 # simple pinhole
param_arr = np.array([focal, width / 2, height / 2])
if camera_model == 'pinhole':
model = 1 # pinhole
param_arr = np.array([focal, focal, width / 2, height / 2])
elif camera_model == 'simple-radial':
model = 2 # simple radial
param_arr = np.array([focal, width / 2, height / 2, 0.1])
elif camera_model == 'opencv':
model = 4 # opencv
param_arr = np.array([focal, focal, width / 2, height / 2, 0., 0., 0., 0.])

return db.add_camera(model, width, height, param_arr)


def add_keypoints(db, h5_path, image_path, camera_model, single_camera = True):
keypoint_f = h5py.File(os.path.join(h5_path, 'keypoints.h5'), 'r')
fname_to_id = {}
db.clean_keypoints()

for filename in tqdm(list(keypoint_f.keys())):
keypoints = keypoint_f[filename][()]
fname_with_ext = filename
path = os.path.join(image_path, fname_with_ext)
if not os.path.isfile(path):
raise IOError(f'Invalid image path {path}')
images = db.get_images()
image_id, camera_id = images[filename]
fname_to_id[filename] = image_id
db.add_keypoints(image_id, keypoints)

return fname_to_id

def add_matches(db, h5_path, fname_to_id):
db.clean_matches()
db.clean_two_view_geometries()
match_file = h5py.File(os.path.join(h5_path, 'matches.h5'), 'r')

added = set()
n_keys = len(match_file.keys())
n_total = (n_keys * (n_keys - 1)) // 2

with tqdm(total=n_total) as pbar:
for key_1 in match_file.keys():
group = match_file[key_1]
for key_2 in group.keys():
id_1 = fname_to_id[key_1]
id_2 = fname_to_id[key_2]

pair_id = image_ids_to_pair_id(id_1, id_2)
if pair_id in added:
warnings.warn(f'Pair {pair_id} ({id_1}, {id_2}) already added!')
continue

matches = group[key_2][()]
#db.add_matches(id_1, id_2, matches)
db.add_two_view_geometry(id_1, id_2, matches)

added.add(pair_id)

pbar.update(1)

def get_unique_idxs(A, dim=1):
# https://stackoverflow.com/questions/72001505/how-to-get-unique-elements-and-their-firstly-appeared-indices-of-a-pytorch-tenso
unique, idx, counts = torch.unique(A, dim=dim, sorted=True, return_inverse=True, return_counts=True)
_, ind_sorted = torch.sort(idx, stable=True)
cum_sum = counts.cumsum(0)
cum_sum = torch.cat((torch.tensor([0],device=cum_sum.device), cum_sum[:-1]))
first_indicies = ind_sorted[cum_sum]
return first_indicies

def import_into_colmap(img_dir,
feature_dir ='.featureout',
database_path = 'colmap.db'
):
db = COLMAPDatabase.connect(database_path)
#db.create_tables()
single_camera = False
fname_to_id = add_keypoints(db, feature_dir, img_dir, 'simple-radial', single_camera)
add_matches(
db,
feature_dir,
fname_to_id,
)

db.commit()
return

def LoftrRomaToMultiview(
input_dir: Path,
output_dir: Path,
image_dir: Path,
img_ext: Path,
) -> None:

with h5py.File(fr'{input_dir}\features.h5', mode='r') as h5_feats, \
h5py.File(fr'{input_dir}\matches.h5', mode='r') as h5_matches, \
h5py.File(fr'{input_dir}\matches_loftr.h5', mode='w') as h5_out:

for img1 in h5_matches.keys():
print(img1)
kpts1 = h5_feats[img1]['keypoints'][...]
group_match = h5_matches[img1]
group_out = h5_out.require_group(img1)
for img2 in group_match.keys():
print(f"--- {img2}")
kpts2 = h5_feats[img2]['keypoints'][...]
matches = group_match[img2][...]
h5_out[img1][img2] = np.hstack((kpts1[matches[:,0],:], kpts2[matches[:,1],:]))

kpts = defaultdict(list)
match_indexes = defaultdict(dict)
total_kpts=defaultdict(int)

with h5py.File(fr'{input_dir}\matches_loftr.h5', mode='r') as f_match:
for k1 in f_match.keys():
group = f_match[k1]
for k2 in group.keys():
matches = group[k2][...]
total_kpts[k1]
kpts[k1].append(matches[:, :2])
kpts[k2].append(matches[:, 2:])
current_match = torch.arange(len(matches)).reshape(-1, 1).repeat(1, 2)
current_match[:, 0]+=total_kpts[k1]
current_match[:, 1]+=total_kpts[k2]
total_kpts[k1]+=len(matches)
total_kpts[k2]+=len(matches)
match_indexes[k1][k2]=current_match

for k in kpts.keys():
kpts[k] = np.round(np.concatenate(kpts[k], axis=0))

unique_kpts = {}
unique_match_idxs = {}
out_match = defaultdict(dict)
for k in kpts.keys():
uniq_kps, uniq_reverse_idxs = torch.unique(torch.from_numpy(kpts[k]),dim=0, return_inverse=True)
unique_match_idxs[k] = uniq_reverse_idxs
unique_kpts[k] = uniq_kps.numpy()
for k1, group in match_indexes.items():
for k2, m in group.items():
m2 = deepcopy(m)
m2[:,0] = unique_match_idxs[k1][m2[:,0]]
m2[:,1] = unique_match_idxs[k2][m2[:,1]]
mkpts = np.concatenate([unique_kpts[k1][ m2[:,0]],
unique_kpts[k2][ m2[:,1]],
],
axis=1)
unique_idxs_current = get_unique_idxs(torch.from_numpy(mkpts), dim=0)
m2_semiclean = m2[unique_idxs_current]
unique_idxs_current1 = get_unique_idxs(m2_semiclean[:, 0], dim=0)
m2_semiclean = m2_semiclean[unique_idxs_current1]
unique_idxs_current2 = get_unique_idxs(m2_semiclean[:, 1], dim=0)
m2_semiclean2 = m2_semiclean[unique_idxs_current2]
out_match[k1][k2] = m2_semiclean2.numpy()

with h5py.File(fr'{output_dir}\keypoints.h5', mode='w') as f_kp:
for k, kpts1 in unique_kpts.items():
f_kp[k] = kpts1

with h5py.File(fr'{output_dir}\matches.h5', mode='w') as f_match:
for k1, gr in out_match.items():
group = f_match.require_group(k1)
for k2, match in gr.items():
group[k2] = match

try:
os.remove(f"{output_dir}/database.db")
except:
pass

import_into_colmap(
image_dir,
feature_dir=f"{output_dir}",
database_path=f"{output_dir}/database.db")


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='LOFTR and Roma matchers to multi-view. Assign same index to keypoints in different images with distance < 1 px')
parser.add_argument('-i', '--input_dir', type=str, help='Path to directory containing databases features.h5 and matches.h5')
parser.add_argument('-o', '--output_dir', type=str, help='Output directory')
parser.add_argument('-d', '--image_dir', type=str, help='Image directory')
parser.add_argument('-e', '--img_ext', type=str, default='.jpg', help='Image extension')

args = parser.parse_args()

input_dir = args.input_dir
output_dir = args.output_dir
image_dir = args.image_dir
img_ext = args.img_ext

LoftrRomaToMultiview(
input_dir,
output_dir,
image_dir,
img_ext,
)

0 comments on commit 8c20c58

Please sign in to comment.