From 3d4baa653639280f96b8b277e34370be48bf145c Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Sat, 22 Feb 2025 10:18:24 +0800 Subject: [PATCH 1/2] TYP: Add type hints to the GMT accessors (#3816) --- pygmt/accessors.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pygmt/accessors.py b/pygmt/accessors.py index 711778b7d11..d0891076d2a 100644 --- a/pygmt/accessors.py +++ b/pygmt/accessors.py @@ -126,7 +126,7 @@ class GMTDataArrayAccessor: (, ) """ - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: xr.DataArray): self._obj = xarray_obj # Default to Gridline registration and Cartesian grid type @@ -137,19 +137,19 @@ def __init__(self, xarray_obj): # two columns of the shortened summary information of grdinfo. if (_source := self._obj.encoding.get("source")) and Path(_source).exists(): with contextlib.suppress(ValueError): - self._registration, self._gtype = map( + self._registration, self._gtype = map( # type: ignore[assignment] int, grdinfo(_source, per_column="n").split()[-2:] ) @property - def registration(self): + def registration(self) -> GridRegistration: """ Grid registration type :class:`pygmt.enums.GridRegistration`. """ return self._registration @registration.setter - def registration(self, value): + def registration(self, value: GridRegistration | int): # TODO(Python>=3.12): Simplify to `if value not in GridRegistration`. if value not in GridRegistration.__members__.values(): msg = ( @@ -160,14 +160,14 @@ def registration(self, value): self._registration = GridRegistration(value) @property - def gtype(self): + def gtype(self) -> GridType: """ Grid coordinate system type :class:`pygmt.enums.GridType`. """ return self._gtype @gtype.setter - def gtype(self, value): + def gtype(self, value: GridType | int): # TODO(Python>=3.12): Simplify to `if value not in GridType`. if value not in GridType.__members__.values(): msg = ( From 9672f05d01f267696ed6fdd46cc26007bf0c1cee Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 21 Feb 2025 23:46:59 +0800 Subject: [PATCH 2/2] Refactor _validate_data_input --- pygmt/helpers/utils.py | 93 +++++++++++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 34 deletions(-) diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index da9b4ad026a..ff7be026d9d 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Any, Literal +import numpy as np import xarray as xr from pygmt.encodings import charset from pygmt.exceptions import GMTInvalidInput @@ -39,10 +40,20 @@ "ISO-8859-15", "ISO-8859-16", ] +# Type hints for the list of possible data kinds. +Kind = Literal[ + "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" +] def _validate_data_input( - data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None + data=None, + x=None, + y=None, + z=None, + required_z: bool = False, + required_data: bool = True, + kind: Kind | None = None, ) -> None: """ Check if the combination of data/x/y/z is valid. @@ -76,7 +87,7 @@ def _validate_data_input( >>> _validate_data_input(data=data, required_z=True, kind="matrix") Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. >>> _validate_data_input( ... data=pd.DataFrame(data, columns=["x", "y"]), ... required_z=True, @@ -84,7 +95,7 @@ def _validate_data_input( ... ) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. >>> _validate_data_input( ... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])), ... required_z=True, @@ -92,7 +103,7 @@ def _validate_data_input( ... ) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. >>> _validate_data_input(data="infile", x=[1, 2, 3]) Traceback (most recent call last): ... @@ -115,34 +126,52 @@ def _validate_data_input( GMTInvalidInput If the data input is not valid. """ - if data is None: # data is None - if x is None and y is None: # both x and y are None - if required_data: # data is not optional + # Check if too much data is provided. + if data is not None and any(v is not None for v in (x, y, z)): + msg = "Too much data. Use either data or x/y/z." + raise GMTInvalidInput(msg) + + # Determine the data kind if not provided. + kind = kind or data_kind(data, required=required_data) + + # Determine the required number of columns based on the required_z flag. + required_cols = 3 if required_z else 1 + + # Check based on the data kind. + match kind: + case "empty": # data is given via a series vectors like x/y/z. + if x is None and y is None: msg = "No input data provided." raise GMTInvalidInput(msg) - elif x is None or y is None: # either x or y is None - msg = "Must provide both x and y." - raise GMTInvalidInput(msg) - if required_z and z is None: # both x and y are not None, now check z - msg = "Must provide x, y, and z." - raise GMTInvalidInput(msg) - else: # data is not None - if x is not None or y is not None or z is not None: - msg = "Too much data. Use either data or x/y/z." - raise GMTInvalidInput(msg) - # check if data has the required z column - if required_z: - msg = "data must provide x, y, and z columns." - if kind == "matrix" and data.shape[1] < 3: + if x is None or y is None: + msg = "Must provide both x and y." + raise GMTInvalidInput(msg) + if required_z and z is None: + msg = "Must provide x, y, and z." + raise GMTInvalidInput(msg) + case "matrix": # 2-D numpy.ndarray + if (actual_cols := data.shape[1]) < required_cols: + msg = ( + f"Need at least {required_cols} columns but {actual_cols} column(s) " + "are given." + ) + raise GMTInvalidInput(msg) + case "vectors": + # The if-else block should match the codes in the virtualfile_in function. + if hasattr(data, "items") and not hasattr(data, "to_frame"): + # Dict, pandas.DataFrame, or xarray.Dataset, but not pd.Series. + _data = [array for _, array in data.items()] + else: + # Python list, tuple, numpy.ndarray, and pandas.Series types + _data = np.atleast_2d(np.asanyarray(data).T) + + # Check if the number of columns is sufficient. + if (actual_cols := len(_data)) < required_cols: + msg = ( + f"Need at least {required_cols} columns but {actual_cols} " + "column(s) are given." + ) raise GMTInvalidInput(msg) - if kind == "vectors": - if hasattr(data, "shape") and ( - (len(data.shape) == 1 and data.shape[0] < 3) - or (len(data.shape) > 1 and data.shape[1] < 3) - ): # np.ndarray or pd.DataFrame - raise GMTInvalidInput(msg) - if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset - raise GMTInvalidInput(msg) def _is_printable_ascii(argstr: str) -> bool: @@ -261,11 +290,7 @@ def _check_encoding(argstr: str) -> Encoding: return "ISOLatin1+" -def data_kind( - data: Any, required: bool = True -) -> Literal[ - "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" -]: +def data_kind(data: Any, required: bool = True) -> Kind: r""" Check the kind of data that is provided to a module.