Skip to content

Commit

Permalink
Remove GDAL and RichDEM dependancy from tests (#675)
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn authored Jan 23, 2025
1 parent c316ff1 commit 5c5639d
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 313 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ xdem/_version.py

# Example data downloaded/produced during tests
examples/data/
tests/test_data/

doc/source/basic_examples/
doc/source/advanced_examples/
Expand Down
32 changes: 9 additions & 23 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ ifndef VENV
VENV = "venv"
endif

# Python version requirement
PYTHON_VERSION_REQUIRED = 3.10

# Python global variables definition
PYTHON_VERSION_MIN = 3.10
# Set PYTHON if not defined in command line
# Example: PYTHON="python3.10" make venv to use python 3.10 for the venv
# By default the default python3 of the system.
ifndef PYTHON
# Try to find python version required
PYTHON = "python$(PYTHON_VERSION_REQUIRED)"
PYTHON = "python3"
endif
PYTHON_CMD=$(shell command -v $(PYTHON))

PYTHON_VERSION_CUR=$(shell $(PYTHON_CMD) -c 'import sys; print("%d.%d" % sys.version_info[0:2])')
PYTHON_VERSION_OK=$(shell $(PYTHON_CMD) -c 'import sys; req_ver = tuple(map(int, "$(PYTHON_VERSION_REQUIRED)".split("."))); cur_ver = sys.version_info[0:2]; print(int(cur_ver == req_ver))')
PYTHON_VERSION_CUR=$(shell $(PYTHON_CMD) -c 'import sys; print("%d.%d"% sys.version_info[0:2])')
PYTHON_VERSION_OK=$(shell $(PYTHON_CMD) -c 'import sys; cur_ver = sys.version_info[0:2]; min_ver = tuple(map(int, "$(PYTHON_VERSION_MIN)".split("."))); print(int(cur_ver >= min_ver))')

############### Check python version supported ############

Expand All @@ -30,7 +31,7 @@ ifeq (, $(PYTHON_CMD))
endif

ifeq ($(PYTHON_VERSION_OK), 0)
$(error "Requires Python version == $(PYTHON_VERSION_REQUIRED). Current version is $(PYTHON_VERSION_CUR)")
$(error "Requires Python version >= $(PYTHON_VERSION_MIN). Current version is $(PYTHON_VERSION_CUR)")
endif

################ MAKE Targets ######################
Expand All @@ -45,19 +46,6 @@ venv: ## Create a virtual environment in 'venv' directory if it doesn't exist
@touch ${VENV}/bin/activate
@${VENV}/bin/python -m pip install --upgrade wheel setuptools pip

.PHONY: install-gdal
install-gdal: ## Install GDAL version matching the system's GDAL via pip
@if command -v gdalinfo >/dev/null 2>&1; then \
GDAL_VERSION=$$(gdalinfo --version | awk '{print $$2}'); \
echo "System GDAL version: $$GDAL_VERSION"; \
${VENV}/bin/pip install gdal==$$GDAL_VERSION; \
else \
echo "Warning: GDAL not found on the system. Proceeding without GDAL."; \
echo "Try installing GDAL by running the following commands depending on your system:"; \
echo "Debian/Ubuntu: sudo apt-get install -y gdal-bin libgdal-dev"; \
echo "Red Hat/CentOS: sudo yum install -y gdal gdal-devel"; \
echo "Then run 'make install-gdal' to proceed with GDAL installation."; \
fi

.PHONY: install
install: venv ## Install xDEM for development (depends on venv)
Expand All @@ -66,8 +54,6 @@ install: venv ## Install xDEM for development (depends on venv)
@test -f .git/hooks/pre-commit || echo "Installing pre-commit hooks"
@test -f .git/hooks/pre-commit || ${VENV}/bin/pre-commit install -t pre-commit
@test -f .git/hooks/pre-push || ${VENV}/bin/pre-commit install -t pre-push
@echo "Attempting to install GDAL..."
@make install-gdal
@echo "xdem installed in development mode in virtualenv ${VENV}"
@echo "To use: source ${VENV}/bin/activate; xdem -h"

Expand Down
2 changes: 0 additions & 2 deletions dev-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@ dependencies:
- scikit-learn

# Test dependencies
- gdal # To test against GDAL
- pytest
- pytest-xdist
- pyyaml
- flake8
- pylint
- richdem # To test against richdem

# Doc dependencies
- sphinx
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ test =
flake8
pylint
scikit-learn
richdem
doc =
sphinx
sphinx-book-theme
Expand Down
147 changes: 15 additions & 132 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,142 +1,25 @@
from typing import Callable, List, Union
import os
from typing import Callable

import geoutils as gu
import numpy as np
import pytest
import richdem as rd
from geoutils.raster import RasterType

from xdem._typing import NDArrayf
from xdem.examples import download_and_extract_tarball


@pytest.fixture(scope="session") # type: ignore
def raster_to_rda() -> Callable[[RasterType], rd.rdarray]:
def _raster_to_rda(rst: RasterType) -> rd.rdarray:
"""
Convert geoutils.Raster to richDEM rdarray.
"""
arr = rst.data.filled(rst.nodata).squeeze()
rda = rd.rdarray(arr, no_data=rst.nodata)
rda.geotransform = rst.transform.to_gdal()
return rda

return _raster_to_rda


@pytest.fixture(scope="session") # type: ignore
def get_terrainattr_richdem(raster_to_rda: Callable[[RasterType], rd.rdarray]) -> Callable[[RasterType, str], NDArrayf]:
def _get_terrainattr_richdem(rst: RasterType, attribute: str = "slope_radians") -> NDArrayf:
"""
Derive terrain attribute for DEM opened with geoutils.Raster using RichDEM.
"""
rda = raster_to_rda(rst)
terrattr = rd.TerrainAttribute(rda, attrib=attribute)
terrattr[terrattr == terrattr.no_data] = np.nan
return np.array(terrattr)

return _get_terrainattr_richdem
_TESTDATA_DIRECTORY = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "tests", "test_data"))


