Skip to content

Commit

Permalink
fix chunk
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Jul 30, 2023
1 parent 2c8ce6d commit 49aff20
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 13 deletions.
10 changes: 10 additions & 0 deletions dascore/core/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Annotated, Literal
from typing import Any

import numpy as np
import pandas as pd
from pydantic import ConfigDict, PlainValidator
from pydantic import Field, model_validator
Expand Down Expand Up @@ -277,6 +278,15 @@ def flat_dump(self) -> dict:
for coord_name, coord in out.pop("coords").items():
for name, val in coord.items():
out[f"{coord_name}_{name}"] = val
# ensure step has right type if nullish
step_name, start_name = f"{coord_name}_step", f"{coord_name}_min"
step, start = out[step_name], out[start_name]
if step is None:
is_time = isinstance(start, (np.datetime64, np.timedelta64))
if is_time:
out[step_name] = np.timedelta64("NaT")
elif isinstance(start, (float, np.floating)):
out[step_name] = np.NaN
return out


Expand Down
6 changes: 3 additions & 3 deletions dascore/core/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def ensure_consistent_dtype(cls, value, _info):
# for some reason all ints are getting converted to floats. This
# hack just fixes that. TODO: See if this is needed in a few version
# after pydantic 2.1.1
if np.issubdtype(dtype, np.datetime64):
if pd.isnull(value):
return value
elif np.issubdtype(dtype, np.datetime64):
if _info.field_name == "step":
value = dc.to_timedelta64(value)
else:
value = dc.to_datetime64(value)
elif pd.isnull(value):
return value
elif np.issubdtype(dtype, np.timedelta64):
value = dc.to_timedelta64(value)
# convert numpy numerics back to python
Expand Down
8 changes: 4 additions & 4 deletions dascore/utils/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,23 @@ def _get_continuity_group_number(self, start, stop, step) -> pd.Series:
group_num = has_gap.astype(np.int64).cumsum()
return group_num[start.index]

def _get_sampling_group_num(self, df, tolerance=0.05) -> pd.Series:
def _get_sampling_group_num(self, step, tolerance=0.05) -> pd.Series:
"""
Because sampling can be off a little, this adds some tolerance for
how sampling affects groups.
Tolerance affects how close samples have to be in order to count as
the same. 5% is used here.
"""
col = df[f"{self._name}_step"].values
col = step.values
sort_args = np.argsort(col)
sorted_col = col[sort_args]
roll_forward = np.roll(sorted_col, shift=1)
diff = (sorted_col - roll_forward) / sorted_col
out_of_threshold = diff > tolerance
group_number = numpy.cumsum(out_of_threshold)
# undo sorting
out = pd.Series(group_number[np.argsort(sort_args)], index=df.index)
out = pd.Series(group_number[np.argsort(sort_args)], index=step.index)
return out

def _get_duration_overlap(self, duration, start, step, overlap=None):
Expand Down Expand Up @@ -351,7 +351,7 @@ def _get_group(self, df, start, stop, step):
being consistent and group columns matching.
"""
cont_g = self._get_continuity_group_number(start, stop, step)
samp_g = self._get_sampling_group_num(df)
samp_g = self._get_sampling_group_num(step)
col_g = self._get_col_group(df, cont_g)
group_series = [x.astype(str) for x in [samp_g, col_g, cont_g]]
group = reduce(lambda x, y: x + "_" + y, group_series)
Expand Down
9 changes: 5 additions & 4 deletions dascore/utils/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,14 @@ def patches_to_df(
A dataframe with the attrs of each patch converted to a columns
plus a field called 'patch' which contains a reference to the patches.
"""
if isinstance(patches, dc.BaseSpool):
df = patches._df

if hasattr(patches, "_df"):
df = patches._df # noqa
# Handle spool case
elif hasattr(patches, "get_contents"):
return patches.get_contents()
df = patches.get_contents()
elif isinstance(patches, pd.DataFrame):
return patches
df = patches
else:
df = pd.DataFrame([x.flat_dump() for x in scan_patches(patches)])
if df.empty: # create empty df with appropriate columns
Expand Down
5 changes: 3 additions & 2 deletions dascore/utils/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,11 @@ def get_interval_columns(df, name, arrays=False):
if missing_cols:
msg = f"Dataframe is missing {missing_cols} to chunk on {name}"
raise KeyError(msg)
start, stop, step = df[names[0]], df[names[1]], df[names[2]]
if not arrays:
return df[names[0]], df[names[1]], df[names[2]]
return start, stop, step
else:
return df[names[0]].values, df[names[1]].values, df[names[2]].values
return start.values, stop.values, step.values


def yield_slice_from_kwargs(df, kwargs) -> tuple[str, slice]:
Expand Down

0 comments on commit 49aff20

Please sign in to comment.