diff --git a/src/sparcscore/pipeline/extraction.py b/src/sparcscore/pipeline/extraction.py index 6191649..ede8a28 100644 --- a/src/sparcscore/pipeline/extraction.py +++ b/src/sparcscore/pipeline/extraction.py @@ -57,7 +57,16 @@ def __init__(self, base_directory = self.directory.replace("/extraction", "") self.input_segmentation_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, self.DEFAULT_SEGMENTATION_FILE) - self.filtered_classes_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "classes.csv") + + #get path to filtered classes + if os.path.isfile(os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "needs_filtering.txt")): + try: + self.filtered_classes_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "filtered/filtered_classes.csv") + except: + raise ValueError("Need to run segmentation_filtering method ") + else: + self.filtered_classes_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "classes.csv") + self.output_path = os.path.join(self.directory, self.DEFAULT_DATA_DIR, self.DEFAULT_DATA_FILE) #extract required information for generating datasets @@ -163,7 +172,11 @@ def parse_remapping(self): def get_classes(self, filtered_classes_path): self.log(f"Loading filtered classes from {filtered_classes_path}") cr = csv.reader(open(filtered_classes_path,'r'), ) - filtered_classes = [int(float(el[0])) for el in list(cr)] + + if "filtered_classes.csv" in filtered_classes_path: + filtered_classes = [el[0] for el in list(cr)] #do not do int transform here as we expect a str of format "nucleus_id:cytosol_id" + else: + filtered_classes = [int(float(el[0])) for el in list(cr)] self.log("Loaded {} filtered classes".format(len(filtered_classes))) filtered_classes = np.unique(filtered_classes) #make sure they are all unique @@ -171,7 +184,7 @@ def get_classes(self, filtered_classes_path): self.log("After removing duplicates {} filtered classes remain.".format(len(filtered_classes))) class_list = list(filtered_classes) - if 0 in class_list: class_list.remove(0) + if 0 in class_list: class_list.remove(0) #remove background if still listed self.num_classes = len(class_list) return(class_list) @@ -290,6 +303,14 @@ def _extract_classes(self, input_segmentation_path, px_center, arg): index, save_index, cell_id, image_index, label_info = self._get_label_info(arg) #label_info not used in base case but relevant for flexibility for other classes + if type(cell_id) == str: + nucleus_id, cytosol_id = cell_id.split(":") + nucleus_id = int(float(nucleus_id)) #convert to int for further processing + cytosol_id = int(float(cytosol_id)) #convert to int for further processing + else: + nucleus_id = cell_id + cytosol_id = cell_id + #generate some progress output every 10000 cells #relevant for benchmarking of time if save_index % 10000 == 0: @@ -321,7 +342,7 @@ def _extract_classes(self, input_segmentation_path, px_center, arg): else: nuclei_mask = hdf_labels[image_index, 0, window_y, window_x] - nuclei_mask = np.where(nuclei_mask == cell_id, 1, 0) + nuclei_mask = np.where(nuclei_mask == nucleus_id, 1, 0) nuclei_mask_extended = gaussian(nuclei_mask, preserve_range=True, sigma=5) nuclei_mask = gaussian(nuclei_mask, preserve_range=True, sigma=1) @@ -344,7 +365,7 @@ def _extract_classes(self, input_segmentation_path, px_center, arg): else: cell_mask = hdf_labels[image_index, 1,window_y,window_x] - cell_mask = np.where(cell_mask == cell_id, 1, 0).astype(int) + cell_mask = np.where(cell_mask == cytosol_id, 1, 0).astype(int) cell_mask = binary_fill_holes(cell_mask) cell_mask_extended = dilation(cell_mask, footprint=disk(6)) diff --git a/src/sparcscore/pipeline/filter_segmentation.py b/src/sparcscore/pipeline/filter_segmentation.py new file mode 100644 index 0000000..5d497d9 --- /dev/null +++ b/src/sparcscore/pipeline/filter_segmentation.py @@ -0,0 +1,346 @@ +import os +import numpy as np +import csv +import h5py +from multiprocessing import Pool +import shutil +import pandas as pd + +import traceback + +from sparcscore.processing.segmentation import sc_any +from sparcscore.pipeline.base import ProcessingStep + +# to show progress +from tqdm.auto import tqdm + +#to perform garbage collection +import gc +import sys + +class SegmentationFilter(ProcessingStep): + """SegmentationFilter helper class used for creating workflows to filter generated segmentation masks before extraction. + + """ + DEFAULT_OUTPUT_FILE = "segmentation.h5" + DEFAULT_FILTER_FILE = "filtered_classes.csv" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + #self.input_path = input_segmentation + self.identifier = None + self.window = None + self.input_path = None + + def read_input_masks(self, input_path): + + with h5py.File(input_path, "r") as hf: + hdf_input = hf.get("labels") + + #use a memory mapped numpy array to save the input image to better utilize memory consumption + from alphabase.io import tempmmap + TEMP_DIR_NAME = tempmmap.redefine_temp_location(self.config["cache"]) + self.TEMP_DIR_NAME = TEMP_DIR_NAME #save for later to be able to remove cached folders + + input_masks = tempmmap.array(shape = hdf_input.shape, dtype = np.uint16) + input_masks = hdf_input[:2,:, :] + + return(input_masks) + + def save_classes(self, classes): + + #define path where classes should be saved + filtered_path = os.path.join(self.directory, self.DEFAULT_FILTER_FILE) + + to_write = "\n".join([f"{str(x)}:{str(y)}" for x, y in classes.items()]) + + with open(filtered_path, "w") as myfile: + myfile.write(to_write) + + self.log(f"Saved nucleus_id:cytosol_id matchings of all cells that passed filtering to {filtered_path}.") + + def initialize_as_tile(self, identifier, window, input_path, zarr_status = True): + """Initialize Filtering Step with further parameters needed for filtering segmentation results. + + Important: + This function is intended for internal use by the :class:`TiledFilterSegmentation` helper class. In most cases it is not relevant to the creation of custom filtering workflows. + + Args: + identifier (int): Unique index of the tile. + window (list(tuple)): Defines the window which is assigned to the tile. The window will be applied to the input. The first element refers to the first dimension of the image and so on. For example use ``[(0,1000),(0,2000)]`` To crop the image to `1000 px height` and `2000 px width` from the top left corner. + input_path (str): Location of the input hdf5 file. During tiled segmentation the :class:`TiledSegmentation` derived helper class will save the input image in form of a hdf5 file. This makes the input image available for parallel reading by the segmentation processes. + """ + self.identifier = identifier + self.window = window + self.input_path = input_path + self.save_zarr = zarr_status + + def call_as_tile(self): + """Wrapper function for calling a tiled segmentation. + + Important: + This function is intended for internal use by the :class:`TiledSegmentation` helper class. In most cases it is not relevant to the creation of custom segmentation workflows. + """ + + with h5py.File(self.input_path, "r") as hf: + hdf_input = hf.get("labels") + + #use a memory mapped numpy array to save the input image to better utilize memory consumption + from alphabase.io import tempmmap + TEMP_DIR_NAME = tempmmap.redefine_temp_location(self.config["cache"]) + + #calculate shape of required datacontainer + c, _, _ = hdf_input.shape + x1 = self.window[0].start + x2 = self.window[0].stop + y1 = self.window[1].start + y2 = self.window[1].stop + + x = x2 - x1 + y = y2 - y1 + + #initialize directory and load data + input_image = tempmmap.array(shape = (2, x, y), dtype = np.uint16) + input_image = hdf_input[:2, self.window[0], self.window[1]] + + #perform check to see if any input pixels are not 0, if so perform segmentation, else return array of zeros. + if sc_any(input_image): + try: + self.log(f"Beginning filtering on tile in position [{self.window[0]}, {self.window[1]}]") + super().__call__(input_image) + except Exception: + self.log(traceback.format_exc()) + else: + print(f"Tile in position [{self.window[0]}, {self.window[1]}] only contained zeroes.") + try: + super().__call_empty__(input_image) + except Exception: + self.log(traceback.format_exc()) + + #cleanup generated temp dir and variables + del input_image + gc.collect() + + #write out window location + self.log(f"Writing out window location to file at {self.directory}/window.csv") + with open(f"{self.directory}/window.csv", "w") as f: + f.write(f"{self.window}\n") + + self.log(f"Filtering of tile with the slicing {self.window} finished.") + + #delete generate temp directory to cleanup space + shutil.rmtree(TEMP_DIR_NAME, ignore_errors=True) + + def get_output(self): + return os.path.join(self.directory, self.DEFAULT_OUTPUT_FILE) + +class TiledSegmentationFilter(SegmentationFilter): + """""" + + DEFAULT_TILES_FOLDER = "tiles" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not hasattr(self, "method"): + raise AttributeError( + "No SegmentationFilter method defined, please set attribute ``method``" + ) + + def initialize_tile_list(self, tileing_plan, input_path): + _tile_list = [] + + self.input_path = input_path + + for i, window in enumerate(tileing_plan): + local_tile_directory = os.path.join(self.tile_directory, str(i)) + current_tile = self.method( + self.config, + local_tile_directory, + project_location = self.project_location, + debug=self.debug, + overwrite=self.overwrite, + intermediate_output=self.intermediate_output, + ) + current_tile.initialize_as_tile(i, window, self.input_path, zarr_status = False) + _tile_list.append(current_tile) + + return _tile_list + + def calculate_tileing_plan(self, mask_size): + #save tileing plan to file + tileing_plan_path = f"{self.directory}/tileing_plan.csv" + + if os.path.isfile(tileing_plan_path): + self.log(f"tileing plan already found in directory {tileing_plan_path}.") + if self.overwrite: + self.log("Overwriting existing tileing plan.") + os.remove(tileing_plan_path) + else: + self.log("Reading existing tileing plan from file.") + with open(tileing_plan_path, "r") as f: + _tileing_plan = [eval(line) for line in f.readlines()] + return(_tileing_plan) + + _tileing_plan = [] + side_size = np.floor(np.sqrt(int(self.config["tile_size"]))) + tiles_side = np.round(mask_size / side_size).astype(int) + tile_size = mask_size // tiles_side + + self.log(f"input image {mask_size[0]} px by {mask_size[1]} px") + self.log(f"target_tile_size: {self.config['tile_size']}") + self.log(f"tileing plan:") + self.log(f"{tiles_side[0]} rows by {tiles_side[1]} columns") + self.log(f"{tile_size[0]} px by {tile_size[1]} px") + + for y in range(tiles_side[0]): + for x in range(tiles_side[1]): + last_row = y == tiles_side[0] - 1 + last_column = x == tiles_side[1] - 1 + + lower_y = y * tile_size[0] + lower_x = x * tile_size[1] + + upper_y = (y + 1) * tile_size[0] + upper_x = (x + 1) * tile_size[1] + + #add px overlap to each tile + lower_y = lower_y - self.config["overlap_px"] + lower_x = lower_x - self.config["overlap_px"] + upper_y = upper_y + self.config["overlap_px"] + upper_x = upper_x + self.config["overlap_px"] + + #make sure that each limit stays within the slides + if lower_y < 0: + lower_y = 0 + if lower_x < 0: + lower_x = 0 + + if last_row: + upper_y = mask_size[0] + + if last_column: + upper_x = mask_size[1] + + tile = (slice(lower_y, upper_y), slice(lower_x, upper_x)) + _tileing_plan.append(tile) + + #write out newly generated tileing plan + with open(tileing_plan_path, "w") as f: + for tile in _tileing_plan: + f.write(f"{tile}\n") + self.log(f"Tileing plan written to file at {tileing_plan_path}") + + return _tileing_plan + + def resolve_tileing(self, tileing_plan): + """ + The function iterates over a tileing plan and generates a converged list of all nucleus_id:cytosol_id matchings. + """ + + self.log("resolve tileing plan and joining generated lists together") + + #initialize empty list to save results to + filtered_classes_combined = [] + + for i, window in enumerate(tileing_plan): + + local_tile_directory = os.path.join(self.tile_directory, str(i)) + local_output = os.path.join(local_tile_directory, self.DEFAULT_OUTPUT_FILE) + local_classes = os.path.join(local_tile_directory, "filtered_classes.csv") + + #check to make sure windows match + with open(f"{local_tile_directory}/window.csv", "r") as f: + window_local = eval(f.read()) + if window_local != window: + self.log("Tileing plans do not match. Aborting run.") + self.log("Tileing plan found locally: ", window_local) + self.log("Tileing plan found in tileing plan: ", window) + sys.exit("tileing plans do not match!") + + cr = csv.reader(open(local_classes, "r")) + filtered_classes = [el[0] for el in list(cr)] + + filtered_classes_combined += filtered_classes + self.log(f"Finished stitching tile {i}") + + #remove duplicates from list (this only removes perfect duplicates) + filtered_classes_combined = list(set(filtered_classes_combined)) + + #perform sanity check that no cytosol_id is listed twice + filtered_classes_combined = {int(k): int(v) for k, v in (s.split(":") for s in filtered_classes_combined)} + if len(filtered_classes_combined.values()) != len(set(filtered_classes_combined.values())): + print(pd.Series(filtered_classes_combined.values()).value_counts()) + print(filtered_classes_combined) + sys.exit("Duplicate values found. Some issues with filtering. Please contact the developers.") + + # save newly generated class list to file + filtered_path = os.path.join(self.directory, self.DEFAULT_FILTER_FILE) + to_write = "\n".join([f"{str(x)}:{str(y)}" for x, y in filtered_classes_combined.items()]) + with open(filtered_path, "w") as myfile: + myfile.write(to_write) + + # Add section here that cleans up the results from the tiles and deletes them to save memory + self.log("Deleting intermediate tile results to free up storage space") + shutil.rmtree(self.tile_directory, ignore_errors=True) + + gc.collect() + + def process(self, input_path): + + self.tile_directory = os.path.join(self.directory, self.DEFAULT_TILES_FOLDER) + + if not os.path.isdir(self.tile_directory): + os.makedirs(self.tile_directory) + self.log("Created new tile directory " + self.tile_directory) + + # calculate tileing plan + with h5py.File(input_path, "r") as hf: + self.mask_size = hf["labels"].shape[1:] + + if self.config["tile_size"] >= np.prod(self.mask_size): + target_size = self.config["tile_size"] + self.log(f"target size {target_size} is equal or larger to input mask {np.prod(self.mask_size)}. Tileing will not be used.") + + tileing_plan = [ + (slice(0, self.mask_size[0]), slice(0, self.mask_size[1])) + ] + + else: + target_size = self.config["tile_size"] + self.log(f"target size {target_size} is smaller than input mask {np.prod(self.mask_size)}. Tileing will be used.") + tileing_plan = self.calculate_tileing_plan(self.mask_size) + + #save tileing plan to file to be able to reload later + self.log(f"Saving Tileing plan to file: {self.directory}/tileing_plan.csv") + with open(f"{self.directory}/tileing_plan.csv", "w") as f: + for tile in tileing_plan: + f.write(f"{tile}\n") + + tile_list = self.initialize_tile_list(tileing_plan, input_path) + + self.log( + f"tileing plan with {len(tileing_plan)} elements generated, tileing with {self.config['threads']} threads begins" + ) + + with Pool(processes=self.config['threads']) as pool: + results = list( + tqdm( + pool.imap(self.method.call_as_tile, tile_list), + total=len(tile_list), + ) + ) + pool.close() + pool.join() + print("All Filtering Steps are done.", flush=True) + + #free up memory + del tile_list + gc.collect() + + self.log("Finished tiled filtering.") + self.resolve_tileing(tileing_plan) + + #make sure to cleanup temp directories + self.log("=== finished filtering === ") \ No newline at end of file diff --git a/src/sparcscore/pipeline/filtering_workflows.py b/src/sparcscore/pipeline/filtering_workflows.py new file mode 100644 index 0000000..dae1e95 --- /dev/null +++ b/src/sparcscore/pipeline/filtering_workflows.py @@ -0,0 +1,126 @@ +from sparcscore.pipeline.filter_segmentation import ( + SegmentationFilter, + TiledSegmentationFilter +) + +import numpy as np +from tqdm.auto import tqdm +import shutil +from collections import defaultdict + +from sparcscore.processing.preprocessing import downsample_img_pxs + +class BaseFiltering(SegmentationFilter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_unique_ids(self, mask): + return(np.unique(mask)[1:]) + + def return_empty_mask(self, input_image): + #write out an empty entry + self.save_classes(classes = {}) + +class filtering_match_nucleus_to_cytosol(BaseFiltering): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def match_nucleus_id_to_cytosol(self, nucleus_mask, cytosol_mask, return_ids_to_discard = False): + all_nucleus_ids = self.get_unique_ids(nucleus_mask) + all_cytosol_ids = self.get_unique_ids(cytosol_mask) + + nucleus_cytosol_pairs = {} + nuclei_ids_to_discard = [] + + for nucleus_id in tqdm(all_nucleus_ids): + # get the nucleus and set the background to 0 and the nucleus to 1 + nucleus = (nucleus_mask == nucleus_id) + + # now get the coordinates of the nucleus + nucleus_pixels = np.nonzero(nucleus) + + # check if those indices are not background in the cytosol mask + potential_cytosol = cytosol_mask[nucleus_pixels] + + #if there is a cytosolID in the area of the nucleus proceed, else continue with a new nucleus + if np.all(potential_cytosol != 0): + + unique_cytosol, counts = np.unique( + potential_cytosol, return_counts=True + ) + all_counts = np.sum(counts) + cytosol_proportions = counts / all_counts + + if np.any(cytosol_proportions >= self.config["filtering_threshold"]): + # get the cytosol_id with max proportion + cytosol_id = unique_cytosol[ + np.argmax(cytosol_proportions >= self.config["filtering_threshold"]) + ] + nucleus_cytosol_pairs[nucleus_id] = cytosol_id + else: + #no cytosol found with sufficient quality to call so discard nucleus + nuclei_ids_to_discard.append(nucleus_id) + + else: + #discard nucleus as no matching cytosol found + nuclei_ids_to_discard.append(nucleus_id) + + #check to ensure that only one nucleus_id is assigned to each cytosol_id + cytosol_count = defaultdict(int) + + # Count the occurrences of each cytosol value + for cytosol in nucleus_cytosol_pairs.values(): + cytosol_count[cytosol] += 1 + + # Find cytosol values assigned to more than one nucleus and remove from dictionary + multi_nucleated_nulceus_ids = [] + + for nucleus, cytosol in nucleus_cytosol_pairs.items(): + if cytosol_count[cytosol] > 1: + multi_nucleated_nulceus_ids.append(nucleus) + + #update list of all nuclei used + nuclei_ids_to_discard.append(multi_nucleated_nulceus_ids) + + #remove entries from dictionary + # this needs to be put into a seperate loop because otherwise the dictionary size changes during loop and this throws an error + for nucleus in multi_nucleated_nulceus_ids: + del nucleus_cytosol_pairs[nucleus] + + #get all cytosol_ids that need to be discarded + used_cytosol_ids = set(nucleus_cytosol_pairs.values()) + not_used_cytosol_ids = set(all_cytosol_ids) - used_cytosol_ids + not_used_cytosol_ids = list(not_used_cytosol_ids) + + if return_ids_to_discard: + return(nucleus_cytosol_pairs, nuclei_ids_to_discard, not_used_cytosol_ids) + else: + return(nucleus_cytosol_pairs) + + def process(self, input_masks): + + if type(input_masks) == str: + input_masks = self.read_input_masks(input_masks) + + #allow for optional downsampling to improve computation time + if "downsampling_factor" in self.config.keys(): + N = self.config["downsampling_factor"] + #use a less precise but faster downsampling method that preserves integer values + input_masks = downsample_img_pxs(input_masks, N= N) + + #get input masks + nucleus_mask = input_masks[0, :, :] + cytosol_mask = input_masks[1, :, :] + + nucleus_cytosol_pairs = self.match_nucleus_id_to_cytosol(nucleus_mask, cytosol_mask) + + #save results + self.save_classes(classes = nucleus_cytosol_pairs) + + #cleanup TEMP directories if not done during individual tile runs + if hasattr(self, "TEMP_DIR_NAME"): + shutil.rmtree(self.TEMP_DIR_NAME) + +class multithreaded_filtering_match_nucleus_to_cytosol(TiledSegmentationFilter): + method = filtering_match_nucleus_to_cytosol \ No newline at end of file diff --git a/src/sparcscore/pipeline/project.py b/src/sparcscore/pipeline/project.py index e1b4556..c49f354 100644 --- a/src/sparcscore/pipeline/project.py +++ b/src/sparcscore/pipeline/project.py @@ -54,6 +54,7 @@ class Project(Logable): """ DEFAULT_CONFIG_NAME = "config.yml" DEFAULT_SEGMENTATION_DIR_NAME = "segmentation" + DEFAULT_SEGMENTATION_FILTERING_DIR_NAME = "segmentation/filtering" DEFAULT_EXTRACTION_DIR_NAME = "extraction" DEFAULT_CLASSIFICATION_DIR_NAME = "classification" DEFAULT_SELECTION_DIR_NAME = "selection" @@ -79,6 +80,7 @@ def __init__( debug=False, overwrite=False, segmentation_f=None, + segmentation_filtering_f = None, extraction_f=None, classification_f=None, selection_f=None, @@ -91,6 +93,7 @@ def __init__( self.intermediate_output = intermediate_output self.segmentation_f = segmentation_f + self.segmentation_filtering_f = segmentation_filtering_f self.extraction_f = extraction_f self.classification_f = classification_f self.selection_f = selection_f @@ -160,6 +163,26 @@ def __init__( else: self.segmentation_f = None + # ==== setup filtering of segmentation ==== + if segmentation_filtering_f is not None: + if segmentation_filtering_f.__name__ not in self.config: + raise ValueError( + f"Config for {segmentation_filtering_f.__name__} is missing from the config file" + ) + + filter_seg_directory = os.path.join( + self.project_location, self.DEFAULT_SEGMENTATION_FILTERING_DIR_NAME + ) + + self.segmentation_filtering_f = segmentation_filtering_f( + self.config[segmentation_filtering_f.__name__], + filter_seg_directory, + project_location = self.project_location, + debug=self.debug, + overwrite=self.overwrite, + intermediate_output=self.intermediate_output, + ) + # === setup extraction === if extraction_f is not None: extraction_directory = os.path.join( @@ -568,7 +591,6 @@ def segment(self, *args, **kwargs): self.log("No input image loaded. Trying to read file from disk.") try: self.load_input_image() - self.segmentation_f(self.input_image, *args, **kwargs) except: raise ValueError("No input image loaded and no file found to load image from.") self.segmentation_f(self.input_image, *args, **kwargs) @@ -588,13 +610,24 @@ def complete_segmentation(self, *args, **kwargs): self.log("No input image loaded. Trying to read file from disk.") try: self.load_input_image() - self.segmentation_f.complete_segmentation(self.input_image, *args, **kwargs) except: raise ValueError("No input image loaded and no file found to load image from.") self.segmentation_f.complete_segmentation(self.input_image, *args, **kwargs) elif self.input_image is not None: self.segmentation_f.complete_segmentation(self.input_image, *args, **kwargs) + + def filter_segmentation(self, *args, **kwargs): + """execute workflow to run filtering on generated segmentation masks to only select those cells that + fulfill the filtering criteria + """ + self.log("Filtering generated segmentation masks for cells that fulfill the required criteria") + + if self.segmentation_filtering_f is None: + raise ValueError("No filtering method for refining segmentation masks defined.") + + input_segmentation = self.segmentation_f.get_output() + self.segmentation_filtering_f(input_segmentation, *args, **kwargs) def extract(self, *args, **kwargs): """ diff --git a/src/sparcscore/pipeline/segmentation.py b/src/sparcscore/pipeline/segmentation.py index 8df23e1..ebe06ab 100644 --- a/src/sparcscore/pipeline/segmentation.py +++ b/src/sparcscore/pipeline/segmentation.py @@ -37,6 +37,7 @@ class Segmentation(ProcessingStep): DEFAULT_OUTPUT_FILE (str, default ``segmentation.h5``) DEFAULT_FILTER_FILE (str, default ``classes.csv``) + DEFAULT_FILTER_ADDTIONAL_FILE (str, default ``filtered_classes.csv``) PRINT_MAPS_ON_DEBUG (bool, default ``False``) identifier (int, default ``None``): Only set if called by :class:`ShardedSegmentation`. Unique index of the shard. @@ -65,6 +66,7 @@ def process(self): """ DEFAULT_OUTPUT_FILE = "segmentation.h5" DEFAULT_FILTER_FILE = "classes.csv" + DEFAULT_FILTER_ADDTIONAL_FILE = "needs_additional_filtering.txt" PRINT_MAPS_ON_DEBUG = True DEFAULT_INPUT_IMAGE_NAME = "input_image.ome.zarr" @@ -223,7 +225,21 @@ def save_segmentation(self, channels, labels, classes): # save classes self.save_classes(classes) + #check filter status in config + if "filter_status" in self.config.keys(): + filter_status = self.config["filter_status"] + else: + filter_status = True #always assumes that filtering is performed by default. Needs to be manually turned off if not desired. + + if not filter_status: + #define path where the empty file should be generated + filtered_path = os.path.join(self.directory, self.DEFAULT_FILTER_ADDTIONAL_FILE) + with open(filtered_path, "w") as myfile: + myfile.write("\n") + + self.log(f"Generated empty file at {filtered_path}. This marks that no filtering has been performed during segmentation and an additional step needs to be performed to populate this file with nucleus_id:cytosol_id matchings before running the extraction.") + self.log("=== finished segmentation ===") self.save_segmentation_zarr(labels = labels) @@ -993,6 +1009,11 @@ def initializer_function(gpu_id_list): #make sure to cleanup temp directories self.log("=== completed segmentation === ") + +############################################# +###### TIMECOURSE/BATCHED METHODS ########### +############################################# + class TimecourseSegmentation(Segmentation): """Segmentation helper class used for creating segmentation workflows working with timecourse data."""