From 49aff203450f6b2c4d2e2a4781e598cf99a8dc00 Mon Sep 17 00:00:00 2001 From: derrick chambers Date: Sat, 29 Jul 2023 18:30:30 -0600 Subject: [PATCH] fix chunk --- dascore/core/attrs.py | 10 ++++++++++ dascore/core/coords.py | 6 +++--- dascore/utils/chunk.py | 8 ++++---- dascore/utils/patch.py | 9 +++++---- dascore/utils/pd.py | 5 +++-- 5 files changed, 25 insertions(+), 13 deletions(-) diff --git a/dascore/core/attrs.py b/dascore/core/attrs.py index c220f60b..ea1f46e0 100644 --- a/dascore/core/attrs.py +++ b/dascore/core/attrs.py @@ -10,6 +10,7 @@ from typing import Annotated, Literal from typing import Any +import numpy as np import pandas as pd from pydantic import ConfigDict, PlainValidator from pydantic import Field, model_validator @@ -277,6 +278,15 @@ def flat_dump(self) -> dict: for coord_name, coord in out.pop("coords").items(): for name, val in coord.items(): out[f"{coord_name}_{name}"] = val + # ensure step has right type if nullish + step_name, start_name = f"{coord_name}_step", f"{coord_name}_min" + step, start = out[step_name], out[start_name] + if step is None: + is_time = isinstance(start, (np.datetime64, np.timedelta64)) + if is_time: + out[step_name] = np.timedelta64("NaT") + elif isinstance(start, (float, np.floating)): + out[step_name] = np.NaN return out diff --git a/dascore/core/coords.py b/dascore/core/coords.py index 3fa5352f..d2dfa4e2 100644 --- a/dascore/core/coords.py +++ b/dascore/core/coords.py @@ -75,13 +75,13 @@ def ensure_consistent_dtype(cls, value, _info): # for some reason all ints are getting converted to floats. This # hack just fixes that. TODO: See if this is needed in a few version # after pydantic 2.1.1 - if np.issubdtype(dtype, np.datetime64): + if pd.isnull(value): + return value + elif np.issubdtype(dtype, np.datetime64): if _info.field_name == "step": value = dc.to_timedelta64(value) else: value = dc.to_datetime64(value) - elif pd.isnull(value): - return value elif np.issubdtype(dtype, np.timedelta64): value = dc.to_timedelta64(value) # convert numpy numerics back to python diff --git a/dascore/utils/chunk.py b/dascore/utils/chunk.py index 56ef1a25..7ae1f908 100644 --- a/dascore/utils/chunk.py +++ b/dascore/utils/chunk.py @@ -185,7 +185,7 @@ def _get_continuity_group_number(self, start, stop, step) -> pd.Series: group_num = has_gap.astype(np.int64).cumsum() return group_num[start.index] - def _get_sampling_group_num(self, df, tolerance=0.05) -> pd.Series: + def _get_sampling_group_num(self, step, tolerance=0.05) -> pd.Series: """ Because sampling can be off a little, this adds some tolerance for how sampling affects groups. @@ -193,7 +193,7 @@ def _get_sampling_group_num(self, df, tolerance=0.05) -> pd.Series: Tolerance affects how close samples have to be in order to count as the same. 5% is used here. """ - col = df[f"{self._name}_step"].values + col = step.values sort_args = np.argsort(col) sorted_col = col[sort_args] roll_forward = np.roll(sorted_col, shift=1) @@ -201,7 +201,7 @@ def _get_sampling_group_num(self, df, tolerance=0.05) -> pd.Series: out_of_threshold = diff > tolerance group_number = numpy.cumsum(out_of_threshold) # undo sorting - out = pd.Series(group_number[np.argsort(sort_args)], index=df.index) + out = pd.Series(group_number[np.argsort(sort_args)], index=step.index) return out def _get_duration_overlap(self, duration, start, step, overlap=None): @@ -351,7 +351,7 @@ def _get_group(self, df, start, stop, step): being consistent and group columns matching. """ cont_g = self._get_continuity_group_number(start, stop, step) - samp_g = self._get_sampling_group_num(df) + samp_g = self._get_sampling_group_num(step) col_g = self._get_col_group(df, cont_g) group_series = [x.astype(str) for x in [samp_g, col_g, cont_g]] group = reduce(lambda x, y: x + "_" + y, group_series) diff --git a/dascore/utils/patch.py b/dascore/utils/patch.py index 74dd0e22..93a29031 100644 --- a/dascore/utils/patch.py +++ b/dascore/utils/patch.py @@ -218,13 +218,14 @@ def patches_to_df( A dataframe with the attrs of each patch converted to a columns plus a field called 'patch' which contains a reference to the patches. """ - if isinstance(patches, dc.BaseSpool): - df = patches._df + + if hasattr(patches, "_df"): + df = patches._df # noqa # Handle spool case elif hasattr(patches, "get_contents"): - return patches.get_contents() + df = patches.get_contents() elif isinstance(patches, pd.DataFrame): - return patches + df = patches else: df = pd.DataFrame([x.flat_dump() for x in scan_patches(patches)]) if df.empty: # create empty df with appropriate columns diff --git a/dascore/utils/pd.py b/dascore/utils/pd.py index 9048ee98..8ac89139 100644 --- a/dascore/utils/pd.py +++ b/dascore/utils/pd.py @@ -180,10 +180,11 @@ def get_interval_columns(df, name, arrays=False): if missing_cols: msg = f"Dataframe is missing {missing_cols} to chunk on {name}" raise KeyError(msg) + start, stop, step = df[names[0]], df[names[1]], df[names[2]] if not arrays: - return df[names[0]], df[names[1]], df[names[2]] + return start, stop, step else: - return df[names[0]].values, df[names[1]].values, df[names[2]].values + return start.values, stop.values, step.values def yield_slice_from_kwargs(df, kwargs) -> tuple[str, slice]: