Skip to content

Commit

Permalink
use parameters to design group names in HDF5
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed May 27, 2024
1 parent ce9b509 commit 886083c
Showing 1 changed file with 33 additions and 28 deletions.
61 changes: 33 additions & 28 deletions src/sparcscore/pipeline/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def process(self):
DEFAULT_FILTER_FILE = "classes.csv"
DEFAULT_FILTER_ADDTIONAL_FILE = "needs_additional_filtering.txt"
PRINT_MAPS_ON_DEBUG = True
DEFAULT_CHANNELS_NAME = "channels"
DEFAULT_MASK_NAME = "labels"

DEFAULT_INPUT_IMAGE_NAME = "input_image.ome.zarr"

Expand Down Expand Up @@ -158,7 +160,7 @@ def call_as_shard(self):
self.log(f"Beginning Segmentation of Shard with the slicing {self.window}")

with h5py.File(self.input_path, "r") as hf:
hdf_input = hf.get("channels")
hdf_input = hf.get(self.DEFAULT_CHANNELS_NAME)

#calculate shape of required datacontainer
c, _, _ = hdf_input.shape
Expand Down Expand Up @@ -225,28 +227,28 @@ def save_segmentation(self, channels, labels, classes):
hf = h5py.File(map_path, "a")

#check if data container already exists and if so delete
if "labels" in hf.keys():
del hf["labels"]
if self.DEFAULT_MASK_NAME in hf.keys():
del hf[self.DEFAULT_MASK_NAME]
self.log(
"labels dataset already existed in hdf5, dataset was deleted and will be overwritten."
)

hf.create_dataset(
"labels",
self.DEFAULT_MASK_NAME,
data=labels,
chunks=(1, self.config["chunk_size"], self.config["chunk_size"]),
)

#check if data container already exists and if so delete
if "channels" in hf.keys():
del hf["channels"]
if self.DEFAULT_CHANNELS_NAME in hf.keys():
del hf[self.DEFAULT_CHANNELS_NAME]
self.log(
"channels dataset already existed in hdf5, dataset was deleted and will be overwritten."
)

#also save channels
hf.create_dataset(
"channels",
self.DEFAULT_CHANNELS_NAME,
data=channels,
chunks=(1, self.config["chunk_size"], self.config["chunk_size"]),
)
Expand Down Expand Up @@ -275,7 +277,7 @@ def save_segmentation_zarr(self, labels = None):

#check if segmentation names already exist if so delete
for seg_names in segmentation_names:
path = os.path.join(self.project_location, self.DEFAULT_INPUT_IMAGE_NAME, "labels", seg_names)
path = os.path.join(self.project_location, self.DEFAULT_INPUT_IMAGE_NAME, self.DEFAULT_MASK_NAME, seg_names)
if os.path.isdir(path):
shutil.rmtree(path)
self.log(f"removed existing {seg_names} segmentation from ome.zarr")
Expand All @@ -285,7 +287,7 @@ def save_segmentation_zarr(self, labels = None):
path_labels = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)

with h5py.File(path_labels, "r") as hf:
labels = hf["labels"][:]
labels = hf[self.DEFAULT_MASK_NAME][:]

segmentations = [np.expand_dims(seg, axis = 0) for seg in labels]

Expand Down Expand Up @@ -469,7 +471,7 @@ def save_input_image(self, input_image):
with h5py.File(output, "w") as hf:

hdf_channels = hf.create_dataset(
"channels",
self.DEFAULT_CHANNELS_NAME,
data = input_image,
chunks=(1, self.config["chunk_size"], self.config["chunk_size"]),
dtype="uint16",
Expand Down Expand Up @@ -497,14 +499,14 @@ def save_segmentation(self, channels, labels, classes):
hf = h5py.File(map_path, "w")

#check if data container already exists and if so delete
if "labels" in hf.keys():
del hf["labels"]
if self.DEFAULT_MASK_NAME in hf.keys():
del hf[self.DEFAULT_MASK_NAME]
self.log(
"labels dataset already existed in hdf5, dataset was deleted and will be overwritten."
)

hf.create_dataset(
"labels",
self.DEFAULT_MASK_NAME,
data=labels,
chunks=(1, self.config["chunk_size"], self.config["chunk_size"]),
)
Expand Down Expand Up @@ -669,20 +671,20 @@ def resolve_sharding(self, sharding_plan):

with h5py.File(output, "a") as hf:
#check if data container already exists and if so delete
if "labels" in hf.keys():
del hf["labels"]
if self.DEFAULT_MASK_NAME in hf.keys():
del hf[self.DEFAULT_MASK_NAME]
self.log(
"labels dataset already existed in hdf5, dataset was deleted and will be overwritten."
)

hdf_labels = hf.create_dataset(
"labels",
self.DEFAULT_MASK_NAME,
label_size,
chunks=(1, self.config["chunk_size"], self.config["chunk_size"]),
dtype="int32",
)

hdf_channels = hf.get("channels")
hdf_channels = hf.get(self.DEFAULT_CHANNELS_NAME)

class_id_shift = 0

Expand Down Expand Up @@ -715,7 +717,7 @@ def resolve_sharding(self, sharding_plan):
]

