Skip to content

Commit

Permalink
Refactor _validate_data_input
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Feb 22, 2025
1 parent 3d4baa6 commit 9672f05
Showing 1 changed file with 59 additions and 34 deletions.
93 changes: 59 additions & 34 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Literal

import numpy as np
import xarray as xr
from pygmt.encodings import charset
from pygmt.exceptions import GMTInvalidInput
Expand All @@ -39,10 +40,20 @@
"ISO-8859-15",
"ISO-8859-16",
]
# Type hints for the list of possible data kinds.
Kind = Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]


def _validate_data_input(
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
data=None,
x=None,
y=None,
z=None,
required_z: bool = False,
required_data: bool = True,
kind: Kind | None = None,
) -> None:
"""
Check if the combination of data/x/y/z is valid.
Expand Down Expand Up @@ -76,23 +87,23 @@ def _validate_data_input(
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
>>> _validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... required_z=True,
... kind="vectors",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
>>> _validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... required_z=True,
... kind="vectors",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
>>> _validate_data_input(data="infile", x=[1, 2, 3])
Traceback (most recent call last):
...
Expand All @@ -115,34 +126,52 @@ def _validate_data_input(
GMTInvalidInput
If the data input is not valid.
"""
if data is None: # data is None
if x is None and y is None: # both x and y are None
if required_data: # data is not optional
# Check if too much data is provided.
if data is not None and any(v is not None for v in (x, y, z)):
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)

# Determine the data kind if not provided.
kind = kind or data_kind(data, required=required_data)

# Determine the required number of columns based on the required_z flag.
required_cols = 3 if required_z else 1

# Check based on the data kind.
match kind:
case "empty": # data is given via a series vectors like x/y/z.
if x is None and y is None:
msg = "No input data provided."
raise GMTInvalidInput(msg)
elif x is None or y is None: # either x or y is None
msg = "Must provide both x and y."
raise GMTInvalidInput(msg)
if required_z and z is None: # both x and y are not None, now check z
msg = "Must provide x, y, and z."
raise GMTInvalidInput(msg)
else: # data is not None
if x is not None or y is not None or z is not None:
msg = "Too much data. Use either data or x/y/z."
raise GMTInvalidInput(msg)
# check if data has the required z column
if required_z:
msg = "data must provide x, y, and z columns."
if kind == "matrix" and data.shape[1] < 3:
if x is None or y is None:
msg = "Must provide both x and y."
raise GMTInvalidInput(msg)
if required_z and z is None:
msg = "Must provide x, y, and z."
raise GMTInvalidInput(msg)
case "matrix": # 2-D numpy.ndarray
if (actual_cols := data.shape[1]) < required_cols:
msg = (
f"Need at least {required_cols} columns but {actual_cols} column(s) "
"are given."
)
raise GMTInvalidInput(msg)
case "vectors":
# The if-else block should match the codes in the virtualfile_in function.
if hasattr(data, "items") and not hasattr(data, "to_frame"):
# Dict, pandas.DataFrame, or xarray.Dataset, but not pd.Series.
_data = [array for _, array in data.items()]
else:
# Python list, tuple, numpy.ndarray, and pandas.Series types
_data = np.atleast_2d(np.asanyarray(data).T)

# Check if the number of columns is sufficient.
if (actual_cols := len(_data)) < required_cols:
msg = (
f"Need at least {required_cols} columns but {actual_cols} "
"column(s) are given."
)
raise GMTInvalidInput(msg)
if kind == "vectors":
if hasattr(data, "shape") and (
(len(data.shape) == 1 and data.shape[0] < 3)
or (len(data.shape) > 1 and data.shape[1] < 3)
): # np.ndarray or pd.DataFrame
raise GMTInvalidInput(msg)
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
raise GMTInvalidInput(msg)


def _is_printable_ascii(argstr: str) -> bool:
Expand Down Expand Up @@ -261,11 +290,7 @@ def _check_encoding(argstr: str) -> Encoding:
return "ISOLatin1+"


def data_kind(
data: Any, required: bool = True
) -> Literal[
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]:
def data_kind(data: Any, required: bool = True) -> Kind:
r"""
Check the kind of data that is provided to a module.
Expand Down

0 comments on commit 9672f05

Please sign in to comment.