Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of raster functionality #446

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wntr/gis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
and GIS formatted data and geospatial functions to snap data and find intersections.
"""
from wntr.gis.network import WaterNetworkGIS
from wntr.gis.geospatial import snap, intersect
from wntr.gis.geospatial import snap, intersect, sample_raster

55 changes: 54 additions & 1 deletion wntr/gis/geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import numpy as np


try:
from shapely.geometry import MultiPoint, LineString, Point, shape
has_shapely = True
Expand All @@ -18,6 +19,13 @@
except ModuleNotFoundError:
gpd = None
has_geopandas = False

try:
import rasterio as rio
has_rasterio = True
except ModuleNotFoundError:
rio = None
has_rasterio = False


def snap(A, B, tolerance):
Expand Down Expand Up @@ -280,4 +288,49 @@ def intersect(A, B, B_value=None, include_background=False, background_value=0):

stats.index.name = None

return stats
return stats


def sample_raster(A, filepath, indexes):
"""Sample a raster (e.g., GeoTIFF file) at the point locations given
by the geometry of GeoDataFrame A.

This function can take either a filepath to a raster or a virtual raster (VRT),
which combines multiple raster tiles into a single object, opens the raster, and
samples it at the coordinates of the point geometries in A. This function
assigns nan to values that match the raster's `nodata` attribute. These sampled
values are returned as a Series which has an index matching A.

Parameters
----------
A : GeoDataFrame
Geodataframe containing point geometries (lines and polygons not yet implemented)
filepath : str
Path to raster or alternatively a VRT
band : int or list[int]
Index or indices of bands to sample

Returns
-------
Series
Pandas Series containing the sampled values for each geometry in gdf
"""
# further functionality could include the implementation for other geometries (line, polygon),
# and use of multiprocessing to speed up querying.
if not has_rasterio:
raise ModuleNotFoundError('rasterio is required')

assert (A['geometry'].geom_type == "Point").all()
with rio.open(filepath) as raster:
xys = zip(A.geometry.x, A.geometry.y)

values = np.array(
tuple(raster.sample(xys, indexes)), dtype=float # force to float to allow for conversion of nodata to nan
).squeeze()

values[values == raster.nodata] = np.nan
values = pd.Series(values, index=A.index)

return values


45 changes: 45 additions & 0 deletions wntr/tests/test_gis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
gpd = None
has_geopandas = False

try:
import rasterio as rio
has_rasterio = True
except ModuleNotFoundError:
rio = None
has_rasterio = False

testdir = dirname(abspath(str(__file__)))
datadir = join(testdir, "networks_for_testing")
ex_datadir = join(testdir, "..", "..", "examples", "networks")
Expand Down Expand Up @@ -70,6 +77,36 @@ def setUpClass(self):
df = pd.DataFrame(point_data)
self.points = gpd.GeoDataFrame(df, crs=None)

# raster testing
points = [
(-120.5, 38.5),
(-120.6, 38.6),
(-120.55, 38.65),
(-120.65, 38.55),
(-120.7, 38.7)
]
point_geometries = [Point(xy) for xy in points]
raster_points = gpd.GeoDataFrame(geometry=point_geometries, crs="EPSG:4326")
raster_points.index = ["A", "B", "C", "D", "E"]
self.raster_points = raster_points

# create example raster
minx, miny, maxx, maxy = raster_points.total_bounds
raster_width = 100
raster_height = 100

x = np.linspace(0, 1, raster_width)
y = np.linspace(0, 1, raster_height)
raster_data = np.cos(y)[:, np.newaxis] * np.sin(x) # arbitrary values

transform = rio.transform.from_bounds(minx, miny, maxx, maxy, raster_width, raster_height)
self.transform = transform

with rio.open(
"test_raster.tif", "w", driver="GTiff", height=raster_height, width=raster_width,
count=1, dtype=raster_data.dtype, crs="EPSG:4326", transform=transform) as dst:
dst.write(raster_data, 1)

@classmethod
def tearDownClass(self):
pass
Expand Down Expand Up @@ -311,5 +348,13 @@ def test_snap_points_to_lines(self):

assert_frame_equal(pd.DataFrame(snapped_points), expected, check_dtype=False)

def test_sample_raster(self):
raster_values = wntr.gis.sample_raster(self.raster_points, "test_raster.tif", 1)

assert (raster_values.index == self.raster_points.index).all()
# self.raster_points.plot(column=values, legend=True)
expected_values = np.array([0.000000, 0.423443, 0.665369, 0.174402, 0.000000])
assert np.isclose(raster_values.values, expected_values, atol=1e-5).all()

if __name__ == "__main__":
unittest.main()
Loading