From e7b441b0f0af782ea1c6339d77061d321ef76dbb Mon Sep 17 00:00:00 2001 From: Mehrtash Babadi Date: Thu, 12 Sep 2024 17:24:35 +0000 Subject: [PATCH] expose root node --- cellarium/cas/postprocessing/ontology_aware.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/cellarium/cas/postprocessing/ontology_aware.py b/cellarium/cas/postprocessing/ontology_aware.py index 4f60e7b..c809ab1 100644 --- a/cellarium/cas/postprocessing/ontology_aware.py +++ b/cellarium/cas/postprocessing/ontology_aware.py @@ -300,10 +300,14 @@ def _get_subtree_phyloxml_string(subtree_dict: OrderedDict, node_name: str, leve def get_most_granular_top_k_calls( - aggregated_scores: AggregatedCellOntologyScores, cl: CellOntologyCache, min_acceptable_score: float, top_k: int = 1 + aggregated_scores: AggregatedCellOntologyScores, + cl: CellOntologyCache, + min_acceptable_score: float, + top_k: int = 1, + root_note: str = CL_EUKARYOTIC_CELL_ROOT_NODE ) -> t.List[tuple]: depth_list = list( - map(cl.get_longest_path_lengths_from_target(CL_EUKARYOTIC_CELL_ROOT_NODE).get, aggregated_scores.cl_names) + map(cl.get_longest_path_lengths_from_target(root_note).get, aggregated_scores.cl_names) ) sorted_score_and_depth_list = sorted( list( @@ -319,7 +323,7 @@ def get_most_granular_top_k_calls( trunc_list = sorted_score_and_depth_list[:top_k] # pad with root node if necessary for _ in range(len(trunc_list) - top_k): - trunc_list.append((1.0, 0, CL_EUKARYOTIC_CELL_ROOT_NODE)) + trunc_list.append((1.0, 0, root_note)) return trunc_list @@ -329,6 +333,7 @@ def compute_most_granular_top_k_calls_single( min_acceptable_score: float, top_k: int = 3, obs_prefix: str = "cas_cell_type", + root_note: str = CL_EUKARYOTIC_CELL_ROOT_NODE ): top_k_calls_dict = defaultdict(list) scores_array_nc = adata.obsm[CAS_CL_SCORES_ANNDATA_OBSM_KEY].toarray() @@ -358,7 +363,7 @@ def compute_most_granular_top_k_calls_single( for i_cell in range(adata.n_obs): aggregated_scores.aggregated_scores_c = scores_array_nc[i_cell] - top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k) + top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k, root_note) for k in range(top_k): top_k_calls_dict[f"{obs_prefix}_score_{k + 1}"].append(top_k_output[k][0]) top_k_calls_dict[f"{obs_prefix}_name_{k + 1}"].append(top_k_output[k][2]) @@ -378,6 +383,7 @@ def compute_most_granular_top_k_calls_cluster( aggregation_score_threshod: float = 1e-4, top_k: int = 3, obs_prefix: str = "cas_cell_type", + root_note: str = CL_EUKARYOTIC_CELL_ROOT_NODE ): top_k_calls_dict = dict() for k in range(top_k): @@ -394,7 +400,7 @@ def _update_list(target_list, indices, value): aggregated_scores = get_aggregated_cas_ontology_aware_scores( adata, obs_indices, aggregation_op, aggregation_domain, aggregation_score_threshod ) - top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k) + top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k, root_note) for k in range(top_k): _update_list(top_k_calls_dict[f"{obs_prefix}_score_{k + 1}"], obs_indices, top_k_output[k][0]) _update_list(top_k_calls_dict[f"{obs_prefix}_name_{k + 1}"], obs_indices, top_k_output[k][2])