Skip to content

Commit

Permalink
expose root node
Browse files Browse the repository at this point in the history
  • Loading branch information
mbabadi committed Sep 12, 2024
1 parent 72f500f commit e7b441b
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions cellarium/cas/postprocessing/ontology_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,14 @@ def _get_subtree_phyloxml_string(subtree_dict: OrderedDict, node_name: str, leve


def get_most_granular_top_k_calls(
aggregated_scores: AggregatedCellOntologyScores, cl: CellOntologyCache, min_acceptable_score: float, top_k: int = 1
aggregated_scores: AggregatedCellOntologyScores,
cl: CellOntologyCache,
min_acceptable_score: float,
top_k: int = 1,
root_note: str = CL_EUKARYOTIC_CELL_ROOT_NODE
) -> t.List[tuple]:
depth_list = list(
map(cl.get_longest_path_lengths_from_target(CL_EUKARYOTIC_CELL_ROOT_NODE).get, aggregated_scores.cl_names)
map(cl.get_longest_path_lengths_from_target(root_note).get, aggregated_scores.cl_names)
)
sorted_score_and_depth_list = sorted(
list(
Expand All @@ -319,7 +323,7 @@ def get_most_granular_top_k_calls(
trunc_list = sorted_score_and_depth_list[:top_k]
# pad with root node if necessary
for _ in range(len(trunc_list) - top_k):
trunc_list.append((1.0, 0, CL_EUKARYOTIC_CELL_ROOT_NODE))
trunc_list.append((1.0, 0, root_note))
return trunc_list


Expand All @@ -329,6 +333,7 @@ def compute_most_granular_top_k_calls_single(
min_acceptable_score: float,
top_k: int = 3,
obs_prefix: str = "cas_cell_type",
root_note: str = CL_EUKARYOTIC_CELL_ROOT_NODE
):
top_k_calls_dict = defaultdict(list)
scores_array_nc = adata.obsm[CAS_CL_SCORES_ANNDATA_OBSM_KEY].toarray()
Expand Down Expand Up @@ -358,7 +363,7 @@ def compute_most_granular_top_k_calls_single(

for i_cell in range(adata.n_obs):
aggregated_scores.aggregated_scores_c = scores_array_nc[i_cell]
top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k)
top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k, root_note)
for k in range(top_k):
top_k_calls_dict[f"{obs_prefix}_score_{k + 1}"].append(top_k_output[k][0])
top_k_calls_dict[f"{obs_prefix}_name_{k + 1}"].append(top_k_output[k][2])
Expand All @@ -378,6 +383,7 @@ def compute_most_granular_top_k_calls_cluster(
aggregation_score_threshod: float = 1e-4,
top_k: int = 3,
obs_prefix: str = "cas_cell_type",
root_note: str = CL_EUKARYOTIC_CELL_ROOT_NODE
):
top_k_calls_dict = dict()
for k in range(top_k):
Expand All @@ -394,7 +400,7 @@ def _update_list(target_list, indices, value):
aggregated_scores = get_aggregated_cas_ontology_aware_scores(
adata, obs_indices, aggregation_op, aggregation_domain, aggregation_score_threshod
)
top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k)
top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k, root_note)
for k in range(top_k):
_update_list(top_k_calls_dict[f"{obs_prefix}_score_{k + 1}"], obs_indices, top_k_output[k][0])
_update_list(top_k_calls_dict[f"{obs_prefix}_name_{k + 1}"], obs_indices, top_k_output[k][2])
Expand Down

0 comments on commit e7b441b

Please sign in to comment.