@pytest.fixture(scope="session") # type: ignore
def get_terrain_attribute_richdem(
get_terrainattr_richdem: Callable[[RasterType, str], NDArrayf]
) -> Callable[[RasterType, Union[str, list[str]], bool, float, float, float], Union[RasterType, list[RasterType]]]:
def _get_terrain_attribute_richdem(
dem: RasterType,
attribute: Union[str, List[str]],
degrees: bool = True,
hillshade_altitude: float = 45.0,
hillshade_azimuth: float = 315.0,
hillshade_z_factor: float = 1.0,
) -> Union[RasterType, List[RasterType]]:
"""
Derive one or multiple terrain attributes from a DEM using RichDEM.
"""
if isinstance(attribute, str):
attribute = [attribute]

if not isinstance(dem, gu.Raster):
raise ValueError("DEM must be a geoutils.Raster object.")

terrain_attributes = {}

# Check which products should be made to optimize the processing
make_aspect = any(attr in attribute for attr in ["aspect", "hillshade"])
make_slope = any(
attr in attribute
for attr in [
"slope",
"hillshade",
"planform_curvature",
"aspect",
"profile_curvature",
"maximum_curvature",
]
)
make_hillshade = "hillshade" in attribute
make_curvature = "curvature" in attribute
make_planform_curvature = "planform_curvature" in attribute or "maximum_curvature" in attribute
make_profile_curvature = "profile_curvature" in attribute or "maximum_curvature" in attribute

if make_slope:
terrain_attributes["slope"] = get_terrainattr_richdem(dem, "slope_radians")

if make_aspect:
# The aspect of RichDEM is returned in degrees, we convert to radians to match the others
terrain_attributes["aspect"] = np.deg2rad(get_terrainattr_richdem(dem, "aspect"))
# For flat slopes, RichDEM returns a 90° aspect by default, while GDAL return a 180° aspect
# We stay consistent with GDAL
slope_tmp = get_terrainattr_richdem(dem, "slope_radians")
terrain_attributes["aspect"][slope_tmp == 0] = np.pi

if make_hillshade:
# If a different z-factor was given, slopemap with exaggerated gradients.
if hillshade_z_factor != 1.0:
slopemap = np.arctan(np.tan(terrain_attributes["slope"]) * hillshade_z_factor)
else:
slopemap = terrain_attributes["slope"]

azimuth_rad = np.deg2rad(360 - hillshade_azimuth)
altitude_rad = np.deg2rad(hillshade_altitude)

# The operation below yielded the closest hillshade to GDAL (multiplying by 255 did not work)
# As 0 is generally no data for this uint8, we add 1 and then 0.5 for the rounding to occur between
# 1 and 255
terrain_attributes["hillshade"] = np.clip(
1.5
+ 254
* (
np.sin(altitude_rad) * np.cos(slopemap)
+ np.cos(altitude_rad) * np.sin(slopemap) * np.sin(azimuth_rad - terrain_attributes["aspect"])
),
0,
255,
).astype("float32")

if make_curvature:
terrain_attributes["curvature"] = get_terrainattr_richdem(dem, "curvature")

if make_planform_curvature:
terrain_attributes["planform_curvature"] = get_terrainattr_richdem(dem, "planform_curvature")

if make_profile_curvature:
terrain_attributes["profile_curvature"] = get_terrainattr_richdem(dem, "profile_curvature")

# Convert the unit if wanted.
if degrees:
for attr in ["slope", "aspect"]:
if attr not in terrain_attributes:
continue
terrain_attributes[attr] = np.rad2deg(terrain_attributes[attr])

