Skip to content

Commit

Permalink
Simplify the parse_slice function (#78)
Browse files Browse the repository at this point in the history
* Simplify the parse_slice function

* Rename test_utils into utils

* Add tests for parse_slice

* Update h5grove/utils.py

Co-authored-by: Thomas VINCENT <[email protected]>

---------

Co-authored-by: Thomas VINCENT <[email protected]>
  • Loading branch information
loichuder and t20100 authored Feb 22, 2023
1 parent 496edfe commit 6650a28
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 84 deletions.
40 changes: 24 additions & 16 deletions h5grove/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,26 @@ def get_entity_from_file(
return h5file[path]


def parse_slice(dataset: h5py.Dataset, slice_str: str) -> Tuple[Union[slice, int], ...]:
def parse_slice(slice_str: str) -> Tuple[Union[slice, int], ...]:
"""
Parses a string containing a slice under NumPy format.
Examples:
'5' => (5,)
'1, 2:5' => (1, slice(2,5))
'0:10:5, 2, 3:' => (slice(0, 10, 5), 2, slice(3, None, None))
:param slice_str: String containing the slice
"""
if "," not in slice_str:
return (parse_slice_member(slice_str, dataset.shape[0]),)
return (parse_slice_member(slice_str),)

slice_members = slice_str.split(",")

if len(slice_members) > dataset.ndim:
raise TypeError(
f"{slice_str} is a {len(slice_members)}d slice while the dataset is {dataset.ndim}d"
)

return tuple(
parse_slice_member(s, dataset.shape[i]) for i, s in enumerate(slice_members)
)
return tuple(parse_slice_member(s) for s in slice_members)


def parse_slice_member(slice_member: str, max_dim: int) -> Union[slice, int]:
def parse_slice_member(slice_member: str) -> Union[slice, int]:
if ":" not in slice_member:
return int(slice_member)

Expand All @@ -104,16 +107,16 @@ def parse_slice_member(slice_member: str, max_dim: int) -> Union[slice, int]:
start, stop = slice_params

return slice(
int(start) if start != "" else 0, int(stop) if stop != "" else max_dim
int(start) if start != "" else 0, int(stop) if stop != "" else None
)

if len(slice_params) == 3:
start, stop, step = slice_params

return slice(
int(start) if start != "" else 0,
int(stop) if stop != "" else max_dim,
int(step) if step != "" else 1,
int(start) if start != "" else None,
int(stop) if stop != "" else None,
int(step) if step != "" else None,
)

raise TypeError(f"{slice_member} is not a valid slice")
Expand Down Expand Up @@ -252,7 +255,12 @@ def get_dataset_slice(dataset: h5py.Dataset, selection: Selection):
return dataset[()]

if isinstance(selection, str):
return dataset[parse_slice(dataset, selection)]
parsed_slice = parse_slice(selection)
if len(parsed_slice) > dataset.ndim:
raise ValueError(
f"{selection} has too many members to slice a {dataset.ndim}D dataset"
)
return dataset[parsed_slice]

return dataset[selection]

Expand Down
2 changes: 1 addition & 1 deletion test/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from conftest import BaseServer
from h5grove.models import LinkResolution
from test_utils import decode_response, decode_array_response
from utils import decode_response, decode_array_response


class BaseTestEndpoints:
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from urllib.error import HTTPError

import pytest
from test_utils import Response, assert_error_response
from utils import Response, assert_error_response
from h5grove.encoders import orjson_encode


Expand Down
2 changes: 1 addition & 1 deletion test/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from conftest import BaseServer
import base_test
from test_utils import Response
from utils import Response

from h5grove.fastapi_utils import router, settings

Expand Down
2 changes: 1 addition & 1 deletion test/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from conftest import BaseServer
from test_utils import Response
from utils import Response
import base_test

from h5grove.flask_utils import BLUEPRINT
Expand Down
2 changes: 1 addition & 1 deletion test/test_tornado.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tornado.web

from conftest import BaseServer
from test_utils import Response, assert_error_response
from utils import Response, assert_error_response
import base_test

from h5grove.tornado_utils import get_handlers
Expand Down
68 changes: 5 additions & 63 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,7 @@
import io
import json
import numpy as np
from typing import List, NamedTuple, Tuple
from h5grove.utils import parse_slice

from h5grove.utils import hdf_path_join


class Response(NamedTuple):
"""Return type of :meth:`get`"""

status: int
headers: List[Tuple[str, str]]
content: bytes

def find_header_value(self, key: str):
"""Find header value by key (case-insensitive)"""
return {h[0].lower(): h[1] for h in self.headers}[key.lower()]


def test_root_path_join():
assert hdf_path_join("/", "child") == "/child"


def test_group_path_join():
assert hdf_path_join("/group1/group2", "data") == "/group1/group2/data"


def test_group_path_join_trailing():
assert hdf_path_join("/group1/group2/", "data") == "/group1/group2/data"


def decode_response(response: Response, format: str = "json"):
"""Decode response content according to given format"""
content_type = response.find_header_value("content-type")

if format == "json":
assert "application/json" in content_type
return json.loads(response.content)
if format == "npy":
assert content_type == "application/octet-stream"
return np.load(io.BytesIO(response.content))
raise ValueError(f"Unsupported format: {format}")


def decode_array_response(
response: Response,
format: str,
dtype: str,
shape: Tuple[int],
) -> np.ndarray:
"""Decode data array response content according to given information"""
content_type = response.find_header_value("content-type")

if format == "bin":
assert content_type == "application/octet-stream"
return np.frombuffer(response.content, dtype=dtype).reshape(shape)

return np.array(decode_response(response, format), copy=False)


def assert_error_response(response: Response, error_code: int):
assert response.status == error_code
content = decode_response(response)
assert isinstance(content, dict) and isinstance(content["message"], str)
def test_parse_slice():
assert parse_slice("5") == (5,)
assert parse_slice("1, 2:5") == (1, slice(2, 5))
assert parse_slice("0:10:5, 2, 3:") == (slice(0, 10, 5), 2, slice(3, None))
65 changes: 65 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import io
import json
import numpy as np
from typing import List, NamedTuple, Tuple

from h5grove.utils import hdf_path_join


class Response(NamedTuple):
"""Return type of :meth:`get`"""

status: int
headers: List[Tuple[str, str]]
content: bytes

def find_header_value(self, key: str):
"""Find header value by key (case-insensitive)"""
return {h[0].lower(): h[1] for h in self.headers}[key.lower()]


def test_root_path_join():
assert hdf_path_join("/", "child") == "/child"


def test_group_path_join():
assert hdf_path_join("/group1/group2", "data") == "/group1/group2/data"


def test_group_path_join_trailing():
assert hdf_path_join("/group1/group2/", "data") == "/group1/group2/data"


def decode_response(response: Response, format: str = "json"):
"""Decode response content according to given format"""
content_type = response.find_header_value("content-type")

if format == "json":
assert "application/json" in content_type
return json.loads(response.content)
if format == "npy":
assert content_type == "application/octet-stream"
return np.load(io.BytesIO(response.content))
raise ValueError(f"Unsupported format: {format}")


def decode_array_response(
response: Response,
format: str,
dtype: str,
shape: Tuple[int],
) -> np.ndarray:
"""Decode data array response content according to given information"""
content_type = response.find_header_value("content-type")

if format == "bin":
assert content_type == "application/octet-stream"
return np.frombuffer(response.content, dtype=dtype).reshape(shape)

return np.array(decode_response(response, format), copy=False)


def assert_error_response(response: Response, error_code: int):
assert response.status == error_code
content = decode_response(response)
assert isinstance(content, dict) and isinstance(content["message"], str)

0 comments on commit 6650a28

Please sign in to comment.