Skip to content

Commit

Permalink
refactor: move osm index cache location from local to global (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaczeQ authored Oct 25, 2024
1 parent 9d41249 commit 38613aa
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- Moved location of the OSM extracts providers to the global cache [#173](https://github.com/kraina-ai/quackosm/issues/173)

## [0.11.2] - 2024-10-14

### Added
Expand Down
32 changes: 24 additions & 8 deletions quackosm/osm_extracts/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, cast

import platformdirs

from quackosm._constants import WGS84_CRS

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -56,14 +58,21 @@ def load_index_decorator(

def inner(function: Callable[[], "GeoDataFrame"]) -> Callable[[], "GeoDataFrame"]:
def wrapper() -> "GeoDataFrame":
cache_file_path = _get_cache_file_path(extract_source)
global_cache_file_path = _get_global_cache_file_path(extract_source)
expected_columns = ["id", "name", "file_name", "parent", "geometry", "area", "url"]

# Check if index exists in cache
if cache_file_path.exists():
if global_cache_file_path.exists():
import geopandas as gpd

index_gdf = gpd.read_file(global_cache_file_path)
elif (local_cache_file_path := _get_local_cache_file_path(extract_source)).exists():
import shutil

import geopandas as gpd

index_gdf = gpd.read_file(cache_file_path)
shutil.copy(local_cache_file_path, global_cache_file_path)
index_gdf = gpd.read_file(global_cache_file_path)
# Download index
else: # pragma: no cover
index_gdf = function()
Expand All @@ -87,14 +96,14 @@ def wrapper() -> "GeoDataFrame":
stacklevel=0,
)
# Invalidate previous cached index
cache_file_path.replace(cache_file_path.with_suffix(".geojson.old"))
global_cache_file_path.replace(global_cache_file_path.with_suffix(".geojson.old"))
# Download index again
index_gdf = wrapper()

# Save index to cache
if not cache_file_path.exists():
cache_file_path.parent.mkdir(parents=True, exist_ok=True)
index_gdf[expected_columns].to_file(cache_file_path, driver="GeoJSON")
if not global_cache_file_path.exists():
global_cache_file_path.parent.mkdir(parents=True, exist_ok=True)
index_gdf[expected_columns].to_file(global_cache_file_path, driver="GeoJSON")

return index_gdf

Expand All @@ -112,7 +121,14 @@ def extracts_to_geodataframe(extracts: list[OpenStreetMapExtract]) -> "GeoDataFr
).set_crs(WGS84_CRS)


def _get_cache_file_path(extract_source: OsmExtractSource) -> Path:
def _get_global_cache_file_path(extract_source: OsmExtractSource) -> Path:
return (
Path(platformdirs.user_cache_dir("QuackOSM"))
/ f"{extract_source.value.lower()}_index.geojson"
)


def _get_local_cache_file_path(extract_source: OsmExtractSource) -> Path:
return Path(f"cache/{extract_source.value.lower()}_index.geojson")


Expand Down
6 changes: 3 additions & 3 deletions tests/base/test_osm_extracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
find_smallest_containing_extracts_total,
get_extract_by_query,
)
from quackosm.osm_extracts.extract import _get_cache_file_path, _get_full_file_name_function
from quackosm.osm_extracts.extract import _get_full_file_name_function, _get_global_cache_file_path
from quackosm.osm_extracts.geofabrik import _load_geofabrik_index

ut = TestCase()
Expand Down Expand Up @@ -224,15 +224,15 @@ def test_uncovered_geometry_extract(

def test_proper_cache_saving() -> None:
"""Test if file is saved in cache properly."""
save_path = _get_cache_file_path(OsmExtractSource.geofabrik)
save_path = _get_global_cache_file_path(OsmExtractSource.geofabrik)
loaded_index = _load_geofabrik_index()
assert save_path.exists()
assert len(loaded_index.columns) == 7


def test_wrong_cached_index() -> None:
"""Test if cached file with missing columns is redownloaded again."""
save_path = _get_cache_file_path(OsmExtractSource.geofabrik)
save_path = _get_global_cache_file_path(OsmExtractSource.geofabrik)
column_to_remove = "id"

# load index first time
Expand Down

0 comments on commit 38613aa

Please sign in to comment.