Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mmapped data loader #23

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ imageio
optax
tqdm
pandas
matplotlib
matplotlib
mrcfile
2 changes: 1 addition & 1 deletion src/kompressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
# SOFTWARE.


from kompressor import image, volume, mapping, utils
from kompressor import image, volume, mapping, utils, dataloaders

VERSION = 'v1.0a'
1 change: 1 addition & 0 deletions src/kompressor/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import mrc_dataloader
56 changes: 56 additions & 0 deletions src/kompressor/dataloaders/mrc_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import mrcfile
from glob import glob
from pathlib import Path
from typing import List, Tuple

import numpy as np

Paths = str
Frame = int
DataPathFrame = Tuple[Paths, Frame]


def get_data_paths_and_frames(files: List[str]) -> List[DataPathFrame]:
"""Gets the data paths and frames from the list of files provided.

Args:
files: The files to be compressed.

Returns:
A list of paths and frames.

Example:
If we have an MRC file file shape 1,2,3 at /tmp/0.mrc then::

data = get_data_paths_and_frames("/tmp/0.mrc)

data will be:
[("/tmp/0.mrc",0), ("/tmp/0.mrc",1),("/tmp/0.mrc",2)]
"""
data_paths = []
for file in files:
file = Path(file)
assert file.is_file(), f"{file} is not a file."
for frame in range(mrcfile.mmap(file).data.shape[-1]):
data_paths.append((file, frame))
return data_paths


def decode_mrc_data_path(data_paths: DataPathFrame) -> np.array:
"""
Decode the MRC data path returning the slice of the data speficifed.

Args:
data_paths: A List containing the [Path, Frame] tuples

Returns:
A Numpy array containing the decoded data along with an additional axis to create Height, Width, Channel.

Examples:
If we have an Data path array of ("/tmp/0.mrc",0) with shape (5,5) this will
return the numpy array with the shape (5,5,1)
"""
data = np.array(mrcfile.mmap(data_paths[0]).data[..., data_paths[1]])
if data.ndim == 2:
return np.expand_dims(data, axis=-1)
return data
Empty file added tests/dataloaders/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions tests/dataloaders/test_mrc_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import random
import shutil
import unittest
import uuid
import mrcfile
import numpy as np

from pathlib import Path

# Test imports
import kompressor as kom

# This ensures all strings generated by uuid are repeatable.
random_generator = random.Random()
random_generator.seed(0)
X_AXIS = 5
Y_AXIS = 4
Z_AXIS = 3
unique_id = uuid.UUID(int=random_generator.getrandbits(128), version=4)
TEST_FOLDER = f"/tmp/kompressor_data_loader_test{unique_id}"


def delete_mrc_folder():
if os.path.exists(TEST_FOLDER):
shutil.rmtree(TEST_FOLDER)


class MRCDataloaderTests(unittest.TestCase):
def setUp(self):
delete_mrc_folder()
os.mkdir(TEST_FOLDER)
for i in range(1):
with mrcfile.new(f"{TEST_FOLDER}/{i}.mrc") as mrc:
mrc.set_data(
np.random.random(size=(X_AXIS, Y_AXIS, Z_AXIS)).astype(np.float32)
)

def tearDown(self):
delete_mrc_folder()

def test_mrc_data_loader(self):
files = [TEST_FOLDER + "/0.mrc"]
data_paths_and_frames = (
kom.dataloaders.mrc_dataloader.get_data_paths_and_frames(files)
)
for frame in range(Z_AXIS):
assert data_paths_and_frames[frame][0] == Path(files[0])
assert data_paths_and_frames[frame][1] == frame

with self.assertRaises(Exception) as context:
kom.dataloaders.mrc_dataloader.get_data_paths_and_frames(["Non_File"])
self.assertTrue("Non_File is not a file" in str(context.exception))

def test_decode_mrc_data_path(self):
files = [TEST_FOLDER + "/0.mrc"]
data_paths_and_frames = (
kom.dataloaders.mrc_dataloader.get_data_paths_and_frames(files)
)
for data_paths in data_paths_and_frames:
data = kom.dataloaders.mrc_dataloader.decode_mrc_data_path(data_paths)
assert data.shape == (X_AXIS, Y_AXIS, 1)
assert data.dtype == np.float32
Loading