Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mean an sum functions #643

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions skfda/representation/_functional_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ def mean(
out: None = None,
keepdims: bool = False,
skipna: bool = False,
min_count: int = 0,
) -> T:
"""Compute the mean of all the samples.

Expand All @@ -891,6 +892,9 @@ def mean(
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
A FData object with just one sample representing
Expand All @@ -902,10 +906,7 @@ def mean(
"Not implemented for that parameter combination",
)

return (
self.sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna)
/ self.n_samples
)
return self

@abstractmethod
def to_grid(
Expand Down
49 changes: 39 additions & 10 deletions skfda/representation/basis/_fdatabasis.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,20 +427,49 @@ def sum( # noqa: WPS125
"""
super().sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna)

coefs = (
np.nansum(self.coefficients, axis=0) if skipna
else np.sum(self.coefficients, axis=0)
)

if min_count > 0:
valid = ~np.isnan(self.coefficients)
n_valid = np.sum(valid, axis=0)
coefs[n_valid < min_count] = np.nan
valid_functions = ~self.isna()
valid_coefficients = self.coefficients[valid_functions]

coefs = np.sum(valid_coefficients, axis=0)

return self.copy(
coefficients=coefs,
sample_names=(None,),
)

def mean( # noqa: WPS125
self: T,
*,
axis: Optional[int] = None,
dtype: None = None,
out: None = None,
keepdims: bool = False,
skipna: bool = False,
min_count: int = 0,
) -> T:
"""Compute the mean of all the samples.

Args:
axis: Used for compatibility with numpy. Must be None or 0.
dtype: Used for compatibility with numpy. Must be None.
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
A FDataBasis object with just one sample representing
the mean of all the samples in the original object.
"""
super().mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims,
skipna=skipna)

return (
self.sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna)
/ np.sum(~self.isna())
)

def var(
self: T,
Expand Down Expand Up @@ -998,7 +1027,7 @@ def isna(self) -> NDArrayBool:
Returns:
na_values (np.ndarray): Positions of NA.
"""
return np.all( # type: ignore[no-any-return]
return np.any( # type: ignore[no-any-return]
np.isnan(self.coefficients),
axis=1,
)
Expand Down
98 changes: 86 additions & 12 deletions skfda/representation/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,60 @@ def _get_points_and_values(self: T) -> Tuple[NDArrayFloat, NDArrayFloat]:

def _get_input_points(self: T) -> GridPoints:
return self.grid_points

def _compute_aggregate(
self: T,
operation = str,
*,
skipna: bool = False,
min_count: int = 0,
) -> T:
"""Compute a defined aggregation operation of the samples.

Args:
operation: Operation to be performed. Can be 'mean', 'sum' or
'var'.
axis: Used for compatibility with numpy. Must be None or 0.
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
An FDataGrid object with just one sample representing
the aggregation of all the samples in the original object.

"""
if operation not in {'sum', 'mean', 'var'}:
raise ValueError("Invalid operation."
"Must be one of 'sum', 'mean', or 'var'.")

if skipna:
agg_func = {
'sum': np.nansum,
'mean': np.nanmean,
'var': np.nanvar
}[operation]
else:
agg_func = {
'sum': np.sum,
'mean': np.mean,
'var': np.var
}[operation]

data = agg_func(self.data_matrix, axis=0, keepdims=True)

if min_count > 0:
valid = ~np.isnan(self.data_matrix)
n_valid = np.sum(valid, axis=0)
data[n_valid < min_count] = np.nan

return self.copy(
data_matrix=data,
sample_names=(None,),
)

def sum( # noqa: WPS125
self: T,
Expand Down Expand Up @@ -583,20 +637,40 @@ def sum( # noqa: WPS125
"""
super().sum(axis=axis, out=out, keepdims=keepdims, skipna=skipna)

data = (
np.nansum(self.data_matrix, axis=0, keepdims=True) if skipna
else np.sum(self.data_matrix, axis=0, keepdims=True)
)
return self._compute_aggregate(operation='sum', skipna=skipna,
min_count=min_count)

if min_count > 0:
valid = ~np.isnan(self.data_matrix)
n_valid = np.sum(valid, axis=0)
data[n_valid < min_count] = np.nan
def mean( # noqa: WPS125
self: T,
*,
axis: Optional[int] = None,
dtype: None = None,
out: None = None,
keepdims: bool = False,
skipna: bool = False,
min_count: int = 0,
) -> T:
"""Compute the mean of all the samples.

return self.copy(
data_matrix=data,
sample_names=(None,),
)
Args:
axis: Used for compatibility with numpy. Must be None or 0.
dtype: Used for compatibility with numpy. Must be None.
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
A FDataGrid object with just one sample representing
the mean of all the samples in the original object.
"""
super().mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims,
skipna=skipna)

return self._compute_aggregate(operation='mean', skipna=skipna,
min_count=min_count)

def var(self: T, correction: int = 0) -> T:
"""Compute the variance of a set of samples in a FDataGrid object.
Expand Down
54 changes: 54 additions & 0 deletions skfda/representation/irregular.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,60 @@ def sum( # noqa: WPS125
values=sum_values,
sample_names=(None,),
)

def mean( # noqa: WPS125
self: T,
*,
axis: Optional[int] = None,
dtype: None = None,
out: None = None,
keepdims: bool = False,
skipna: bool = False,
min_count: int = 0,
) -> T:
"""Compute the mean of all the samples.

Args:
axis: Used for compatibility with numpy. Must be None or 0.
dtype: Used for compatibility with numpy. Must be None.
out: Used for compatibility with numpy. Must be None.
keepdims: Used for compatibility with numpy. Must be False.
skipna: Wether the NaNs are ignored or not.
min_count: Number of valid (non NaN) data to have in order
for the a variable to not be NaN when `skipna` is
`True`.

Returns:
An FDataIrregular object with just one sample representing
the mean of all the samples in the original object.
"""
super().mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims,
skipna=skipna)

common_points, common_values = self._get_common_points_and_values()

if len(common_points) == 0:
raise ValueError("No common points in FDataIrregular object")

sum_function = np.nansum if skipna else np.sum
sum_values = sum_function(common_values, axis=0)

if skipna:
count_values = np.sum(~np.isnan(common_values), axis=0)
else:
count_values = np.full(sum_values.shape, self.n_samples)

if min_count > 0:
count_values[count_values < min_count] = np.nan

mean_values = sum_values / count_values

return FDataIrregular(
start_indices=np.array([0]),
points=common_points,
values=mean_values,
sample_names=(None,),
)

def var(self: T, correction: int = 0) -> T:
"""Compute the variance of all the samples.
Expand Down