Skip to content

Commit

Permalink
add framework to implement methods to filter a segmentatin mask to on…
Browse files Browse the repository at this point in the history
…ly include cells that pass certain filtering thresholds

this framework currently does not amend the segmentation masks but outputs a new classes file which contains nucleus_id:cytosol_id matching of all cells that pass filtering. This file is read in during extraction and used to update the information for which cells are extracted from the dataset.
  • Loading branch information
sophiamaedler committed Jan 5, 2024
1 parent fbbc756 commit cd107b1
Show file tree
Hide file tree
Showing 5 changed files with 554 additions and 7 deletions.
31 changes: 26 additions & 5 deletions src/sparcscore/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,16 @@ def __init__(self,
base_directory = self.directory.replace("/extraction", "")

self.input_segmentation_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, self.DEFAULT_SEGMENTATION_FILE)
self.filtered_classes_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "classes.csv")

#get path to filtered classes
if os.path.isfile(os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "needs_filtering.txt")):
try:
self.filtered_classes_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "filtered/filtered_classes.csv")
except:
raise ValueError("Need to run segmentation_filtering method ")
else:
self.filtered_classes_path = os.path.join(base_directory, self.DEFAULT_SEGMENTATION_DIR, "classes.csv")

self.output_path = os.path.join(self.directory, self.DEFAULT_DATA_DIR, self.DEFAULT_DATA_FILE)

#extract required information for generating datasets
Expand Down Expand Up @@ -163,15 +172,19 @@ def parse_remapping(self):
def get_classes(self, filtered_classes_path):
self.log(f"Loading filtered classes from {filtered_classes_path}")
cr = csv.reader(open(filtered_classes_path,'r'), )
filtered_classes = [int(float(el[0])) for el in list(cr)]

if "filtered_classes.csv" in filtered_classes_path:
filtered_classes = [el[0] for el in list(cr)] #do not do int transform here as we expect a str of format "nucleus_id:cytosol_id"
else:
filtered_classes = [int(float(el[0])) for el in list(cr)]

self.log("Loaded {} filtered classes".format(len(filtered_classes)))
filtered_classes = np.unique(filtered_classes) #make sure they are all unique
filtered_classes.astype(np.uint64)
self.log("After removing duplicates {} filtered classes remain.".format(len(filtered_classes)))

class_list = list(filtered_classes)
if 0 in class_list: class_list.remove(0)
if 0 in class_list: class_list.remove(0) #remove background if still listed
self.num_classes = len(class_list)

return(class_list)
Expand Down Expand Up @@ -290,6 +303,14 @@ def _extract_classes(self, input_segmentation_path, px_center, arg):

index, save_index, cell_id, image_index, label_info = self._get_label_info(arg) #label_info not used in base case but relevant for flexibility for other classes

if type(cell_id) == str:
nucleus_id, cytosol_id = cell_id.split(":")
nucleus_id = int(float(nucleus_id)) #convert to int for further processing
cytosol_id = int(float(cytosol_id)) #convert to int for further processing
else:
nucleus_id = cell_id
cytosol_id = cell_id

#generate some progress output every 10000 cells
#relevant for benchmarking of time
if save_index % 10000 == 0:
Expand Down Expand Up @@ -321,7 +342,7 @@ def _extract_classes(self, input_segmentation_path, px_center, arg):
else:
nuclei_mask = hdf_labels[image_index, 0, window_y, window_x]

nuclei_mask = np.where(nuclei_mask == cell_id, 1, 0)
nuclei_mask = np.where(nuclei_mask == nucleus_id, 1, 0)

nuclei_mask_extended = gaussian(nuclei_mask, preserve_range=True, sigma=5)
nuclei_mask = gaussian(nuclei_mask, preserve_range=True, sigma=1)
Expand All @@ -344,7 +365,7 @@ def _extract_classes(self, input_segmentation_path, px_center, arg):
else:
cell_mask = hdf_labels[image_index, 1,window_y,window_x]

cell_mask = np.where(cell_mask == cell_id, 1, 0).astype(int)
cell_mask = np.where(cell_mask == cytosol_id, 1, 0).astype(int)
cell_mask = binary_fill_holes(cell_mask)

cell_mask_extended = dilation(cell_mask, footprint=disk(6))
Expand Down
Loading

0 comments on commit cd107b1

Please sign in to comment.