From 38613aaf7c9aeddfb7cfa089e13522a81ad96951 Mon Sep 17 00:00:00 2001 From: Kamil Raczycki Date: Fri, 25 Oct 2024 20:34:16 +0200 Subject: [PATCH] refactor: move osm index cache location from local to global (#174) --- CHANGELOG.md | 4 ++++ quackosm/osm_extracts/extract.py | 32 ++++++++++++++++++++++++-------- tests/base/test_osm_extracts.py | 6 +++--- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d43d3d..e1d3b0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Moved location of the OSM extracts providers to the global cache [#173](https://github.com/kraina-ai/quackosm/issues/173) + ## [0.11.2] - 2024-10-14 ### Added diff --git a/quackosm/osm_extracts/extract.py b/quackosm/osm_extracts/extract.py index 0a4e848..0d567ac 100644 --- a/quackosm/osm_extracts/extract.py +++ b/quackosm/osm_extracts/extract.py @@ -6,6 +6,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, cast +import platformdirs + from quackosm._constants import WGS84_CRS if TYPE_CHECKING: # pragma: no cover @@ -56,14 +58,21 @@ def load_index_decorator( def inner(function: Callable[[], "GeoDataFrame"]) -> Callable[[], "GeoDataFrame"]: def wrapper() -> "GeoDataFrame": - cache_file_path = _get_cache_file_path(extract_source) + global_cache_file_path = _get_global_cache_file_path(extract_source) expected_columns = ["id", "name", "file_name", "parent", "geometry", "area", "url"] # Check if index exists in cache - if cache_file_path.exists(): + if global_cache_file_path.exists(): + import geopandas as gpd + + index_gdf = gpd.read_file(global_cache_file_path) + elif (local_cache_file_path := _get_local_cache_file_path(extract_source)).exists(): + import shutil + import geopandas as gpd - index_gdf = gpd.read_file(cache_file_path) + shutil.copy(local_cache_file_path, global_cache_file_path) + index_gdf = gpd.read_file(global_cache_file_path) # Download index else: # pragma: no cover index_gdf = function() @@ -87,14 +96,14 @@ def wrapper() -> "GeoDataFrame": stacklevel=0, ) # Invalidate previous cached index - cache_file_path.replace(cache_file_path.with_suffix(".geojson.old")) + global_cache_file_path.replace(global_cache_file_path.with_suffix(".geojson.old")) # Download index again index_gdf = wrapper() # Save index to cache - if not cache_file_path.exists(): - cache_file_path.parent.mkdir(parents=True, exist_ok=True) - index_gdf[expected_columns].to_file(cache_file_path, driver="GeoJSON") + if not global_cache_file_path.exists(): + global_cache_file_path.parent.mkdir(parents=True, exist_ok=True) + index_gdf[expected_columns].to_file(global_cache_file_path, driver="GeoJSON") return index_gdf @@ -112,7 +121,14 @@ def extracts_to_geodataframe(extracts: list[OpenStreetMapExtract]) -> "GeoDataFr ).set_crs(WGS84_CRS) -def _get_cache_file_path(extract_source: OsmExtractSource) -> Path: +def _get_global_cache_file_path(extract_source: OsmExtractSource) -> Path: + return ( + Path(platformdirs.user_cache_dir("QuackOSM")) + / f"{extract_source.value.lower()}_index.geojson" + ) + + +def _get_local_cache_file_path(extract_source: OsmExtractSource) -> Path: return Path(f"cache/{extract_source.value.lower()}_index.geojson") diff --git a/tests/base/test_osm_extracts.py b/tests/base/test_osm_extracts.py index f5f8676..c8d55fc 100644 --- a/tests/base/test_osm_extracts.py +++ b/tests/base/test_osm_extracts.py @@ -27,7 +27,7 @@ find_smallest_containing_extracts_total, get_extract_by_query, ) -from quackosm.osm_extracts.extract import _get_cache_file_path, _get_full_file_name_function +from quackosm.osm_extracts.extract import _get_full_file_name_function, _get_global_cache_file_path from quackosm.osm_extracts.geofabrik import _load_geofabrik_index ut = TestCase() @@ -224,7 +224,7 @@ def test_uncovered_geometry_extract( def test_proper_cache_saving() -> None: """Test if file is saved in cache properly.""" - save_path = _get_cache_file_path(OsmExtractSource.geofabrik) + save_path = _get_global_cache_file_path(OsmExtractSource.geofabrik) loaded_index = _load_geofabrik_index() assert save_path.exists() assert len(loaded_index.columns) == 7 @@ -232,7 +232,7 @@ def test_proper_cache_saving() -> None: def test_wrong_cached_index() -> None: """Test if cached file with missing columns is redownloaded again.""" - save_path = _get_cache_file_path(OsmExtractSource.geofabrik) + save_path = _get_global_cache_file_path(OsmExtractSource.geofabrik) column_to_remove = "id" # load index first time