diff --git a/skfda/_utils/ndfunction/_ndfunction.py b/skfda/_utils/ndfunction/_ndfunction.py index 5b5f9b4ba..80c819f16 100644 --- a/skfda/_utils/ndfunction/_ndfunction.py +++ b/skfda/_utils/ndfunction/_ndfunction.py @@ -30,19 +30,18 @@ from ._array_api import Array, DType, Shape, array_namespace from ._region import Region from .evaluator import Evaluator -from .extrapolation import ExtrapolationLike, _parse_extrapolation +from .extrapolation import ( + AcceptedExtrapolation, + ExtrapolationLike, + _parse_extrapolation, +) from .typing import GridPointsLike from .utils.validation import check_array_namespace, check_evaluation_points -T = TypeVar('T', bound='NDFunction') A = TypeVar('A', bound=Array[Shape, DType]) EvalPointsType: TypeAlias = A | GridPointsLike[A] | Sequence[GridPointsLike[A]] -AcceptedExtrapolation: TypeAlias = ( - ExtrapolationLike[A] | None | Literal["default"] -) - # When higher-kinded types are supported in Python, this should be generic on: # - Array backend (e.g.: NumPy or CuPy arrays, or PyTorch tensors) @@ -488,6 +487,9 @@ def __neg__(self) -> Self: pass +T = TypeVar('T', bound='NDFunction[Any]') + + def concatenate(functions: Iterable[T], as_coordinates: bool = False) -> T: """ Join samples from an iterable of similar FData objects. @@ -518,24 +520,3 @@ def concatenate(functions: Iterable[T], as_coordinates: bool = False) -> T: ) return first.concatenate(*functions, as_coordinates=as_coordinates) - - -F = TypeVar("F", covariant=True) - - -class _CoordinateSequence(Protocol[F]): - """ - Sequence of coordinates. - - Note that this represents a sequence of coordinates, not a sequence of - FData objects. - """ - - def __getitem__( - self, - key: Union[int, slice], - ) -> F: - pass - - def __len__(self) -> int: - pass diff --git a/skfda/_utils/ndfunction/extrapolation.py b/skfda/_utils/ndfunction/extrapolation.py index 1591d6b70..b8c3b61a4 100644 --- a/skfda/_utils/ndfunction/extrapolation.py +++ b/skfda/_utils/ndfunction/extrapolation.py @@ -12,6 +12,7 @@ Any, Literal, Mapping, + TypeAlias, TypeVar, Union, overload, @@ -35,6 +36,10 @@ Literal["bounds", "exception", "nan", "none", "periodic", "zeros"], ] +AcceptedExtrapolation: TypeAlias = ( + ExtrapolationLike[A] | None | Literal["default"] +) + class PeriodicExtrapolation(Evaluator[RealArray]): """ diff --git a/skfda/representation/_functional_data.py b/skfda/representation/_functional_data.py index 8c3558101..2014c333d 100644 --- a/skfda/representation/_functional_data.py +++ b/skfda/representation/_functional_data.py @@ -1,4 +1,5 @@ -"""Module for functional data manipulation. +""" +Module for functional data manipulation. Defines the abstract class that should be implemented by the funtional data objects of the package and contains some commons methods. @@ -8,6 +9,7 @@ from abc import ABC, abstractmethod from typing import ( + TYPE_CHECKING, Any, Callable, Iterable, @@ -29,6 +31,8 @@ from .._utils.ndfunction import NDFunction, concatenate as concatenate from .._utils.ndfunction._array_api import Array, DType, Shape, numpy_namespace from .._utils.ndfunction._region import Region +from .._utils.ndfunction.extrapolation import AcceptedExtrapolation +from .._utils.ndfunction.typing import GridPointsLike from .._utils.ndfunction.utils.validation import check_grid_points from ..typing._base import DomainRange, LabelTuple, LabelTupleLike from ..typing._numpy import ( @@ -39,6 +43,10 @@ ) from .extrapolation import ExtrapolationLike +if TYPE_CHECKING: + from .basis import Basis, FDataBasis + from .grid import FDataGrid + A = TypeVar('A', bound=Array[Shape, DType]) @@ -207,8 +215,8 @@ def domain_range(self) -> DomainRange: @override @property def domain(self) -> Region[A]: - lower = np.array([d[0] for d in self.domain_range]) - upper = np.array([d[1] for d in self.domain_range]) + lower = self.array_backend.asarray([d[0] for d in self.domain_range]) + upper = self.array_backend.asarray([d[1] for d in self.domain_range]) return AxisAlignedBox(lower, upper) @@ -258,8 +266,8 @@ def shift( shifts: A | float, *, restrict_domain: bool = False, - extrapolation: ExtrapolationLike[A] = "default", - grid_points: GridPointsLike | None = None, + extrapolation: AcceptedExtrapolation[A] = "default", + grid_points: GridPointsLike[A] | None = None, ) -> FDataGrid: r""" Perform a shift of the curves. @@ -497,7 +505,7 @@ def mean( @abstractmethod def to_grid( self, - grid_points: GridPointsLike | None = None, + grid_points: GridPointsLike[A] | None = None, ) -> FDataGrid: """Return the discrete representation of the object. @@ -709,7 +717,7 @@ def isna(self) -> NDArrayBool: # noqa: D102 def take( # noqa: WPS238 self, - indices: int | Sequence[int] | NDArrayInt, + indexer: Sequence[int] | NDArrayInt, allow_fill: bool = False, fill_value: Self | None = None, axis: int = 0, @@ -762,7 +770,7 @@ def take( # noqa: WPS238 if axis != 0: raise ValueError(f"Axis must be 0, not {axis}") - arr_indices = np.atleast_1d(indices) + arr_indices = np.atleast_1d(indexer) if fill_value is None: fill_value = self.dtype.na_value diff --git a/skfda/tests/test_euler_maruyama.py b/skfda/tests/test_euler_maruyama.py index e6684add2..9b0fc8a62 100644 --- a/skfda/tests/test_euler_maruyama.py +++ b/skfda/tests/test_euler_maruyama.py @@ -3,6 +3,7 @@ import numpy as np from scipy.stats import multivariate_normal, norm +from skfda._utils.ndfunction.utils._points import grid_points_equal from skfda.datasets import euler_maruyama from skfda.typing._numpy import NDArrayFloat @@ -570,9 +571,8 @@ def test_grid_points() -> None: start = 0 stop = 10 n_grid_points = 105 - expected_grid_points = np.atleast_2d( - np.linspace(start, stop, n_grid_points), - ) + expected_grid_points = np.empty(shape=(1,), dtype=object) + expected_grid_points[0] = np.linspace(start, stop, n_grid_points) fd = euler_maruyama( initial_condition, @@ -583,7 +583,7 @@ def test_grid_points() -> None: random_state=random_state, ) - np.testing.assert_array_equal( + assert grid_points_equal( fd.grid_points, expected_grid_points, )