Skip to content

Commit

Permalink
Merge branch 'refactor/validate_data_input' into refactor/virtualfile_in
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Feb 22, 2025
2 parents 80a178d + 9672f05 commit d7ca533
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 36 deletions.
12 changes: 6 additions & 6 deletions pygmt/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class GMTDataArrayAccessor:
(<GridRegistration.GRIDLINE: 0>, <GridType.GEOGRAPHIC: 1>)
"""

def __init__(self, xarray_obj):
def __init__(self, xarray_obj: xr.DataArray):
self._obj = xarray_obj

# Default to Gridline registration and Cartesian grid type
Expand All @@ -137,19 +137,19 @@ def __init__(self, xarray_obj):
# two columns of the shortened summary information of grdinfo.
if (_source := self._obj.encoding.get("source")) and Path(_source).exists():
with contextlib.suppress(ValueError):
self._registration, self._gtype = map(
self._registration, self._gtype = map( # type: ignore[assignment]
int, grdinfo(_source, per_column="n").split()[-2:]
)

@property
def registration(self):
def registration(self) -> GridRegistration:
"""
Grid registration type :class:`pygmt.enums.GridRegistration`.
"""
return self._registration

@registration.setter
def registration(self, value):
def registration(self, value: GridRegistration | int):
# TODO(Python>=3.12): Simplify to `if value not in GridRegistration`.
if value not in GridRegistration.__members__.values():
msg = (
Expand All @@ -160,14 +160,14 @@ def registration(self, value):
self._registration = GridRegistration(value)

@property
def gtype(self):
def gtype(self) -> GridType:
"""
Grid coordinate system type :class:`pygmt.enums.GridType`.
"""
return self._gtype

@gtype.setter
def gtype(self, value):
def gtype(self, value: GridType | int):
# TODO(Python>=3.12): Simplify to `if value not in GridType`.
if value not in GridType.__members__.values():
msg = (
Expand Down
74 changes: 44 additions & 30 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Literal

import numpy as np
import xarray as xr
from pygmt.encodings import charset
from pygmt.exceptions import GMTInvalidInput
Expand All @@ -39,11 +40,21 @@
"ISO-8859-15",
"ISO-8859-16",
]
# Type hints for the list of possible data kinds.
Kind = Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]


def _validate_data_input( # noqa: PLR0912
data=None, x=None, y=None, z=None, required_data=True, required_cols=2, kind=None
):
def _validate_data_input(
data=None,
x=None,
y=None,
z=None,
required_data: bool = True,
required_cols: int = 2,
kind: Kind | None = None,
) -> None:
"""
Check if the combination of data/x/y/z is valid.
Expand Down Expand Up @@ -76,23 +87,23 @@ def _validate_data_input( # noqa: PLR0912
>>> _validate_data_input(data=data, required_cols=3, kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
>>> _validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... required_cols=3,
... kind="vectors",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
>>> _validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... kind="vectors",
... required_cols=3,
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
>>> _validate_data_input(data="infile", x=[1, 2, 3])
Traceback (most recent call last):
...
Expand All @@ -115,42 +126,49 @@ def _validate_data_input( # noqa: PLR0912
GMTInvalidInput
If the data input is not valid.
"""
if kind is None:
kind = data_kind(data, required=required_data)

# Check if too much data is provided.
if data is not None and any(v is not None for v in (x, y, z)):
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)

# Determine the data kind if not provided.
kind = kind or data_kind(data, required=required_data)

# Check based on the data kind.
match kind:
case "empty":
if x is None and y is None: # Both x and y are None.
case "empty": # data is given via a series vectors like x/y/z.
if x is None and y is None:
msg = "No input data provided."
raise GMTInvalidInput(msg)
if x is None or y is None: # Either x or y is None.
if x is None or y is None:
msg = "Must provide both x and y."
raise GMTInvalidInput(msg)
if required_cols >= 3 and z is None:
# Both x and y are not None, now check z.
msg = "Must provide x, y, and z."
raise GMTInvalidInput(msg)
case "matrix": # 2-D numpy.ndarray
if (actual_cols := data.shape[1]) < required_cols:
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
msg = (
f"Need at least {required_cols} columns but {actual_cols} column(s) "
"are given."
)
raise GMTInvalidInput(msg)
case "vectors":
# The if-else block should match the codes in the virtualfile_in function.
if hasattr(data, "items") and not hasattr(data, "to_frame"):
# Dict, pd.DataFrame, xr.Dataset
arrays = [array for _, array in data.items()]
if (actual_cols := len(arrays)) < required_cols:
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
raise GMTInvalidInput(msg)

# Loop over columns to make sure they're not None
for idx, array in enumerate(arrays[:required_cols]):
if array is None:
msg = f"data needs {required_cols} columns but the {idx} column is None."
raise GMTInvalidInput(msg)
# Dict, pandas.DataFrame, or xarray.Dataset, but not pd.Series.
_data = [array for _, array in data.items()]
else:
# Python list, tuple, numpy.ndarray, and pandas.Series types
_data = np.atleast_2d(np.asanyarray(data).T)

# Check if the number of columns is sufficient.
if (actual_cols := len(_data)) < required_cols:
msg = (
f"Need at least {required_cols} columns but {actual_cols} "
"column(s) are given."
)
raise GMTInvalidInput(msg)


def _is_printable_ascii(argstr: str) -> bool:
Expand Down Expand Up @@ -269,11 +287,7 @@ def _check_encoding(argstr: str) -> Encoding:
return "ISOLatin1+"


def data_kind(
data: Any, required: bool = True
) -> Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]:
def data_kind(data: Any, required: bool = True) -> Kind:
r"""
Check the kind of data that is provided to a module.
Expand Down

0 comments on commit d7ca533

Please sign in to comment.