Skip to content

Commit

Permalink
Refactor: Dataset interface updates for future improved job I/O (#83)
Browse files Browse the repository at this point in the history
* WIP

* feat: dataset unordered equivalence checker

* refactor: consistent Dataset.filter_fields implementation

No need to use allocate here

* feat: rename flag for Dataset.filter_prefix

Allows renaming prefix at the same time as filtering it out

* chore: enforce kwarg for Dataset copy option

* docs: dataset docstring fix

* refactor: remove now-unnecessary methods

* chore: remove unused import

* fix: correct filter fields copy behaviour

* fix: test for python v3.7
  • Loading branch information
nfrasser authored May 15, 2024
1 parent 16fbb64 commit c405441
Showing 1 changed file with 43 additions and 28 deletions.
71 changes: 43 additions & 28 deletions cryosparc/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""

from functools import lru_cache, reduce
from functools import reduce
from pathlib import PurePath
from typing import (
IO,
Expand Down Expand Up @@ -604,15 +604,6 @@ def load(cls, file: Union[str, PurePath, IO[bytes]], cstrs: bool = False):
except Exception as err:
raise DatasetLoadError(f"Could not load dataset from file {file}") from err

@classmethod
def load_cached(cls, file: Union[str, PurePath, IO[bytes]], cstrs: bool = False):
return cls._load_cached(file, cstrs).copy()

@classmethod
@lru_cache(maxsize=None)
def _load_cached(cls, file: Union[str, PurePath, IO[bytes]], cstrs: bool = False):
return cls.load(file, cstrs)

@classmethod
async def from_async_stream(cls, stream: AsyncBinaryIO):
headersize = u32intle(await stream.read(4))
Expand Down Expand Up @@ -857,7 +848,7 @@ def __eq__(self, other: object):
Check whether two datasets contain the same data in the same order.
Args:
other (Dataset): dataset to compare
other (object): dataset to compare
Returns:
bool: True or False
Expand Down Expand Up @@ -1050,7 +1041,7 @@ def add_fields(

return self._reset()

def filter_fields(self, names: Union[Collection[str], Callable[[str], bool]], copy: bool = False):
def filter_fields(self, names: Union[Collection[str], Callable[[str], bool]], *, copy: bool = False):
"""
Keep only the given fields from the dataset. Provide a list of fields or
function that returns ``True`` if a given field name should be kept.
Expand All @@ -1066,16 +1057,14 @@ def filter_fields(self, names: Union[Collection[str], Callable[[str], bool]], co
Dataset: current dataset or copy with filtered fields
"""
test = (lambda n: n in names) if isinstance(names, Collection) else names
new_fields = [f for f in self.descr() if f[0] == "uid" or test(f[0])]
if len(new_fields) == len(self.descr()):
return self
new_fields = [f for f in self.fields() if f == "uid" or test(f)]
if len(new_fields) == self._data.ncol():
return self.copy() if copy else self # nothing to filter

result = self.allocate(len(self), new_fields)
for key, *_ in new_fields:
result[key] = self[key]
result = type(self)([(key, self[key]) for key in new_fields])
return result if copy else self._reset(result._data)

def filter_prefixes(self, prefixes: Collection[str], copy: bool = False):
def filter_prefixes(self, prefixes: Collection[str], *, copy: bool = False):
"""
Similar to ``filter_fields``, except takes list of prefixes.
Expand Down Expand Up @@ -1105,22 +1094,30 @@ def filter_prefixes(self, prefixes: Collection[str], copy: bool = False):
"""
return self.filter_fields(lambda n: any(n.startswith(p + "/") for p in prefixes), copy=copy)

def filter_prefix(self, keep_prefix: str, copy: bool = False):
def filter_prefix(self, keep_prefix: str, *, rename: Optional[str] = None, copy: bool = False):
"""
Similar to ``filter_prefixes`` but for a single prefix.
Args:
keep_prefix (str): Prefix to keep
keep_prefix (str): Prefix to keep.
rename (str, optional): If specified, rename prefix to this prefix.
Defaults to None.
copy (bool, optional): If True, return a copy if the dataset rather
than mutate. Defaults to False.
Returns:
Dataset: current dataset or copy with filtered prefix
"""
return self.filter_prefixes([keep_prefix], copy=copy)
keep_fields = [f for f in self.fields(exclude_uid=True) if f.startswith(f"{keep_prefix}/")]
new_fields = keep_fields
if rename and rename != keep_prefix:
new_fields = [f"{rename}/{f.split('/', 1)[1]}" for f in keep_fields]

result = type(self)([("uid", self["uid"])] + [(nf, self[f]) for f, nf in zip(keep_fields, new_fields)])
return result if copy else self._reset(result._data)

def drop_fields(self, names: Union[Collection[str], Callable[[str], bool]], copy: bool = False):
def drop_fields(self, names: Union[Collection[str], Callable[[str], bool]], *, copy: bool = False):
"""
Remove the given field names from the dataset. Provide a list of fields
or a function that takes a field name and returns True if that field
Expand All @@ -1139,7 +1136,7 @@ def drop_fields(self, names: Union[Collection[str], Callable[[str], bool]], copy
test = (lambda n: n not in names) if isinstance(names, Collection) else (lambda n: not names(n)) # type: ignore
return self.filter_fields(test, copy=copy)

def rename_fields(self, field_map: Union[Dict[str, str], Callable[[str], str]], copy: bool = False):
def rename_fields(self, field_map: Union[Dict[str, str], Callable[[str], str]], *, copy: bool = False):
"""
Change the name of dataset fields based on the given mapping.
Expand All @@ -1160,7 +1157,7 @@ def rename_fields(self, field_map: Union[Dict[str, str], Callable[[str], str]],
result = type(self)([(f if f == "uid" else fm(f), self[f]) for f in self])
return result if copy else self._reset(result._data)

def rename_field(self, current_name: str, new_name: str, copy: bool = False):
def rename_field(self, current_name: str, new_name: str, *, copy: bool = False):
"""
Change name of a dataset field based on the given mapping.
Expand All @@ -1175,7 +1172,7 @@ def rename_field(self, current_name: str, new_name: str, copy: bool = False):
"""
return self.rename_fields({current_name: new_name}, copy=copy)

def rename_prefix(self, old_prefix: str, new_prefix: str, copy: bool = False):
def rename_prefix(self, old_prefix: str, new_prefix: str, *, copy: bool = False):
"""
Similar to rename_fields, except changes the prefix of all fields with
the given ``old_prefix`` to ``new_prefix``.
Expand Down Expand Up @@ -1469,7 +1466,25 @@ def replace(self, query: Dict[str, "ArrayLike"], *others: "Dataset", assume_disj

return result

def to_cstrs(self, copy: bool = False):
def is_equivalent(self, other: object):
"""
Check whether two datasets contain the same data, regardless of field
order.
Args:
other (object): dataset to compare
Returns:
bool: True or False
"""
return (
isinstance(other, Dataset)
and len(self) == len(other)
and set(self.descr()) == set(other.descr())
and all(n.array_equal(self[f], other[f]) for f in self)
)

def to_cstrs(self, *, copy: bool = False):
"""
Convert all Python string columns to C strings. Resulting dataset fields
that previously had dtype ``np.object_`` (or ``T_OBJ`` internally) will get
Expand All @@ -1491,7 +1506,7 @@ def to_cstrs(self, copy: bool = False):
self._reset() # in case data got reallocated
return dset

def to_pystrs(self, copy: bool = False):
def to_pystrs(self, *, copy: bool = False):
"""
Convert all C string columns to Python strings. Resulting dataset fields
that previously had dtype ``np.uint64`` (and ``T_STR`` internally) will
Expand Down

0 comments on commit c405441

Please sign in to comment.