Skip to content

Commit

Permalink
Refactor: Dataset load memory optimizations (#93)
Browse files Browse the repository at this point in the history
- New `Dataset.inspect` function for reading dataset headers
- Use chunked mmap loading for numpy arrays
- Break `load` into helper functions
  • Loading branch information
nfrasser authored Aug 27, 2024
1 parent bd8a6bd commit 1cc7362
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 105 deletions.
9 changes: 9 additions & 0 deletions cryosparc/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ def __array_wrap__(self, obj, context=None, return_scalar=False):
# or n.median
return obj[()] if obj.shape == () else super().__array_wrap__(obj, context, return_scalar) # type: ignore

def __setitem__(self, key, value):
if isinstance(value, n.ndarray):
# parse fixed-size size strings
if value.dtype.char == "S":
value = n.vectorize(hashcache(bytes.decode), otypes="O")(value)
elif value.dtype.char == "U":
value = n.vectorize(hashcache(str), otypes="O")(value)
return super().__setitem__(key, value)

def to_fixed(self) -> "Column":
"""
If this Column is composed of Python objects, convert to fixed-size
Expand Down
219 changes: 164 additions & 55 deletions cryosparc/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@
decode_dataset_header,
encode_dataset_header,
fielddtype,
filter_descr,
get_data_field,
get_data_field_dtype,
makefield,
safe_makefield,
normalize_field,
)
from .errors import DatasetLoadError
from .row import R, Row, Spool
from .stream import AsyncBinaryIO, Streamable
from .util import bopen, default_rng, hashcache, random_integers, u32bytesle, u32intle
from .util import bopen, default_rng, random_integers, u32bytesle, u32intle

if TYPE_CHECKING:
from numpy.typing import ArrayLike, DTypeLike, NDArray
Expand Down Expand Up @@ -537,7 +537,40 @@ def common_fields(cls, *datasets: "Dataset", assert_same_fields=False) -> List[F
return [f for f in datasets[0].descr() if f in fields]

@classmethod
def load(cls, file: Union[str, PurePath, IO[bytes]], cstrs: bool = False):
def inspect(cls, file: Union[str, PurePath]) -> DatasetHeader:
"""
Given a path to a dataset file, get information included in its header.
Args:
file: (str | Path): Readable file path.
Returns:
DatasetHeader: Dictionary with dataset
"""
try:
with open(file, "rb") as f:
prefix = f.read(6)
if prefix == FORMAT_MAGIC_PREFIXES[NUMPY_FORMAT]:
f.seek(0) # will be done after context block for mmapping
elif prefix == FORMAT_MAGIC_PREFIXES[CSDAT_FORMAT]:
return cls._load_stream_header(f)
else:
raise ValueError(f"Could not determine dataset format (prefix is {prefix})")
# numpy format
return cls._load_numpy_header(file)
except Exception as err:
raise DatasetLoadError(f"Could not load dataset from file {file}") from err

@classmethod
def load(
cls,
file: Union[str, PurePath, IO[bytes]],
*,
prefixes: Optional[Sequence[str]] = None,
fields: Optional[Sequence[str]] = None,
cstrs: bool = False,
):
"""
Read a dataset from path or file handle.
Expand All @@ -549,6 +582,10 @@ def load(cls, file: Union[str, PurePath, IO[bytes]], cstrs: bool = False):
file (str | Path | IO): Readable file path or handle. Must be
seekable if loading a dataset saved in the default
``NUMPY_FORMAT``
prefixes (list[str], optional): Which field prefixes to load. If
not specified, loads either all or specified `fields`.
fields (list[str], optional): Which fields to load. If not
specified, loads either all or specified `prefixes`.
cstrs (bool): If True, load internal string columns as C strings
instead of Python strings. Defaults to False.
Expand All @@ -563,48 +600,119 @@ def load(cls, file: Union[str, PurePath, IO[bytes]], cstrs: bool = False):
with bopen(file, "rb") as f:
prefix = f.read(6)
if prefix == FORMAT_MAGIC_PREFIXES[NUMPY_FORMAT]:
f.seek(0)
indata = n.load(f, allow_pickle=False)
dset = cls(indata)
if cstrs:
dset.to_cstrs()
return dset
elif prefix != FORMAT_MAGIC_PREFIXES[CSDAT_FORMAT]:
raise TypeError(f"Could not determine dataset format (prefix is {prefix})")

headersize = u32intle(f.read(4))
header = decode_dataset_header(f.read(headersize))

# Calling addrows separately to minimizes column-based
# allocations, improves performance by ~20%
dset = cls.allocate(0, header["dtype"])
data = dset._data
data.addrows(header["length"])
loader = Stream(data)
for field in header["dtype"]:
colsize = u32intle(f.read(4))
buffer = f.read(colsize)
if field[0] in header["compressed_fields"]:
loader.decompress_col(field[0], buffer)
else:
data.getbuf(field[0])[:] = buffer

# Read in the string heap (rest of stream)
# NOTE: There will be a bug here for long column keys that are
# added when there's already an allocated string in a T_STR
# column in the saved dataset (should be rare).
heap = f.read()
data.setstrheap(heap)

# Convert C strings to Python strings
loader.cast_objs_to_strs() # dtype may be T_OBJ but actually all are T_STR
if not cstrs:
dset.to_pystrs()
return dset

f.seek(0) # will be done after context block for mmapping
elif prefix == FORMAT_MAGIC_PREFIXES[CSDAT_FORMAT]:
return cls._load_stream(
f,
prefixes=prefixes,
fields=fields,
cstrs=cstrs,
seekable=isinstance(file, (str, PurePath)),
)
else:
raise ValueError(f"Could not determine dataset format (prefix is {prefix})")
# numpy
return cls._load_numpy(file, prefixes=prefixes, fields=fields, cstrs=cstrs)
except Exception as err:
raise DatasetLoadError(f"Could not load dataset from file {file}") from err

@classmethod
def _load_numpy_header(cls, file: Union[str, PurePath]) -> DatasetHeader:
indata = n.load(str(file), mmap_mode="r", allow_pickle=False)
fields = [normalize_field(f[0], fielddtype(f)) for f in indata.dtype.descr]
return DatasetHeader(length=len(indata), dtype=fields, compression=None, compressed_fields=[])

@classmethod
def _load_numpy(
cls,
file: Union[str, PurePath, IO[bytes]],
prefixes: Optional[Sequence[str]] = None,
fields: Optional[Sequence[str]] = None,
cstrs: bool = False,
):
import os

# disable mmap by setting CRYOSPARC_DATASET_MMAP=false
if os.getenv("CRYOSPARC_DATASET_MMAP", "true").lower() == "true" and isinstance(file, (str, PurePath)):
# Use mmap to avoid loading full record array into memory
# cast path to a string for older numpy/python
mmap_mode, f = "r", str(file)
chunk_size = 2**14 # magic number optimizes memory and performance
else:
mmap_mode, f = None, file
chunk_size = 2**60 # huge enough number so you don't use chunks

indata = n.load(f, mmap_mode=mmap_mode, allow_pickle=False)
size = len(indata)
descr = filter_descr(indata.dtype.descr, keep_prefixes=prefixes, keep_fields=fields)
dset = cls.allocate(size, descr)
offset = 0
while offset < size:
end = min(offset + chunk_size, size)
chunk = indata[offset:end]
for field in descr:
dset[field[0]][offset:end] = chunk[field[0]]
offset += chunk_size
if mmap_mode and offset < size:
# reset mmap to avoid excessive memory usage
del indata
indata = n.load(f, mmap_mode=mmap_mode, allow_pickle=False)

if cstrs:
dset.to_cstrs()
return dset

@classmethod
def _load_stream_header(cls, f: IO[bytes]) -> DatasetHeader:
# NOTE: assumes prefix header bytes have already been read
headersize = u32intle(f.read(4))
return decode_dataset_header(f.read(headersize))

@classmethod
def _load_stream(
cls,
f: IO[bytes],
prefixes: Optional[Sequence[str]] = None,
fields: Optional[Sequence[str]] = None,
cstrs: bool = False,
seekable: bool = False,
):
# NOTE: assumes prefix header bytes have already been read
header = cls._load_stream_header(f)
descr = filter_descr(header["dtype"], keep_prefixes=prefixes, keep_fields=fields)
field_names = {field[0] for field in descr}

# Calling addrows separately to minimizes column-based
# allocations, improves performance by ~20%
dset = cls.allocate(0, descr)
data = dset._data
data.addrows(header["length"])
loader = Stream(data)
for field in header["dtype"]:
colsize = u32intle(f.read(4))
if field[0] not in field_names:
# try to seek instead of read to reduce memory usage
f.seek(colsize, 1) if seekable else f.read(colsize)
continue # skip fields that were not selected
buffer = f.read(colsize)
if field[0] in header["compressed_fields"]:
loader.decompress_col(field[0], buffer)
else:
data.getbuf(field[0])[:] = buffer

# Read in the string heap (rest of stream)
# NOTE: There will be a bug here for long column keys that are
# added when there's already an allocated string in a T_STR
# column in the saved dataset (should be rare).
heap = f.read()
data.setstrheap(heap)

# Convert C strings to Python strings
loader.cast_objs_to_strs() # dtype may be T_OBJ but actually all are T_STR
if not cstrs:
dset.to_pystrs()
return dset

@classmethod
async def from_async_stream(cls, stream: AsyncBinaryIO):
headersize = u32intle(await stream.read(4))
Expand Down Expand Up @@ -668,6 +776,9 @@ def stream(self, compression: Literal["lz4", None] = None) -> Generator[bytes, N
``format=CSDAT_FORMAT``. Call ``Dataset.load`` on the resulting
file/buffer to retrieve the original data.
Args:
compression (Literal["lz4", None], optional):
Yields:
bytes: Dataset file chunks
"""
Expand All @@ -679,7 +790,10 @@ def stream(self, compression: Literal["lz4", None] = None) -> Generator[bytes, N
compressed_fields = [f for f in self if compression and f not in NEVER_COMPRESS_FIELDS]
header = encode_dataset_header(
DatasetHeader(
length=len(self), dtype=self.descr(), compression=compression, compressed_fields=compressed_fields
length=len(self),
dtype=self.descr(),
compression=compression,
compressed_fields=compressed_fields,
)
)
yield u32bytesle(len(header))
Expand Down Expand Up @@ -734,16 +848,16 @@ def __init__(
elif isinstance(allocate, n.ndarray): # record array
for field in allocate.dtype.descr:
assert field[0], f"Cannot initialize with record array of dtype {allocate.dtype}"
field = ("uid", "u8") if field[0] == "uid" else field
populate.append((field, allocate[field[0]]))
a = allocate[field[0]]
populate.append((normalize_field(field[0], fielddtype(field)), a))
elif isinstance(allocate, Mapping):
for f, v in allocate.items():
a = n.asarray(v)
populate.append((safe_makefield(f, arraydtype(a)), a))
populate.append((normalize_field(f, arraydtype(a)), a))
else:
for f, v in allocate:
a = n.asarray(v)
populate.append((safe_makefield(f, arraydtype(a)), a))
populate.append((normalize_field(f, arraydtype(a)), a))

# Check that all entries are the same length
nrows = 0
Expand Down Expand Up @@ -818,11 +932,6 @@ def __setitem__(self, key: str, val: "ArrayLike"):
val (ArrayLike): numpy array or value to assign
"""
assert self._data.has(key), f"Cannot set non-existing dataset key {key}; use add_fields() first"
if isinstance(val, n.ndarray):
if val.dtype.char == "S":
val = n.vectorize(hashcache(bytes.decode), otypes="O")(val)
elif val.dtype.char == "U":
val = n.vectorize(hashcache(str), otypes="O")(val)
self[key][:] = val

def __delitem__(self, key: str):
Expand Down Expand Up @@ -1021,7 +1130,7 @@ def add_fields(
if dtypes:
dt = dtypes.split(",") if isinstance(dtypes, str) else dtypes
assert len(fields) == len(dt), "Incorrect dtype spec"
desc = [safe_makefield(str(f), dt) for f, dt in zip(fields, dt)]
desc = [normalize_field(str(f), dt) for f, dt in zip(fields, dt)]
else:
desc = fields # type: ignore

Expand Down Expand Up @@ -1210,7 +1319,7 @@ def copy_fields(self, old_fields: List[str], new_fields: List[str]):
assert len(old_fields) == len(new_fields), "Number of old and new fields must match"
current_fields = self.fields()
missing_fields = [
makefield(new, get_data_field_dtype(self._data, old))
normalize_field(new, get_data_field_dtype(self._data, old))
for old, new in zip(old_fields, new_fields)
if new not in current_fields
]
Expand Down
Loading

0 comments on commit 1cc7362

Please sign in to comment.