Skip to content

Commit

Permalink
fix: coerce the empty string to '/' when accessing hdf5 data
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Sep 13, 2024
1 parent 7a8eda8 commit 14a4810
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
55 changes: 34 additions & 21 deletions src/fibsem_tools/io/h5.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import warnings
from typing import Any, Union
from __future__ import annotations
from typing import TYPE_CHECKING

import h5py
if TYPE_CHECKING:
from typing import Any

import warnings
import h5py
from fibsem_tools.type import PathLike

H5_ACCESS_MODES = ("r", "r+", "w", "w-", "x", "a")

# file, group and dataset creation take both of these
H5_GROUP_KWDS = ("name", "track_order")

H5_DATASET_KWDS = (
"name",
"shape",
"dtype",
"data",
Expand All @@ -21,15 +26,11 @@
"maxshape",
"fillvalue",
"track_times",
"track_order",
"external",
"allow_unknown_filter",
)

H5_GROUP_KWDS = ("name", "track_order")
) + H5_GROUP_KWDS

H5_FILE_KWDS = (
"name",
"mode",
"driver",
"libver",
Expand All @@ -38,11 +39,10 @@
"rdcc_nslots",
"rdcc_nbytes",
"rdcc_w0",
"track_order",
"fs_strategy",
"fs_persist",
"fs_threshold",
)
) + H5_GROUP_KWDS


def partition_h5_kwargs(**kwargs: Any) -> tuple[dict[str, Any], dict[str, Any]]:
Expand All @@ -60,10 +60,19 @@ def partition_h5_kwargs(**kwargs: Any) -> tuple[dict[str, Any], dict[str, Any]]:

def access(
store: PathLike, path: PathLike, mode: str, **kwargs: Any
) -> Union[h5py.Dataset, h5py.Group]:
) -> h5py.Dataset | h5py.Group:
"""
Docstring
Get or create an hdf5 dataset or group. Be advised that this function opens a file handle to the
hdf5 file. The caller is responsible for closing that file handle, e.g.
via access('path.h5').file.close()
"""

# hdf5 names the root group "/", so we convert any path equal to the empty string to "/" instead
if path == "":
path_normalized = "/"
else:
path_normalized = str(path)

if mode not in H5_ACCESS_MODES:
msg = f"Invalid access mode. Got {mode}, expected one of {H5_ACCESS_MODES}."
raise ValueError(msg)
Expand All @@ -73,21 +82,25 @@ def access(

h5f = h5py.File(store, mode=mode, **file_kwargs)

if mode in ("r", "r+", "a") and (result := h5f.get(path)) is not None:
if mode in ("r", "r+", "a") and (result := h5f.get(path_normalized)) is not None:
# access a pre-existing dataset or group
return result
else:
if len(dataset_kwargs) > 0:
# create a dataset
if "name" in dataset_kwargs:
warnings.warn(
"""
'Name' was provided to this function as a keyword argument. This
value will be replaced with the second argument to this function.
"""
msg = (
"'Name' was provided to this function as a keyword argument. "
"This value will be ignored, and instead the second argument to this function "
"will be used as the name of the dataset or group being created."
)
dataset_kwargs["name"] = path
warnings.warn(msg)

dataset_kwargs["name"] = path_normalized
result = h5f.create_dataset(**dataset_kwargs)
else:
result = h5f.require_group(path)
# create a group
result = h5f.require_group(path_normalized)

result.attrs.update(**attrs)

Expand Down
18 changes: 12 additions & 6 deletions tests/io/test_h5.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import os
from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from _pytest.compat import LEGACY_PATH

import os
import pytest
import h5py
import numpy as np

from fibsem_tools.io.h5 import access


def test_access_array(tmpdir):
@pytest.mark.parametrize("key", ("s0", "s2"))
def test_access_array(tmpdir: LEGACY_PATH, key: str) -> None:
path = os.path.join(str(tmpdir), "foo.h5")
key = "s0"
data = np.random.randint(0, 255, size=(10, 10, 10), dtype="uint8")
attrs = {"resolution": "1000"}

Expand All @@ -27,9 +33,9 @@ def test_access_array(tmpdir):
arr3.file.close()


def test_access_group(tmpdir):
key = "s0"
store = os.path.join(str(tmpdir), key)
@pytest.mark.parametrize("key", ("a", "/", ""))
def test_access_group(tmpdir: LEGACY_PATH, key: str) -> None:
store = os.path.join(str(tmpdir), "test.h5")
attrs = {"resolution": "1000"}

grp = access(store, key, attrs=attrs, mode="w")
Expand Down

0 comments on commit 14a4810

Please sign in to comment.