diff --git a/src/deep_image_matching/extractors/extractor_base.py b/src/deep_image_matching/extractors/extractor_base.py index a3600a3..dfca7a2 100644 --- a/src/deep_image_matching/extractors/extractor_base.py +++ b/src/deep_image_matching/extractors/extractor_base.py @@ -205,7 +205,7 @@ def extract(self, img: Union[Image, Path, str]) -> np.ndarray: save_features_h5( feature_path, features, - im_path.name, + img.name, as_half=self.features_as_half, ) @@ -435,4 +435,4 @@ def viz_keypoints( [int(cv2.IMWRITE_JPEG_QUALITY), jpg_quality], ) else: - cv2.imwrite(out_path, out) + cv2.imwrite(out_path, out) \ No newline at end of file diff --git a/src/deep_image_matching/extractors/no_extractor.py b/src/deep_image_matching/extractors/no_extractor.py index 1675966..ba06ec0 100644 --- a/src/deep_image_matching/extractors/no_extractor.py +++ b/src/deep_image_matching/extractors/no_extractor.py @@ -43,7 +43,6 @@ def extract(self, img: Union[Image, Path, str]) -> np.ndarray: output_dir = Path(self.config["general"]["output_dir"]) feature_path = output_dir / "features.h5" output_dir.mkdir(parents=True, exist_ok=True) - im_name = im_path.name # Build fake features features = {} @@ -51,6 +50,7 @@ def extract(self, img: Union[Image, Path, str]) -> np.ndarray: features["descriptors"] = np.array([]) features["scores"] = np.array([]) img_obj = Image(im_path) + im_name = img_obj.name # img_obj.read_exif() features["image_size"] = np.array(img_obj.size) features["tile_idx"] = np.array([]) @@ -99,4 +99,4 @@ def _frame2tensor(self, image: np.ndarray, device: str = "cuda"): if __name__ == "__main__": - pass + pass \ No newline at end of file diff --git a/src/deep_image_matching/image_matching.py b/src/deep_image_matching/image_matching.py index f7abe27..54001d8 100644 --- a/src/deep_image_matching/image_matching.py +++ b/src/deep_image_matching/image_matching.py @@ -420,6 +420,7 @@ def rotate_upright_images( pairs = [(item[0].name, item[1].name) for item in self.pairs] path_to_upright_dir = self.output_dir / "upright_images" os.makedirs(path_to_upright_dir, exist_ok=False) + # I guess will break here, use recursive folder iterator images = os.listdir(self.image_dir) logger.info(f"Copying images to {path_to_upright_dir}") @@ -745,4 +746,4 @@ def rotate_back_features(self, feature_path: Path) -> None: if isinstance(v, np.ndarray): grp.create_dataset(k, data=v) - logger.info("Features rotated back.") + logger.info("Features rotated back.") \ No newline at end of file diff --git a/src/deep_image_matching/io/h5_to_db.py b/src/deep_image_matching/io/h5_to_db.py index 4e9ed71..13c53fd 100644 --- a/src/deep_image_matching/io/h5_to_db.py +++ b/src/deep_image_matching/io/h5_to_db.py @@ -25,9 +25,11 @@ import h5py import numpy as np import yaml -from PIL import ExifTags, Image +from PIL import ExifTags +from PIL import Image as PIL_Image from tqdm import tqdm +from ..utils.image import Image from ..utils.database import COLMAPDatabase, image_ids_to_pair_id logger = logging.getLogger("dim") @@ -127,7 +129,7 @@ def get_focal(image_path: Path, err_on_default: bool = False) -> float: This function calculates the focal length based on the maximum size of the image and the EXIF data. If the focal length cannot be determined from the EXIF data, it uses a default prior value. """ - image = Image.open(image_path) + image = PIL_Image.open(image_path) max_size = max(image.size) exif = image.getexif() @@ -156,7 +158,7 @@ def get_focal(image_path: Path, err_on_default: bool = False) -> float: def create_camera(db: Path, image_path: Path, camera_model: str): - image = Image.open(image_path) + image = PIL_Image.open(image_path) width, height = image.size focal = get_focal(image_path) @@ -237,7 +239,7 @@ def add_keypoints(db: Path, h5_path: Path, image_path: Path, camera_options: dic with h5py.File(str(h5_path), "r") as keypoint_f: fname_to_id = {} - k = 0 + created_cameras = {} for filename in tqdm(list(keypoint_f.keys())): keypoints = keypoint_f[filename]["keypoints"].__array__() @@ -247,19 +249,31 @@ def add_keypoints(db: Path, h5_path: Path, image_path: Path, camera_options: dic if filename not in list(grouped_images.keys()): if camera_options["general"]["single_camera"] is False: - camera_id = create_camera(db, path, camera_options["general"]["camera_model"]) + image = Image(path) + if image.camera_id != None: + if image.camera_id not in created_cameras: + camera_id = create_camera( + db, path, camera_options[f"cam{image.camera_id}"]["camera_model"] + ) + created_cameras[image.camera_id] = camera_id + else: + camera_id = created_cameras[image.camera_id] + else: + camera_id = create_camera( + db, path, camera_options["general"]["camera_model"] + ) + created_cameras[camera_id] = camera_id elif camera_options["general"]["single_camera"] is True: - if k == 0: - camera_id = create_camera(db, path, camera_options["general"]["camera_model"]) + if len(created_cameras) == 0: + camera_id = create_camera( + db, path, camera_options["general"]["camera_model"] + ) single_camera_id = camera_id - k += 1 - elif k > 0: + created_cameras[camera_id] = camera_id + else: camera_id = single_camera_id - elif filename in list(grouped_images.keys()): - camera_id = grouped_images[filename]["camera_id"] else: - print('ERROR in h5_to_db.py') - quit() + camera_id = grouped_images[filename]["camera_id"] image_id = db.add_image(filename, camera_id) fname_to_id[filename] = image_id @@ -402,4 +416,4 @@ def add_matches(db, h5_path, fname_to_id): fname_to_id, ) - db.commit() + db.commit() \ No newline at end of file diff --git a/src/deep_image_matching/matchers/loftr.py b/src/deep_image_matching/matchers/loftr.py index bec1eeb..5c8720b 100644 --- a/src/deep_image_matching/matchers/loftr.py +++ b/src/deep_image_matching/matchers/loftr.py @@ -8,6 +8,7 @@ from ..constants import TileSelection, Timer from ..utils.tiling import Tiler +from ..utils.image import Image from .matcher_base import DetectorFreeMatcherBase, tile_selection logger = logging.getLogger("dim") @@ -92,13 +93,15 @@ def _match_pairs( Raises: torch.cuda.OutOfMemoryError: If an out-of-memory error occurs while matching images. """ - - img0_name = img0_path.name - img1_name = img1_path.name + # Could just rename args but they might be used as keyword args elsewhere + img0 = img0_path + img1 = img1_path + img0_name = img0.name + img1_name = img0.name # Load images - image0 = self._load_image_np(img0_path) - image1 = self._load_image_np(img1_path) + image0 = self._load_image_np(img0.path) + image1 = self._load_image_np(img1.path) # Resize images if needed image0_ = self._resize_image(self._quality, image0) @@ -282,4 +285,4 @@ def _frame2tensor(self, image: np.ndarray, device: str = "cpu") -> torch.Tensor: if image.shape[1] > 2: image = K.color.bgr_to_rgb(image) image = K.color.rgb_to_grayscale(image) - return image + return image \ No newline at end of file diff --git a/src/deep_image_matching/matchers/matcher_base.py b/src/deep_image_matching/matchers/matcher_base.py index 8f8fba5..f6615c4 100644 --- a/src/deep_image_matching/matchers/matcher_base.py +++ b/src/deep_image_matching/matchers/matcher_base.py @@ -17,7 +17,7 @@ from ..thirdparty.hloc.extractors.superpoint import SuperPoint from ..thirdparty.LightGlue.lightglue import LightGlue from ..utils.geometric_verification import geometric_verification -from ..utils.image import resize_image +from ..utils.image import resize_image, Image from ..utils.tiling import Tiler from ..visualization import viz_matches_cv2, viz_matches_mpl @@ -205,10 +205,12 @@ def match( self._feature_path = Path(feature_path) # Get features from h5 file + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name - features0 = get_features(self._feature_path, img0.name) - features1 = get_features(self._feature_path, img1.name) + features0 = get_features(self._feature_path, img0_name) + features1 = get_features(self._feature_path, img1_name) timer_match.update("load h5 features") # Perform matching (on tiles or full images) @@ -328,8 +330,8 @@ def match( self.viz_matches( feature_path, matches_path, - img0, - img1, + img0.path, + img1.path, save_path=viz_dir / f"{img0_name}_{img1_name}.jpg", img_format="jpg", jpg_quality=70, @@ -466,14 +468,14 @@ def viz_matches( jpg_quality = kwargs.get("jpg_quality", 80) hide_matching_track = kwargs.get("hide_matching_track", False) - img0 = Path(img0) - img1 = Path(img1) + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name # Load images - image0 = load_image_np(img0, as_float=False, grayscale=True) - image1 = load_image_np(img1, as_float=False, grayscale=True) + image0 = load_image_np(img0.path, as_float=False, grayscale=True) + image1 = load_image_np(img1.path, as_float=False, grayscale=True) # Load features and matches features0 = get_features(feature_path, img0_name) @@ -648,8 +650,8 @@ def match( else: self._feature_path = Path(feature_path) - img0 = Path(img0) - img1 = Path(img1) + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name @@ -672,7 +674,7 @@ def match( features1 = get_features(feature_path, img1_name) # Rescale threshold according the image original image size - img_shape = cv2.imread(str(img0)).shape + img_shape = cv2.imread(img0.path).shape scale_fct = np.floor(max(img_shape) / self.max_tile_size / 2) gv_threshold = self.config["general"]["gv_threshold"] * scale_fct @@ -854,14 +856,14 @@ def viz_matches( logger.warning("interactive_viz is ignored if fast_viz is True") assert save_path is not None, "output_dir must be specified if fast_viz is True" - img0 = Path(img0) - img1 = Path(img1) + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name # Load images - image0 = load_image_np(img0, self.as_float, self.grayscale) - image1 = load_image_np(img1, self.as_float, self.grayscale) + image0 = load_image_np(img0.path, self.as_float, self.grayscale) + image1 = load_image_np(img1.path, self.as_float, self.grayscale) # Load features and matches features0 = get_features(feature_path, img0_name) @@ -1181,4 +1183,4 @@ def sp2lg(feats: dict) -> dict: def rbd2np(data: dict) -> dict: """Remove batch dimension from elements in data""" - return {k: v[0].cpu().numpy() if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items()} + return {k: v[0].cpu().numpy() if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items()} \ No newline at end of file diff --git a/src/deep_image_matching/matchers/roma.py b/src/deep_image_matching/matchers/roma.py index 50cd6fe..84f2bcd 100644 --- a/src/deep_image_matching/matchers/roma.py +++ b/src/deep_image_matching/matchers/roma.py @@ -13,6 +13,7 @@ from ..io.h5 import get_features from ..thirdparty.RoMa.roma import roma_outdoor from ..utils.geometric_verification import geometric_verification +from ..utils.image import Image from ..utils.tiling import Tiler from ..visualization import viz_matches_cv2 from .matcher_base import DetectorFreeMatcherBase, tile_selection @@ -115,8 +116,8 @@ def match( self._feature_path = Path(feature_path) # Get features from h5 file - img0 = Path(img0) - img1 = Path(img1) + img0 = Image(img0) + img1 = Image(img1) img0_name = img0.name img1_name = img1.name @@ -139,7 +140,7 @@ def match( features1 = get_features(feature_path, img1_name) # Rescale threshold according the image original image size - img_shape = cv2.imread(str(img0)).shape + img_shape = cv2.imread(img0.path).shape tile_size = max(self.config["general"]["tile_size"]) scale_fct = np.floor(max(img_shape) / tile_size / 2) gv_threshold = self.config["general"]["gv_threshold"] * scale_fct @@ -172,8 +173,8 @@ def match( def _match_pairs( self, feature_path: Path, - img0_path: Path, - img1_path: Path, + img0: Image, + img1: Image ): """ Perform matching between feature pairs. @@ -187,12 +188,12 @@ def _match_pairs( np.ndarray: Array containing the indices of matched keypoints. """ - img0_name = img0_path.name - img1_name = img1_path.name + img0_name = img0.name + img1_name = img1.name # Run inference - W_A, H_A = Image.open(img0_path).size - W_B, H_B = Image.open(img1_path).size + W_A, H_A = Image.open(img0.path).size + W_B, H_B = Image.open(img1.path).size #for path in [str(img0_path), str(img1_path)]: # image = cv2.imread(path, cv2.IMREAD_UNCHANGED) @@ -200,7 +201,7 @@ def _match_pairs( # image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # cv2.imwrite(path, image_rgb) - warp, certainty = self.matcher.match(str(img0_path), str(img1_path), device=self._device) + warp, certainty = self.matcher.match(img0.path, img1.path, device=self._device) matches, certainty = self.matcher.sample(warp, certainty) kptsA, kptsB = self.matcher.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B) kptsA, kptsB = kptsA.cpu().numpy(), kptsB.cpu().numpy() @@ -283,8 +284,8 @@ def write_tiles_disk(output_dir: Path, tiles: dict) -> None: timer.update("tile selection") # Read images and resize them if needed - image0 = cv2.imread(str(img0)) - image1 = cv2.imread(str(img1)) + image0 = cv2.imread(img0.path) + image1 = cv2.imread(img1.path) image0 = self._resize_image(self._quality, image0) image1 = self._resize_image(self._quality, image1) @@ -431,4 +432,4 @@ def kps_in_image(kp, img_size, border_thr=2): if not self.keep_tiles: shutil.rmtree(tiles_dir) - return matches + return matches \ No newline at end of file diff --git a/src/deep_image_matching/utils/geometric_verification.py b/src/deep_image_matching/utils/geometric_verification.py index 3e24e26..4bef73f 100644 --- a/src/deep_image_matching/utils/geometric_verification.py +++ b/src/deep_image_matching/utils/geometric_verification.py @@ -1,14 +1,10 @@ import importlib -import logging -from typing import Tuple, Union +from typing import Tuple import cv2 import numpy as np -import pytest -from ..constants import GeometricVerification - -logger = logging.getLogger("dim") +from deep_image_matching import GeometricVerification, logger pydegesac_default_params = { "laf_consistensy_coef": -1.0, @@ -17,7 +13,6 @@ "enable_degeneracy_check": True, } opencv_methods_mapping = { - "NONE": None, "LMEDS": cv2.LMEDS, "RANSAC": cv2.RANSAC, "RHO": cv2.RHO, @@ -32,7 +27,9 @@ def log_result(inlMask: np.ndarray, method: str) -> None: - logger.debug(f"{method} found {inlMask.sum()} inliers ({inlMask.sum()*100/len(inlMask):.2f}%)") + logger.debug( + f"{method} found {inlMask.sum()} inliers ({inlMask.sum()*100/len(inlMask):.2f}%)" + ) def log_error(err: Exception, method: str, fallback: bool = False) -> None: @@ -44,7 +41,7 @@ def log_error(err: Exception, method: str, fallback: bool = False) -> None: def geometric_verification( kpts0: np.ndarray = None, kpts1: np.ndarray = None, - method: Union[str, int, GeometricVerification] = "pydegensac", + method: GeometricVerification = GeometricVerification.PYDEGENSAC, threshold: float = 1, confidence: float = 0.9999, max_iters: int = 10000, @@ -68,24 +65,10 @@ def geometric_verification( - inlMask: a Boolean array that masks the correspondences that were identified as inliers. """ - gv_names = [gv.name for gv in GeometricVerification] - if isinstance(method, str): - try: - method = GeometricVerification[method.upper()] - except KeyError: - raise ValueError(f"Invalid Geometry Verification method. It must be one of {gv_names}") - elif isinstance(method, int): - try: - method = GeometricVerification(method) - except ValueError: - raise ValueError(f"Invalid Geometry Verification method. It must be one of {gv_names}") - if not isinstance(method, GeometricVerification): - raise ValueError( - f"Invalid Geometry Verification method. It must be a GeometricVerification enum, a string with the method name among {gv_names} or an integer corresponding to the method index." - ) - if method == GeometricVerification.NONE: - return None, np.ones(len(kpts0), dtype=bool) + assert isinstance( + method, GeometricVerification + ), "Invalid method. It must be a GeometricVerification enum in GeometricVerification.PYDEGENSAC or GeometricVerification.MAGSAC." fallback = False F = None @@ -100,7 +83,9 @@ def geometric_verification( try: pydegensac = importlib.import_module("pydegensac") except ImportError: - logger.warning("Pydegensac not available. Using RANSAC (OpenCV) for geometric verification.") + logger.warning( + "Pydegensac not available. Using RANSAC (OpenCV) for geometric verification." + ) fallback = True if method == GeometricVerification.PYDEGENSAC and not fallback: @@ -126,7 +111,9 @@ def geometric_verification( if method == GeometricVerification.MAGSAC: try: - F, inliers = cv2.findFundamentalMat(kpts0, kpts1, cv2.USAC_MAGSAC, threshold, confidence, max_iters) + F, inliers = cv2.findFundamentalMat( + kpts0, kpts1, cv2.USAC_MAGSAC, threshold, confidence, max_iters + ) inlMask = (inliers > 0).squeeze() if not quiet: log_result(inlMask, method.name) @@ -140,7 +127,9 @@ def geometric_verification( logger.debug(f"Method was set to {method}, trying to use it from OPENCV...") met = opencv_methods_mapping[method.name] try: - F, inliers = cv2.findFundamentalMat(kpts0, kpts1, met, threshold, confidence, max_iters) + F, inliers = cv2.findFundamentalMat( + kpts0, kpts1, met, threshold, confidence, max_iters + ) inlMask = (inliers > 0).squeeze() if not quiet: log_result(inlMask, method.name) @@ -152,7 +141,9 @@ def geometric_verification( # Use RANSAC as fallback if method == GeometricVerification.RANSAC or fallback: try: - F, inliers = cv2.findFundamentalMat(kpts0, kpts1, cv2.RANSAC, threshold, confidence, max_iters) + F, inliers = cv2.findFundamentalMat( + kpts0, kpts1, cv2.RANSAC, threshold, confidence, max_iters + ) inlMask = (inliers > 0).squeeze() if not quiet: log_result(inlMask, method.name) @@ -163,21 +154,4 @@ def geometric_verification( if not quiet: logger.debug(f"Estiamted Fundamental matrix: \n{F}") - return F, inlMask - - -if __name__ == "__main__": - # Generate random keypoints - rng = np.random.default_rng(12345) - kpts0 = rng.random((100, 2)) - kpts1 = rng.random((100, 2)) - method = GeometricVerification.PYDEGENSAC - F, mask = geometric_verification( - kpts0=kpts0, - kpts1=kpts1, - method=method, - threshold=1, - confidence=0.9999, - max_iters=10000, - ) - print(F) + return F, inlMask \ No newline at end of file diff --git a/src/deep_image_matching/utils/image.py b/src/deep_image_matching/utils/image.py index 5256322..bb5c085 100644 --- a/src/deep_image_matching/utils/image.py +++ b/src/deep_image_matching/utils/image.py @@ -111,20 +111,39 @@ def __init__(self, path: Union[str, Path], id: int = None) -> None: self._exif_data = None self._date_time = None self._focal_length = None - + self._camera_id = None + self._name = None + if "cam" in str(path): + for i, part in enumerate(path.parts): + if part.startswith("cam") and part[3:].isdigit(): + self._camera_id = eval(part[3:]) + rel_path = Path(*path.parts[i:]) + self._name = str(rel_path) + else: + self._name = path.name + try: self.read_exif() except Exception: img = PIL.Image.open(path) self._width, self._height = img.size + + img = PIL.Image.open(path) + self._width, self._height = img.size + def __repr__(self) -> str: """Returns a string representation of the image""" - return f"Image {self._path}" + return f"Image {self.name}" def __str__(self) -> str: """Returns a string representation of the image""" - return f"Image {self._path}" + return f"Image {self.name}" + + @property + def camera_id(self) -> int: + """Returns the camera_id of the image, if defined""" + return self._camera_id @property def id(self) -> int: @@ -136,12 +155,12 @@ def id(self) -> int: @property def name(self) -> str: """Returns the name of the image (including extension)""" - return self._path.name + return self._name @property def stem(self) -> str: """Returns the name of the image (excluding extension)""" - return self._path.stem + return self._name.stem @property def path(self) -> Path: @@ -408,7 +427,7 @@ def __init__(self, img_dir: Path): self.images = [] self.current_idx = 0 i = 0 - all_imgs = [image for image in img_dir.glob("*") if image.suffix in self.IMAGE_EXT] + all_imgs = [image for image in img_dir.rglob("*") if image.suffix in self.IMAGE_EXT] all_imgs.sort() if len(all_imgs) == 0: @@ -478,4 +497,4 @@ def img_paths(self): img_list = ImageList(image_dir) - print("done") + print("done") \ No newline at end of file