Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX(psoct): avoid loading all slices when files are "old matlab format" #27

Merged
merged 11 commits into from
Nov 22, 2024
202 changes: 136 additions & 66 deletions linc_convert/modalities/psoct/multi_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import json
import math
import os
from contextlib import contextmanager
from functools import wraps
from itertools import product
from typing import Any, Callable, Optional
from typing import Callable, Mapping, Optional
from warnings import warn

import cyclopts
Expand All @@ -38,54 +37,126 @@


def _automap(func: Callable) -> Callable:
"""Decorator to automatically map the array in the mat file.""" # noqa: D401
"""Automatically maps the array in the mat file."""

@wraps(func)
def wrapper(inp: str, out: str = None, **kwargs: dict) -> Any: # noqa: ANN401
def wrapper(inp: list[str], out: str = None, **kwargs: dict) -> callable:
if out is None:
out = os.path.splitext(inp[0])[0]
out += ".nii.zarr" if kwargs.get("nii", False) else ".ome.zarr"
kwargs["nii"] = kwargs.get("nii", False) or out.endswith(".nii.zarr")
with _mapmat(inp, kwargs.get("key", None)) as dat:
return func(dat, out, **kwargs)
dat = _mapmat(inp, kwargs.get("key", None))
return func(dat, out, **kwargs)

return wrapper


@contextmanager
def _mapmat(fnames: list[str], key: str = None) -> None:
"""Load or memory-map an array stored in a .mat file."""
loaded_data = []

for fname in fnames:
try:
# "New" .mat file
f = h5py.File(fname, "r")
except Exception:
# "Old" .mat file
f = loadmat(fname)

class _ArrayWrapper:
def _get_key(self, f: Mapping) -> str:
key = self.key
if key is None:
if not len(f.keys()):
raise Exception(f"{fname} is empty")
key = list(f.keys())[0]
raise Exception(f"{self.file} is empty")
for key in f.keys():
if key[:1] != "_":
break
if len(f.keys()) > 1:
warn(
f"More than one key in .mat file {fname}, "
f"More than one key in .mat file {self.file}, "
f'arbitrarily loading "{key}"'
)

if key not in f.keys():
raise Exception(f"Key {key} not found in file {fname}")
raise Exception(f"Key {key} not found in file {self.file}")

return key


class _H5ArrayWrapper(_ArrayWrapper):
def __init__(self, file: h5py.File, key: str | None) -> None:
self.file = file
self.key = key
self.array = file.get(self._get_key(self.file))

def __del__(self) -> None:
if hasattr(self.file, "close"):
self.file.close()

def load(self) -> np.ndarray:
self.array = self.array[...]
if hasattr(self.file, "close"):
self.file.close()
self.file = None
return self.array

@property
def shape(self) -> list[int]:
return self.array.shape

@property
def dtype(self) -> np.dtype:
return self.array.dtype

def __len__(self) -> int:
return len(self.array)

def __getitem__(self, index: object) -> np.ndarray:
return self.array[index]


class _MatArrayWrapper(_ArrayWrapper):
def __init__(self, file: str, key: str | None) -> None:
self.file = file
self.key = key
self.array = None

def __del__(self) -> None:
if hasattr(self.file, "close"):
self.file.close()

def load(self) -> np.ndarray:
f = loadmat(self.file)
self.array = f.get(self._get_key(f))
self.file = None
return self.array

@property
def shape(self) -> list[int]:
if self.array is None:
self.load()
return self.array.shape

@property
def dtype(self) -> np.dtype:
if self.array is None:
self.load()
return self.array.dtype

def __len__(self) -> int:
if self.array is None:
self.load()
return len(self.array)

def __getitem__(self, index: object) -> np.ndarray:
if self.array is None:
self.load()
return self.array[index]


def _mapmat(fnames: list[str], key: str = None) -> list[_ArrayWrapper]:
"""Load or memory-map an array stored in a .mat file."""
# loaded_data = []