output_attributes = [terrain_attributes[key].reshape(dem.shape) for key in attribute]
def get_test_data_path() -> Callable[[str], str]:
def _get_test_data_path(filename: str, overwrite: bool = False) -> str:
"""Get file from test_data"""
download_and_extract_tarball(dir="test_data", target_dir=_TESTDATA_DIRECTORY, overwrite=overwrite)
file_path = os.path.join(_TESTDATA_DIRECTORY, filename)

if isinstance(dem, gu.Raster):
output_attributes = [
gu.Raster.from_array(attr, transform=dem.transform, crs=dem.crs, nodata=-99999)
for attr in output_attributes
]
if not os.path.exists(file_path):
if overwrite:
raise FileNotFoundError(f"The file {filename} was not found in the test_data directory.")
file_path = _get_test_data_path(filename, overwrite=True)

return output_attributes if len(output_attributes) > 1 else output_attributes[0]
return file_path

return _get_terrain_attribute_richdem
return _get_test_data_path
54 changes: 4 additions & 50 deletions tests/test_coreg/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import os.path
import warnings

import geopandas as gpd
Expand All @@ -11,7 +12,6 @@
import pytransform3d
import rasterio as rio
from geoutils import Raster, Vector
from geoutils._typing import NDArrayNum
from geoutils.raster import RasterType
from geoutils.raster.geotransformations import _translate
from scipy.ndimage import binary_dilation
Expand Down Expand Up @@ -42,53 +42,6 @@ def load_examples(crop: bool = True) -> tuple[RasterType, RasterType, Vector]:
return reference_dem, to_be_aligned_dem, glacier_mask


def gdal_reproject_horizontal_shift_samecrs(filepath_example: str, xoff: float, yoff: float) -> NDArrayNum:
"""
Reproject horizontal shift in same CRS with GDAL for testing purposes.
:param filepath_example: Path to raster file.
:param xoff: X shift in georeferenced unit.
:param yoff: Y shift in georeferenced unit.
:return: Reprojected shift array in the same CRS.
"""

from osgeo import gdal, gdalconst

# Open source raster from file
src = gdal.Open(filepath_example, gdalconst.GA_ReadOnly)

# Create output raster in memory
driver = "MEM"
method = gdal.GRA_Bilinear
drv = gdal.GetDriverByName(driver)
dest = drv.Create("", src.RasterXSize, src.RasterYSize, 1, gdal.GDT_Float32)
proj = src.GetProjection()
ndv = src.GetRasterBand(1).GetNoDataValue()
dest.SetProjection(proj)

# Shift the horizontally shifted geotransform
gt = src.GetGeoTransform()
gtl = list(gt)
gtl[0] += xoff
gtl[3] += yoff
dest.SetGeoTransform(tuple(gtl))

# Copy the raster metadata of the source to dest
dest.SetMetadata(src.GetMetadata())
dest.GetRasterBand(1).SetNoDataValue(ndv)
dest.GetRasterBand(1).Fill(ndv)

# Reproject with resampling
gdal.ReprojectImage(src, dest, proj, proj, method)

# Extract reprojected array
array = dest.GetRasterBand(1).ReadAsArray().astype("float32")
array[array == ndv] = np.nan

return array


class TestAffineCoreg:

ref, tba, outlines = load_examples() # Load example reference, to-be-aligned and mask.
Expand Down Expand Up @@ -121,7 +74,7 @@ class TestAffineCoreg:
"xoff_yoff",
[(ref.res[0], ref.res[1]), (10 * ref.res[0], 10 * ref.res[1]), (-1.2 * ref.res[0], -1.2 * ref.res[1])],
) # type: ignore
def test_reproject_horizontal_shift_samecrs__gdal(self, xoff_yoff: tuple[float, float]) -> None:
def test_reproject_horizontal_shift_samecrs__gdal(self, xoff_yoff: tuple[float, float], get_test_data_path) -> None:
"""Check that the same-CRS reprojection based on SciPy (replacing Rasterio due to subpixel errors)
is accurate by comparing to GDAL."""

Expand All @@ -135,7 +88,8 @@ def test_reproject_horizontal_shift_samecrs__gdal(self, xoff_yoff: tuple[float,
)

# Reproject with GDAL
output2 = gdal_reproject_horizontal_shift_samecrs(filepath_example=ref.filename, xoff=xoff, yoff=yoff)
path_output2 = get_test_data_path(os.path.join("gdal", f"shifted_reprojected_xoff{xoff}_yoff{yoff}.tif"))
output2 = Raster(path_output2).data.data

# Reproject and NaN propagation is exactly the same for shifts that are a multiple of pixel resolution
if xoff % ref.res[0] == 0 and yoff % ref.res[1] == 0:
Expand Down
Loading

0 comments on commit 5c5639d

Please sign in to comment.