Skip to content

Commit

Permalink
Merge pull request #34 from dominikl/cleanup_2
Browse files Browse the repository at this point in the history
Remove duplicated code
  • Loading branch information
joshmoore authored Oct 22, 2020
2 parents 49d4308 + dbec40a commit bef531c
Showing 1 changed file with 63 additions and 83 deletions.
146 changes: 63 additions & 83 deletions src/omero_zarr/raw_pixels.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os
import time
from typing import Any, Dict
from typing import Any, Dict, Optional

import cv2
import numpy
Expand All @@ -12,88 +12,51 @@


def image_to_zarr(image: omero.gateway.ImageWrapper, args: argparse.Namespace) -> None:

cache_numpy = args.cache_numpy
target_dir = args.output
cache_dir = target_dir if args.cache_numpy else None

size_c = image.getSizeC()
size_z = image.getSizeZ()
size_x = image.getSizeX()
size_y = image.getSizeY()
size_t = image.getSizeT()

# dir for caching .npy planes
if cache_numpy:
os.makedirs(os.path.join(target_dir, str(image.id)), mode=511, exist_ok=True)
name = os.path.join(target_dir, "%s.zarr" % image.id)
za = None
pixels = image.getPrimaryPixels()

zct_list = []
for t in range(size_t):
for c in range(size_c):
for z in range(size_z):
# We only want to load from server if not cached locally
filename = os.path.join(
target_dir, str(image.id), f"{z:03d}-{c:03d}-{t:03d}.npy",
)
if not os.path.exists(filename):
zct_list.append((z, c, t))

def planeGen() -> np.ndarray:
planes = pixels.getPlanes(zct_list)
yield from planes

planes = planeGen()

for t in range(size_t):
for c in range(size_c):
for z in range(size_z):
filename = os.path.join(
target_dir, str(image.id), f"{z:03d}-{c:03d}-{t:03d}.npy",
)
if os.path.exists(filename):
print(f"plane (from disk) c:{c}, t:{t}, z:{z}")
plane = numpy.load(filename)
else:
print(f"loading plane c:{c}, t:{t}, z:{z}")
plane = next(planes)
if cache_numpy:
print(f"cached at {filename}")
numpy.save(filename, plane)
if za is None:
# store = zarr.NestedDirectoryStore(name)
# root = zarr.group(store=store, overwrite=True)
root = open_group(name, mode="w")
za = root.create(
"0",
shape=(size_t, size_c, size_z, size_y, size_x),
chunks=(1, 1, 1, size_y, size_x),
dtype=plane.dtype,
)
za[t, c, z, :, :] = plane
add_group_metadata(root, image)
print("Created", name)
print(f"Exporting to {name}")
root = open_group(name, mode="w")
n_levels = add_image(image, root, cache_dir=cache_dir)
add_group_metadata(root, image, n_levels)
print("Finished.")


def add_image(
image: omero.gateway.ImageWrapper, parent: Group, field_index: str = "0"
) -> None:
"""Adds the image pixel data as array to the given parent zarr group."""
image: omero.gateway.ImageWrapper, parent: Group, cache_dir: Optional[str] = None
) -> int:
""" Adds the image pixel data as array to the given parent zarr group.
Optionally caches the pixel data in the given cache_dir directory.
Returns the number of resolution levels generated for the image.
"""
if cache_dir is not None:
cache = True
os.makedirs(os.path.join(cache_dir, str(image.id)), mode=511, exist_ok=True)
else:
cache = False
cache_dir = ""

size_c = image.getSizeC()
size_z = image.getSizeZ()
size_x = image.getSizeX()
size_y = image.getSizeY()
size_t = image.getSizeT()
d_type = image.getPixelsType()

field_group = parent.require_group(field_index)

zct_list = []
for t in range(size_t):
for c in range(size_c):
for z in range(size_z):
zct_list.append((z, c, t))
if cache:
# We only want to load from server if not cached locally
filename = os.path.join(
cache_dir, str(image.id), f"{z:03d}-{c:03d}-{t:03d}.npy",
)
if not os.path.exists(filename):
zct_list.append((z, c, t))
else:
zct_list.append((z, c, t))

pixels = image.getPrimaryPixels()

Expand All @@ -111,20 +74,28 @@ def planeGen() -> np.ndarray:
longest = longest // 2
level_count += 1

add_group_metadata(field_group, image, level_count)

field_groups = []
for t in range(size_t):
for c in range(size_c):
for z in range(size_z):
plane = next(planes)
if cache:
filename = os.path.join(
cache_dir, str(image.id), f"{z:03d}-{c:03d}-{t:03d}.npy",
)
if os.path.exists(filename):
plane = numpy.load(filename)
else:
plane = next(planes)
numpy.save(filename, plane)
else:
plane = next(planes)
for level in range(level_count):
size_y = plane.shape[0]
size_x = plane.shape[1]
# If on first plane, create a new group for this resolution level
if t == 0 and c == 0 and z == 0:
field_groups.append(
field_group.create(
parent.create(
str(level),
shape=(size_t, size_c, size_z, size_y, size_x),
chunks=(1, 1, 1, size_y, size_x),
Expand All @@ -142,17 +113,7 @@ def planeGen() -> np.ndarray:
dsize=(size_x // 2, size_y // 2),
interpolation=cv2.INTER_NEAREST,
)


def print_status(t0: float, t: float, count: int, total: int) -> None:
""" Prints percent done and ETA """
percent_done = count * 100 / total
rate = count / (t - t0)
eta = (total - count) / rate
status = "{:.2f}% done, ETA: {}".format(
percent_done, time.strftime("%H:%M:%S", time.gmtime(eta))
)
print(status, end="\r", flush=True)
return level_count


def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace) -> None:
Expand All @@ -167,6 +128,7 @@ def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace)
total = n_rows * n_cols * (n_fields[1] - n_fields[0] + 1)

target_dir = args.output
cache_dir = target_dir if args.cache_numpy else None
name = os.path.join(target_dir, "%s.zarr" % plate.id)
print(f"Exporting to {name}")
root = open_group(name, mode="w")
Expand All @@ -190,11 +152,29 @@ def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace)
ac_group = root.require_group(ac_name)
row_group = ac_group.require_group(row)
col_group = row_group.require_group(col)
add_image(img, col_group, field_name)
print_status(t0, time.time(), count, total)
field_group = col_group.require_group(field_name)
n_levels = add_image(img, field_group, cache_dir=cache_dir)
add_group_metadata(field_group, img, n_levels)
print_status(int(t0), int(time.time()), count, total)
print("Finished.")


def print_status(t0: int, t: int, count: int, total: int) -> None:
""" Prints percent done and ETA.
t0: start timestamp in seconds
t: current timestamp in seconds
count: number of tasks done
total: total number of tasks
"""
percent_done = float(count) * 100 / total
rate = float(count) / (t - t0)
eta = float(total - count) / rate
status = "{:.2f}% done, ETA: {}".format(
percent_done, time.strftime("%H:%M:%S", time.gmtime(eta))
)
print(status, end="\r", flush=True)


def add_group_metadata(
zarr_root: Group, image: omero.gateway.ImageWrapper, resolutions: int = 1
) -> None:
Expand Down

0 comments on commit bef531c

Please sign in to comment.