def make_wrapper(fname: str) -> callable:
try:
# "New" .mat file
f = h5py.File(fname, "r")
return _H5ArrayWrapper(f, key)
except Exception:
# "Old" .mat file
return _MatArrayWrapper(fname, key)

if len(fnames) == 1:
yield f.get(key)
if hasattr(f, "close"):
f.close()
break
loaded_data.append(f.get(key))
yield loaded_data
# yield np.stack(loaded_data, axis=-1)
return [make_wrapper(fname) for fname in fnames]


@multi_slice.default
Expand Down Expand Up @@ -163,10 +234,10 @@ def convert(
omz = zarr.storage.DirectoryStore(out)
omz = zarr.group(store=omz, overwrite=True)

if not hasattr(inp[0], "dtype"):
raise Exception("Input is not numpy array. This is likely unexpected")
if len(inp[0].shape) != 2:
raise Exception("Input array is not 2d")
# if not hasattr(inp[0], "dtype"):
# raise Exception("Input is not an array. This is likely unexpected")
if len(inp[0].shape) < 2:
raise Exception("Input array is not 2d:", inp[0].shape)
# Prepare chunking options
opt = {
"dimension_separator": r"/",
Expand All @@ -177,10 +248,10 @@ def convert(
}
inp: list = inp
inp_shape = (*inp[0].shape, len(inp))
inp_chunk = [min(x, max_load) for x in inp_shape]
nk = ceildiv(inp_shape[0], inp_chunk[0])
nj = ceildiv(inp_shape[1], inp_chunk[1])
ni = ceildiv(inp_shape[2], inp_chunk[2])
inp_chunk = [min(x, max_load) for x in inp_shape[-3:]]
nk = ceildiv(inp_shape[-3], inp_chunk[0])
nj = ceildiv(inp_shape[-2], inp_chunk[1])
ni = len(inp)

nblevels = min(
[int(math.ceil(math.log2(x))) for i, x in enumerate(inp_shape) if i != no_pool]
Expand All @@ -193,32 +264,31 @@ def convert(
omz.create_dataset(str(0), shape=inp_shape, **opt)

# iterate across input chunks
for i, j, k in product(range(ni), range(nj), range(nk)):
loaded_chunk = np.stack(
[
inp[index][
k * inp_chunk[0] : (k + 1) * inp_chunk[0],
j * inp_chunk[1] : (j + 1) * inp_chunk[1],
]
for index in range(i * inp_chunk[2], (i + 1) * inp_chunk[2])
],
axis=-1,
)

print(
f"[{i + 1:03d}, {j + 1:03d}, {k + 1:03d}]",
"/",
f"[{ni:03d}, {nj:03d}, {nk:03d}]",
# f"({1 + level}/{nblevels})",
end="\r",
)

# save current chunk
omz["0"][
k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0],
j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1],
i * inp_chunk[2] : i * inp_chunk[2] + loaded_chunk.shape[2],
] = loaded_chunk
for i in range(ni):
for j, k in product(range(nj), range(nk)):
loaded_chunk = inp[i][
...,
k * inp_chunk[0] : (k + 1) * inp_chunk[0],
j * inp_chunk[1] : (j + 1) * inp_chunk[1],
]

print(
f"[{i + 1:03d}, {j + 1:03d}, {k + 1:03d}]",
"/",
f"[{ni:03d}, {nj:03d}, {nk:03d}]",
# f"({1 + level}/{nblevels})",
end="\r",
)

# save current chunk
omz["0"][
...,
k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0],
j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1],
i,
] = loaded_chunk

inp[i] = None # no ref count -> delete array

generate_pyramid(omz, nblevels - 1, mode="mean")

Expand All @@ -234,7 +304,7 @@ def convert(
no_pool=no_pool,
space_unit=ome_unit,
space_scale=vx,
multiscales_type=("2x2x2" if no_pool is None else "2x2") + "mean window",
multiscales_type=(("2x2x2" if no_pool is None else "2x2") + "mean window"),
)

if not nii:
Expand Down