Skip to content

Commit

Permalink
Documentation updates, try to add geopandas
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Sep 18, 2024
1 parent c9df4e4 commit eaf22dc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
'torch': ('https://pytorch.org/docs/stable', None),
'torchmetrics': ('https://lightning.ai/docs/torchmetrics/stable/', None),
'torchvision': ('https://pytorch.org/vision/stable', None),
'geopandas': ('https://geopandas.org/en/stable/', None),
}

# nbsphinx
Expand Down
24 changes: 12 additions & 12 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def load_file(path: str | GeoDataFrame) -> GeoDataFrame:
"""Load a file from the given path.
Parameters:
path (str or GeoDataFrame): The path to the file or a GeoDataFrame object.
path (str or :class:`GeoDataFrame`): The path to the file or a :class:`GeoDataFrame` object.
Returns:
GeoDataFrame: The loaded file as a GeoDataFrame.
:class:`GeoDataFrame`: The loaded file as a :class:`GeoDataFrame`.
Raises:
None
Expand Down Expand Up @@ -79,11 +79,11 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None:
def __save_as_gpd_or_feather(
path: str, gdf: GeoDataFrame, driver: str = 'ESRI Shapefile'
) -> None:
"""Save a GeoDataFrame as a file supported by any geopandas driver or as a feather file.
"""Save a :class:`GeoDataFrame` as a file supported by any geopandas driver or as a feather file.
Parameters:
path (str): The path to save the file.
gdf (GeoDataFrame): The GeoDataFrame to be saved.
gdf (:class:`GeoDataFrame`): The :class:`GeoDataFrame` to be saved.
driver (str, optional): The driver to be used for saving the file. Defaults to 'ESRI Shapefile'.
Returns:
Expand All @@ -98,10 +98,10 @@ def __save_as_gpd_or_feather(
def get_chips(self) -> GeoDataFrame:
"""Determines the way to get the extent of the chips (samples) of the dataset.
Should return a GeoDataFrame with the extend of the chips with the columns
Should return a :class:`GeoDataFrame` with the extend of the chips with the columns
geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip. It is
expected that every sampler calls this method to get the chips as one of the
last steps in the __init__ method.
last steps in the `__init__` method.
"""

def filter_chips(
Expand All @@ -113,10 +113,10 @@ def filter_chips(
"""Filter the default set of chips in the sampler down to a specific subset.
Args:
filter_by: The file or geodataframe for which the geometries will be used during filtering
filter_by: The file or :class:`GeoDataFrame` for which the geometries will be used during filtering
predicate: Predicate as used in Geopandas sindex.query_bulk
action: What to do with the chips that satisfy the condition by the predicacte.
Can either be "drop" or "keep".
Can either be ``'drop'``or ``'keep'``.
"""
prefilter_leng = len(self.chips)
filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs)
Expand Down Expand Up @@ -281,7 +281,7 @@ def get_chips(self) -> GeoDataFrame:
"""Generate chips from the dataset.
Returns:
GeoDataFrame: A GeoDataFrame containing the generated chips.
:class:`GeoDataFrame`: A :class:`GeoDataFrame` containing the generated chips.
"""
chips = []
for _ in tqdm(range(self.length)):
Expand Down Expand Up @@ -380,7 +380,7 @@ def get_chips(self) -> GeoDataFrame:
"""Generates chips from the given hits.
Returns:
GeoDataFrame: A GeoDataFrame containing the generated chips.
:class:`GeoDataFrame`: A :class:`GeoDataFrame` containing the generated chips.
"""
print('generating samples... ')
self.length = 0
Expand Down Expand Up @@ -459,10 +459,10 @@ def __init__(
self.chips = self.get_chips()

def get_chips(self) -> GeoDataFrame:
"""Generate chips from the hits and return them as a GeoDataFrame.
"""Generate chips from the hits and return them as a :class:`GeoDataFrame`.
Returns:
GeoDataFrame: A GeoDataFrame containing the generated chips.
:class:`GeoDataFrame`: A :class:`GeoDataFrame` containing the generated chips.
"""
generator: Callable[[int], Iterable[int]] = range
if self.shuffle:
Expand Down

0 comments on commit eaf22dc

Please sign in to comment.