Skip to content

Commit

Permalink
Merge branch 'develop' into feature/irregular_to_basis_mixed_effects_…
Browse files Browse the repository at this point in the history
…method
  • Loading branch information
pcuestas committed Jun 16, 2024
2 parents 6ad8998 + cc9ef22 commit 6ab190d
Show file tree
Hide file tree
Showing 4 changed files with 926 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/modules/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ The following functions are used to make synthetic functional datasets:
.. autosummary::
:toctree: autosummary

skfda.datasets.euler_maruyama
skfda.datasets.make_gaussian
skfda.datasets.make_gaussian_process
skfda.datasets.make_sinusoidal_process
Expand Down
2 changes: 2 additions & 0 deletions skfda/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"fetch_bone_density",
],
"_samples_generators": [
"euler_maruyama",
"make_gaussian",
"make_gaussian_process",
"make_multimodal_landmarks",
Expand Down Expand Up @@ -54,6 +55,7 @@
fetch_weather as fetch_weather,
)
from ._samples_generators import (
euler_maruyama as euler_maruyama,
make_gaussian as make_gaussian,
make_gaussian_process as make_gaussian_process,
make_multimodal_landmarks as make_multimodal_landmarks,
Expand Down
241 changes: 236 additions & 5 deletions skfda/datasets/_samples_generators.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,255 @@
from __future__ import annotations

import itertools
from typing import Callable, Sequence, Union
from typing import Any, Callable, Sequence, Union

import numpy as np
import scipy.integrate
from scipy.stats import multivariate_normal
from typing_extensions import Protocol

from .._utils import _cartesian_product, _to_grid_points, normalize_warping
from ..misc.covariances import Brownian, CovarianceLike, _execute_covariance
from ..misc.validation import validate_random_state
from ..representation import FDataGrid
from ..representation.interpolation import SplineInterpolation
from ..typing._base import DomainRangeLike, GridPointsLike, RandomStateLike
from ..typing._numpy import NDArrayFloat
from ..typing._numpy import ArrayLike, NDArrayFloat

MeanCallable = Callable[[np.ndarray], np.ndarray]
CovarianceCallable = Callable[[np.ndarray, np.ndarray], np.ndarray]

MeanLike = Union[float, NDArrayFloat, MeanCallable]
SDETerm = Callable[[float, NDArrayFloat], NDArrayFloat]


class InitialValueGenerator(Protocol):
"""Class to represent SDE initial value generators.
This is intented to be an interface compatible with the rvs method of
SciPy distributions.
"""

def __call__(
self,
size: int,
random_state: RandomStateLike,
) -> NDArrayFloat:
"""Interface of initial value generator."""


def euler_maruyama( # noqa: WPS210
initial_condition: ArrayLike | InitialValueGenerator,
n_grid_points: int = 100,
drift: SDETerm | ArrayLike | None = None,
diffusion: SDETerm | ArrayLike | None = None,
n_samples: int | None = None,
start: float = 0.0, # noqa: WPS358 -- Distinguish float from integer
stop: float = 1.0,
diffusion_matricial_term: bool = True,
random_state: RandomStateLike = None,
) -> FDataGrid:
r"""Numerical integration of an Itô SDE using the Euler-Maruyana scheme.
An SDE can be expressed with the following formula
.. math::
d\mathbf{X}(t) = \mathbf{F}(t, \mathbf{X}(t))dt + \mathbf{G}(t,
\mathbf{X}(t))d\mathbf{W}(t).
In this equation, :math:`\mathbf{X} = (X^{(1)}, X^{(2)}, ... , X^{(n)})
\in \mathbb{R}^q` is a vector that represents the state of the stochastic
process. The function :math:`\mathbf{F}(t, \mathbf{X}) = (F^{(1)}(t,
\mathbf{X}), ..., F^{(q)}(t, \mathbf{X}))` is called drift and refers
to the deterministic component of the equation. The function
:math:`\mathbf{G} (t, \mathbf{X}) = (G^{i, j}(t, \mathbf{X}))_{i=1, j=1}
^{q, m}` is denoted as the diffusion term and refers to the stochastic
component of the evolution. :math:`\mathbf{W}(t)` refers to a Wiener
process (Standard Brownian motion) of dimension :math:`m`. Finally,
:math:`q` refers to the dimension of the variable :math:`\mathbf{X}`
(dimension of the codomain) and :math:`m` to the dimension of the noise.
Euler-Maruyama's method computes the approximated solution using the
Markov chain
.. math::
X_{n + 1}^{(i)} = X_n^{(i)} + F^{(i)}(t_n, \mathbf{X}_n)\Delta t_n +
\sum_{j=1}^m G^{i,j}(t_n, \mathbf{X}_n)\sqrt{\Delta t_n}\Delta Z_n^j,
where :math:`X_n^{(i)}` is the approximated value of :math:`X^{(i)}(t_n)`
and the :math:`\mathbf{Z}_m` are independent, identically distributed
:math:`m`-dimensional standard normal random variables.
Args:
initial_condition: Initial condition of the SDE. It can have one of
three formats: An starting initial value from which to
calculate *n_samples* trajectories. An array of initial values.
For each starting point a trajectory will be calculated. A
function that generates random numbers or vectors. It should
have two parameters called size and random_state and it should
return an array.
n_grid_points: The total number of points of evaluation.
drift: Drift coefficient (:math:`F(t,\mathbf{X})` in the equation).
diffusion: Diffusion coefficient (:math:`G(t,\mathbf{X})` in the
equation).
n_samples: Number of trajectories integrated.
start: Starting time of the trajectories.
stop: Ending time of the trajectories.
diffusion_matricial_term: True if the diffusion coefficient is a
matrix.
random_state: Random state.
Returns:
:class:`FDataGrid` object comprising all the trajectories.
See also:
:func:`make_gaussian_process`: Simpler function for generating
Gaussian processes.
Examples:
Example of the use of euler_maruyama for an Ornstein-Uhlenbeck process
that has the equation:
.. math:
dX(t) = -A(X(t) - \mu)dt + BdW(t)
>>> from scipy.stats import norm
>>> A = 1
>>> mu = 3
>>> B = 0.5
>>> def ou_drift(t: float, x: np.ndarray) -> np.ndarray:
... return -A * (x - mu)
>>> initial_condition = norm().rvs
>>>
>>> trajectories = euler_maruyama(
... initial_condition=initial_condition,
... n_samples=10,
... drift=ou_drift,
... diffusion=B,
... )
"""
random_state = validate_random_state(random_state)

if n_samples is None:
if callable(initial_condition):
raise ValueError(
"Invalid initial conditions. If a function is given, the "
"n_samples argument must be included.",
)

initial_values = np.atleast_1d(initial_condition)
n_samples = len(initial_values)
else:
if callable(initial_condition):
initial_values = initial_condition(
size=n_samples,
random_state=random_state,
)
else:
initial_condition = np.atleast_1d(initial_condition)
dim_codomain = len(initial_condition)
initial_values = (
initial_condition
* np.ones((n_samples, dim_codomain))
)

if initial_values.ndim == 1:
initial_values = initial_values[:, np.newaxis]
elif initial_values.ndim > 2:
raise ValueError(
"Invalid initial conditions. Each of the starting points "
"must be a flat array.",
)
(n_samples, dim_codomain) = initial_values.shape

if dim_codomain == 1:
diffusion_matricial_term = False

if drift is None:
drift = 0.0 # noqa: WPS358 -- Distinguish float from integer

if callable(drift):
drift_function = drift
else:
def constant_drift( # noqa: WPS430 -- We need internal functions
t: float,
x: NDArrayFloat,
) -> NDArrayFloat:
return np.atleast_1d(drift)

drift_function = constant_drift

if diffusion is None:
if diffusion_matricial_term:
diffusion = np.eye(dim_codomain)
else:
diffusion = 1.0

if callable(diffusion):
diffusion_function = diffusion
else:
def constant_diffusion( # noqa: WPS430 -- We need internal functions
t: float,
x: NDArrayFloat,
) -> NDArrayFloat:
return np.atleast_1d(diffusion)

diffusion_function = constant_diffusion

def vector_diffusion_times_noise( # noqa: WPS430 We need internal functons
t_n: float,
x_n: NDArrayFloat,
noise: NDArrayFloat,
) -> NDArrayFloat:
return diffusion_function(t_n, x_n) * noise

def matrix_diffusion_times_noise( # noqa: WPS430 We need internal functons
t_n: float,
x_n: NDArrayFloat,
noise: NDArrayFloat,
) -> Any:
return np.einsum(
'...dj, ...j -> ...d',
diffusion_function(t_n, x_n),
noise,
)

dim_noise = dim_codomain

if diffusion_matricial_term:
diffusion_times_noise = matrix_diffusion_times_noise
dim_noise = diffusion_function(start, initial_values).shape[-1]
else:
diffusion_times_noise = vector_diffusion_times_noise

data_matrix = np.zeros((n_samples, n_grid_points, dim_codomain))
times = np.linspace(start, stop, n_grid_points)
delta_t = times[1:] - times[:-1]
noise = random_state.standard_normal(
size=(n_samples, n_grid_points - 1, dim_noise),
)
data_matrix[:, 0] = initial_values

for n in range(n_grid_points - 1):
t_n = times[n]
x_n = data_matrix[:, n]

data_matrix[:, n + 1] = (
x_n
+ delta_t[n] * drift_function(t_n, x_n)
+ diffusion_times_noise(t_n, x_n, noise[:, n])
* np.sqrt(delta_t[n])
)

return FDataGrid(
grid_points=times,
data_matrix=data_matrix,
)


def make_gaussian(
Expand Down Expand Up @@ -148,7 +379,7 @@ def make_gaussian_process(
)


def make_sinusoidal_process(
def make_sinusoidal_process( # noqa: WPS211
n_samples: int = 15,
n_features: int = 100,
*,
Expand Down Expand Up @@ -273,7 +504,7 @@ def make_multimodal_landmarks(
return modes_location + variation


def make_multimodal_samples(
def make_multimodal_samples( # noqa: WPS211
n_samples: int = 15,
*,
n_modes: int = 1,
Expand Down Expand Up @@ -371,7 +602,7 @@ def make_multimodal_samples(
# Covariance matrix of the samples
cov = mode_std * np.eye(dim_domain)

for i, j, k in itertools.product(
for i, j, k in itertools.product( # noqa: WPS440
range(n_samples),
range(dim_codomain),
range(n_modes),
Expand Down
Loading

0 comments on commit 6ab190d

Please sign in to comment.