local_hf = h5py.File(local_output, "r")
local_hdf_labels = local_hf.get("labels")
local_hdf_labels = local_hf.get(self.DEFAULT_MASK_NAME)

shifted_map, edge_labels = shift_labels(
local_hdf_labels, class_id_shift, return_shifted_labels=True
Expand Down Expand Up @@ -803,9 +805,9 @@ def resolve_sharding(self, sharding_plan):

#reading labels
path_labels = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)

with h5py.File(path_labels, "r") as hf:
labels = hf["labels"][:]
labels = hf[self.DEFAULT_MASK_NAME][:]

self.save_segmentation_zarr(labels = labels)
self.log("finished saving segmentation results to ome.zarr from sharded segmentation.")
Expand Down Expand Up @@ -1031,6 +1033,9 @@ class TimecourseSegmentation(Segmentation):
DEFAULT_OUTPUT_FILE = "input_segmentation.h5"
DEFAULT_INPUT_IMAGE_NAME = "input_segmentation.h5"
PRINT_MAPS_ON_DEBUG = True
DEFAULT_CHANNELS_NAME = "input_images"
DEFAULT_MASK_NAME = "segmentation"

channel_colors = [
"#e60049",
"#0bb4ff",
Expand Down Expand Up @@ -1078,7 +1083,7 @@ def call_as_shard(self):
"""
with h5py.File(self.input_path, "r") as hf:
hdf_input = hf.get("input_images")
hdf_input = hf.get(self.DEFAULT_CHANNELS_NAME)

if type(self.index) == int:
self.index = [self.index]
Expand Down Expand Up @@ -1139,21 +1144,21 @@ def _transfer_tempmmap_to_hdf5(self):
# create hdf5 datasets with temp_arrays as input
with h5py.File(input_path, "a") as hf:
# check if dataset already exists if so delete and overwrite
if "segmentation" in hf.keys():
del hf["segmentation"]
if self.DEFAULT_MASK_NAME in hf.keys():
del hf[self.DEFAULT_MASK_NAME]
self.log(
"segmentation dataset already existe in hdf5, deleted and overwritten."
)
hf.create_dataset(
"segmentation",
self.DEFAULT_MASK_NAME,
shape=_tmp_seg.shape,
chunks=(1, 2, self.shape_input_images[2], self.shape_input_images[3]),
dtype="uint32",
)

#using this loop structure ensures that not all results are loaded in memory at any one timepoint
for i in range(_tmp_seg.shape[0]):
hf["segmentation"][i] = _tmp_seg[i]
hf[self.DEFAULT_MASK_NAME][i] = _tmp_seg[i]

dt = h5py.special_dtype(vlen=np.dtype("uint32"))

Expand Down Expand Up @@ -1204,7 +1209,7 @@ def adjust_segmentation_indexes(self):
path = os.path.join(self.directory, self.DEFAULT_INPUT_IMAGE_NAME)

with h5py.File(path, "a") as hf:
hdf_labels = hf.get("segmentation")
hdf_labels = hf.get(self.DEFAULT_MASK_NAME)
hdf_classes = hf.get("classes")

class_id_shift = 0
Expand Down Expand Up @@ -1279,7 +1284,7 @@ def process(self):
input_path = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)

with h5py.File(input_path, "r") as hf:
input_images = hf.get("input_images")
input_images = hf.get(self.DEFAULT_CHANNELS_NAME)
indexes = list(range(0, input_images.shape[0]))

# initialize segmentation dataset
Expand Down Expand Up @@ -1340,7 +1345,7 @@ def process(self):
input_path = os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE)

with h5py.File(input_path, "r") as hf:
input_images = hf.get("input_images")
input_images = hf.get(self.DEFAULT_CHANNELS_NAME)
indexes = list(range(0, input_images.shape[0]))

# initialize segmentation dataset
Expand Down

0 comments on commit 886083c

Please sign in to comment.