Skip to content

Commit

Permalink
Typing and test fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed Jun 30, 2024
1 parent 4d0679d commit c67cfe4
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 39 deletions.
35 changes: 8 additions & 27 deletions skfda/_utils/ndfunction/_ndfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,18 @@
from ._array_api import Array, DType, Shape, array_namespace
from ._region import Region
from .evaluator import Evaluator
from .extrapolation import ExtrapolationLike, _parse_extrapolation
from .extrapolation import (
AcceptedExtrapolation,
ExtrapolationLike,
_parse_extrapolation,
)
from .typing import GridPointsLike
from .utils.validation import check_array_namespace, check_evaluation_points

T = TypeVar('T', bound='NDFunction')
A = TypeVar('A', bound=Array[Shape, DType])

EvalPointsType: TypeAlias = A | GridPointsLike[A] | Sequence[GridPointsLike[A]]

AcceptedExtrapolation: TypeAlias = (
ExtrapolationLike[A] | None | Literal["default"]
)


# When higher-kinded types are supported in Python, this should be generic on:
# - Array backend (e.g.: NumPy or CuPy arrays, or PyTorch tensors)
Expand Down Expand Up @@ -488,6 +487,9 @@ def __neg__(self) -> Self:
pass


T = TypeVar('T', bound='NDFunction[Any]')


def concatenate(functions: Iterable[T], as_coordinates: bool = False) -> T:
"""
Join samples from an iterable of similar FData objects.
Expand Down Expand Up @@ -518,24 +520,3 @@ def concatenate(functions: Iterable[T], as_coordinates: bool = False) -> T:
)

return first.concatenate(*functions, as_coordinates=as_coordinates)


F = TypeVar("F", covariant=True)


class _CoordinateSequence(Protocol[F]):
"""
Sequence of coordinates.
Note that this represents a sequence of coordinates, not a sequence of
FData objects.
"""

def __getitem__(
self,
key: Union[int, slice],
) -> F:
pass

def __len__(self) -> int:
pass
5 changes: 5 additions & 0 deletions skfda/_utils/ndfunction/extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Any,
Literal,
Mapping,
TypeAlias,
TypeVar,
Union,
overload,
Expand All @@ -35,6 +36,10 @@
Literal["bounds", "exception", "nan", "none", "periodic", "zeros"],
]

AcceptedExtrapolation: TypeAlias = (
ExtrapolationLike[A] | None | Literal["default"]
)


class PeriodicExtrapolation(Evaluator[RealArray]):
"""
Expand Down
24 changes: 16 additions & 8 deletions skfda/representation/_functional_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for functional data manipulation.
"""
Module for functional data manipulation.
Defines the abstract class that should be implemented by the funtional data
objects of the package and contains some commons methods.
Expand All @@ -8,6 +9,7 @@

from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Expand All @@ -29,6 +31,8 @@
from .._utils.ndfunction import NDFunction, concatenate as concatenate
from .._utils.ndfunction._array_api import Array, DType, Shape, numpy_namespace
from .._utils.ndfunction._region import Region
from .._utils.ndfunction.extrapolation import AcceptedExtrapolation
from .._utils.ndfunction.typing import GridPointsLike
from .._utils.ndfunction.utils.validation import check_grid_points
from ..typing._base import DomainRange, LabelTuple, LabelTupleLike
from ..typing._numpy import (
Expand All @@ -39,6 +43,10 @@
)
from .extrapolation import ExtrapolationLike

if TYPE_CHECKING:
from .basis import Basis, FDataBasis
from .grid import FDataGrid

A = TypeVar('A', bound=Array[Shape, DType])


Expand Down Expand Up @@ -207,8 +215,8 @@ def domain_range(self) -> DomainRange:
@override
@property
def domain(self) -> Region[A]:
lower = np.array([d[0] for d in self.domain_range])
upper = np.array([d[1] for d in self.domain_range])
lower = self.array_backend.asarray([d[0] for d in self.domain_range])
upper = self.array_backend.asarray([d[1] for d in self.domain_range])

return AxisAlignedBox(lower, upper)

Expand Down Expand Up @@ -258,8 +266,8 @@ def shift(
shifts: A | float,
*,
restrict_domain: bool = False,
extrapolation: ExtrapolationLike[A] = "default",
grid_points: GridPointsLike | None = None,
extrapolation: AcceptedExtrapolation[A] = "default",
grid_points: GridPointsLike[A] | None = None,
) -> FDataGrid:
r"""
Perform a shift of the curves.
Expand Down Expand Up @@ -497,7 +505,7 @@ def mean(
@abstractmethod
def to_grid(
self,
grid_points: GridPointsLike | None = None,
grid_points: GridPointsLike[A] | None = None,
) -> FDataGrid:
"""Return the discrete representation of the object.
Expand Down Expand Up @@ -709,7 +717,7 @@ def isna(self) -> NDArrayBool: # noqa: D102

def take( # noqa: WPS238
self,
indices: int | Sequence[int] | NDArrayInt,
indexer: Sequence[int] | NDArrayInt,
allow_fill: bool = False,
fill_value: Self | None = None,
axis: int = 0,
Expand Down Expand Up @@ -762,7 +770,7 @@ def take( # noqa: WPS238
if axis != 0:
raise ValueError(f"Axis must be 0, not {axis}")

arr_indices = np.atleast_1d(indices)
arr_indices = np.atleast_1d(indexer)

if fill_value is None:
fill_value = self.dtype.na_value
Expand Down
8 changes: 4 additions & 4 deletions skfda/tests/test_euler_maruyama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from scipy.stats import multivariate_normal, norm

from skfda._utils.ndfunction.utils._points import grid_points_equal
from skfda.datasets import euler_maruyama
from skfda.typing._numpy import NDArrayFloat

Expand Down Expand Up @@ -570,9 +571,8 @@ def test_grid_points() -> None:
start = 0
stop = 10
n_grid_points = 105
expected_grid_points = np.atleast_2d(
np.linspace(start, stop, n_grid_points),
)
expected_grid_points = np.empty(shape=(1,), dtype=object)
expected_grid_points[0] = np.linspace(start, stop, n_grid_points)

fd = euler_maruyama(
initial_condition,
Expand All @@ -583,7 +583,7 @@ def test_grid_points() -> None:
random_state=random_state,
)

np.testing.assert_array_equal(
assert grid_points_equal(
fd.grid_points,
expected_grid_points,
)
Expand Down

0 comments on commit c67cfe4

Please sign in to comment.