Skip to content

Commit

Permalink
Merge pull request #85 from cellarium-ai/mb-expose-root-node
Browse files Browse the repository at this point in the history
Exposing root node in visualization app and logging fixes
  • Loading branch information
mbabadi authored Oct 7, 2024
2 parents f1cc5a3 + 148bb01 commit dda87f2
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 54 deletions.
3 changes: 2 additions & 1 deletion cellarium/cas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from cellarium.cas.client import CASClient

from . import constants, exceptions, postprocessing, preprocessing, service, settings, version, visualization
from . import constants, exceptions, logging, postprocessing, preprocessing, service, settings, version, visualization

__version__ = version.get_version()

Expand All @@ -15,4 +15,5 @@
"service",
"settings",
"version",
"logging",
]
3 changes: 2 additions & 1 deletion cellarium/cas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from anndata import ImplicitModificationWarning
from deprecated import deprecated

from cellarium.cas.logging import logger
from cellarium.cas.service import action_context_manager

from . import _io, constants, exceptions, models, preprocessing, service, settings, version
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
api_token=api_token, api_url=api_url, client_session_id=self.client_session_id
)

self.__print(f"Connecting to the Cellarium Cloud backend with session {self.client_session_id}...")
logger.info(f"Connecting to the Cellarium Cloud backend with session {self.client_session_id}...")
self.user_info = self.cas_api_service.validate_token()
username = self.user_info["username"]
self.__print(f"User is {username}")
Expand Down
17 changes: 17 additions & 0 deletions cellarium/cas/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import logging

from cellarium.cas import settings

# Create a custom logger for the package
logger = logging.getLogger("cellarium.cas")
logger.setLevel(settings.LOGGING_LEVEL)

# Create a handler
handler = logging.StreamHandler()

# Create a formatter and set it for the handler
formatter = logging.Formatter(fmt=settings.LOGGING_FORMAT, datefmt=settings.LOGGING_DATE_FORMAT)
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
import typing as t
from functools import lru_cache
from logging import log

import networkx as nx
import owlready2
from scipy import sparse as sp

from cellarium.cas import settings # noqa
from cellarium.cas._io import suppress_stderr
from cellarium.cas.logging import logger

# Used in CZ CELLxGENE schema v5:
# https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/5.0.0/schema.md
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self, cl_owl_path: str = DEFAULT_CL_OWL_PATH):
"""

with suppress_stderr():
log(logging.INFO, f"Loading cell ontology OWL from:\n{cl_owl_path}")
logger.info(f"Loading cell ontology OWL from: {cl_owl_path}")
cl = owlready2.get_ontology(cl_owl_path).load()

# only keep CL classes with a singleton label
Expand Down
36 changes: 17 additions & 19 deletions cellarium/cas/postprocessing/ontology_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from cellarium.cas.models import CellTypeOntologyAwareResults

from .cell_ontology.cell_ontology_cache import CL_CELL_ROOT_NODE, CL_EUKARYOTIC_CELL_ROOT_NODE, CellOntologyCache
from .cell_ontology.cell_ontology_cache import CL_CELL_ROOT_NODE, CellOntologyCache
from .common import get_obs_indices_for_cluster

# AnnData-related constants
Expand Down Expand Up @@ -63,7 +63,9 @@ def convert_cas_ontology_aware_response_to_score_matrix(


def insert_cas_ontology_aware_response_into_adata(
cas_ontology_aware_response: CellTypeOntologyAwareResults, adata: AnnData, cl: CellOntologyCache
cas_ontology_aware_response: CellTypeOntologyAwareResults,
adata: AnnData,
cl: CellOntologyCache = CellOntologyCache(),
) -> None:
"""
Inserts Cellarium CAS ontology aware response into `obsm` property of a provided AnnData file as a
Expand Down Expand Up @@ -180,7 +182,7 @@ def get_aggregated_cas_ontology_aware_scores(
def convert_aggregated_cell_ontology_scores_to_rooted_tree(
aggregated_scores: AggregatedCellOntologyScores,
cl: CellOntologyCache,
root_cl_name: str = CL_EUKARYOTIC_CELL_ROOT_NODE,
root_cl_name: str = CL_CELL_ROOT_NODE,
min_fraction: float = 0.0,
hidden_cl_names_set: t.Optional[t.Set[str]] = None,
) -> OrderedDict:
Expand Down Expand Up @@ -232,14 +234,6 @@ def build_subtree(node_dict: OrderedDict, node_name: str) -> OrderedDict:
# Validate that this is actually a rooted tree and if not recalculate with the base cell node
if len(tree_dict) == 1: # singly-rooted tree
return tree_dict
elif root_cl_name != CL_CELL_ROOT_NODE:
return convert_aggregated_cell_ontology_scores_to_rooted_tree(
aggregated_scores=aggregated_scores,
cl=cl,
root_cl_name=CL_CELL_ROOT_NODE,
min_fraction=min_fraction,
hidden_cl_names_set=hidden_cl_names_set,
)
else:
raise ValueError("The tree is not singly-rooted.")

Expand Down Expand Up @@ -300,11 +294,13 @@ def _get_subtree_phyloxml_string(subtree_dict: OrderedDict, node_name: str, leve


def get_most_granular_top_k_calls(
aggregated_scores: AggregatedCellOntologyScores, cl: CellOntologyCache, min_acceptable_score: float, top_k: int = 1
aggregated_scores: AggregatedCellOntologyScores,
cl: CellOntologyCache,
min_acceptable_score: float,
top_k: int = 1,
root_note: str = CL_CELL_ROOT_NODE,
) -> t.List[tuple]:
depth_list = list(
map(cl.get_longest_path_lengths_from_target(CL_EUKARYOTIC_CELL_ROOT_NODE).get, aggregated_scores.cl_names)
)
depth_list = list(map(cl.get_longest_path_lengths_from_target(root_note).get, aggregated_scores.cl_names))
sorted_score_and_depth_list = sorted(
list(
(score, depth, cl_name)
Expand All @@ -318,8 +314,8 @@ def get_most_granular_top_k_calls(
)
trunc_list = sorted_score_and_depth_list[:top_k]
# pad with root node if necessary
for _ in range(len(trunc_list) - top_k):
trunc_list.append((1.0, 0, CL_EUKARYOTIC_CELL_ROOT_NODE))
for _ in range(top_k - len(trunc_list)):
trunc_list.append((1.0, 0, root_note))
return trunc_list


Expand All @@ -329,6 +325,7 @@ def compute_most_granular_top_k_calls_single(
min_acceptable_score: float,
top_k: int = 3,
obs_prefix: str = "cas_cell_type",
root_note: str = CL_CELL_ROOT_NODE,
):
top_k_calls_dict = defaultdict(list)
scores_array_nc = adata.obsm[CAS_CL_SCORES_ANNDATA_OBSM_KEY].toarray()
Expand Down Expand Up @@ -358,7 +355,7 @@ def compute_most_granular_top_k_calls_single(

for i_cell in range(adata.n_obs):
aggregated_scores.aggregated_scores_c = scores_array_nc[i_cell]
top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k)
top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k, root_note)
for k in range(top_k):
top_k_calls_dict[f"{obs_prefix}_score_{k + 1}"].append(top_k_output[k][0])
top_k_calls_dict[f"{obs_prefix}_name_{k + 1}"].append(top_k_output[k][2])
Expand All @@ -378,6 +375,7 @@ def compute_most_granular_top_k_calls_cluster(
aggregation_score_threshod: float = 1e-4,
top_k: int = 3,
obs_prefix: str = "cas_cell_type",
root_note: str = CL_CELL_ROOT_NODE,
):
top_k_calls_dict = dict()
for k in range(top_k):
Expand All @@ -394,7 +392,7 @@ def _update_list(target_list, indices, value):
aggregated_scores = get_aggregated_cas_ontology_aware_scores(
adata, obs_indices, aggregation_op, aggregation_domain, aggregation_score_threshod
)
top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k)
top_k_output = get_most_granular_top_k_calls(aggregated_scores, cl, min_acceptable_score, top_k, root_note)
for k in range(top_k):
_update_list(top_k_calls_dict[f"{obs_prefix}_score_{k + 1}"], obs_indices, top_k_output[k][0])
_update_list(top_k_calls_dict[f"{obs_prefix}_name_{k + 1}"], obs_indices, top_k_output[k][2])
Expand Down
5 changes: 3 additions & 2 deletions cellarium/cas/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from aiohttp import client_exceptions

from cellarium.cas import constants, endpoints, exceptions, settings
from cellarium.cas.logging import logger

if settings.is_interactive_environment():
print("Running in an interactive environment, applying nest_asyncio")
logger.debug("Running in an interactive environment, applying nest_asyncio")
nest_asyncio.apply()

# Context variable to track the action id for the current context
Expand Down Expand Up @@ -210,7 +211,7 @@ async def _aiohttp_async_post(
except (json.decoder.JSONDecodeError, client_exceptions.ClientResponseError):
response_detail = await response.text()
except KeyError:
print("Response body doesn't have a 'detail' key, returning full response body")
logger.warning("Response body doesn't have a 'detail' key, returning full response body")
response_detail = str(await response.json())

self.raise_response_exception(status_code=status_code, detail=response_detail)
Expand Down
5 changes: 5 additions & 0 deletions cellarium/cas/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

NUM_ATTEMPTS_PER_CHUNK_DEFAULT = 7
MAX_NUM_REQUESTS_AT_A_TIME = 8
START_RETRY_DELAY = 5
Expand All @@ -6,6 +8,9 @@
AIOHTTP_READ_TIMEOUT_SECONDS = 730
MAX_CHUNK_SIZE_SEARCH_METHOD = 500
CELLARIUM_CLOUD_BACKEND_URL = "https://cellarium-cloud-api.cellarium.ai"
LOGGING_LEVEL = logging.INFO
LOGGING_FORMAT = "* [%(asctime)s.%(msecs)03d] %(message)s"
LOGGING_DATE_FORMAT = "%H:%M:%S"


def is_interactive_environment() -> bool:
Expand Down
4 changes: 2 additions & 2 deletions cellarium/cas/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from cellarium.cas.logging import logger

try:
from .circular_tree_plot_umap_dash_app.app import CASCircularTreePlotUMAPDashApp # noqa
from .ui_utils import find_and_kill_process # noqa

__all__ = ["CASCircularTreePlotUMAPDashApp", "find_and_kill_process"]
except ImportError:
logging.warn(
logger.warning(
"""
Visualization dependencies not installed.
To install the Cellarium CAS Client with visualation dependencies, please run:
Expand Down
33 changes: 18 additions & 15 deletions cellarium/cas/visualization/circular_tree_plot_umap_dash_app/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import logging
import os
import tempfile
import typing as t
from collections import OrderedDict
from logging import log

import dash_bootstrap_components as dbc
import numpy as np
Expand All @@ -15,18 +13,18 @@
from dash.development.base_component import Component
from plotly.express.colors import sample_colorscale

from cellarium.cas.models import CellTypeOntologyAwareResults
from cellarium.cas.logging import logger
from cellarium.cas.postprocessing import (
CAS_CL_SCORES_ANNDATA_OBSM_KEY,
CAS_METADATA_ANNDATA_UNS_KEY,
CellOntologyScoresAggregationDomain,
CellOntologyScoresAggregationOp,
convert_aggregated_cell_ontology_scores_to_rooted_tree,
generate_phyloxml_from_scored_cell_ontology_tree,
get_aggregated_cas_ontology_aware_scores,
get_obs_indices_for_cluster,
insert_cas_ontology_aware_response_into_adata,
)
from cellarium.cas.postprocessing.cell_ontology import CL_EUKARYOTIC_CELL_ROOT_NODE, CellOntologyCache
from cellarium.cas.postprocessing.cell_ontology import CL_CELL_ROOT_NODE, CellOntologyCache
from cellarium.cas.visualization._components.circular_tree_plot import CircularTreePlot
from cellarium.cas.visualization.ui_utils import ConfigValue, find_and_kill_process

Expand Down Expand Up @@ -176,7 +174,6 @@ class CASCircularTreePlotUMAPDashApp:
def __init__(
self,
adata: AnnData,
cas_ontology_aware_response: CellTypeOntologyAwareResults,
cluster_label_obs_column: t.Optional[str] = None,
aggregation_op: CellOntologyScoresAggregationOp = CellOntologyScoresAggregationOp.MEAN,
aggregation_domain: CellOntologyScoresAggregationDomain = CellOntologyScoresAggregationDomain.OVER_THRESHOLD,
Expand All @@ -195,6 +192,7 @@ def __init__(
circular_tree_start_angle: int = 180,
circular_tree_end_angle: int = 360,
figure_height: int = 400,
root_node: str = CL_CELL_ROOT_NODE,
hidden_cl_names_set: set[str] = DEFAULT_HIDDEN_CL_NAMES_SET,
shown_cl_names_set: set[str] = DEFAULT_SHOWN_CL_NAMES_SET,
score_colorscale: t.Union[str, list] = "Viridis",
Expand All @@ -217,11 +215,19 @@ def __init__(
self.circular_tree_start_angle = circular_tree_start_angle
self.circular_tree_end_angle = circular_tree_end_angle
self.height = figure_height
self.root_node = root_node
self.hidden_cl_names_set = hidden_cl_names_set
self.shown_cl_names_set = shown_cl_names_set
self.score_colorscale = score_colorscale

assert "X_umap" in adata.obsm, "UMAP coordinates not found in adata.obsm['X_umap']"
assert "X_umap" in adata.obsm, (
"UMAP coordinates not found in adata.obsm['X_umap']. "
"This visualisation requires precomputed UMAP coordinates."
)
assert (CAS_CL_SCORES_ANNDATA_OBSM_KEY in adata.obsm) and (CAS_METADATA_ANNDATA_UNS_KEY in adata.uns), (
"Cell type ontology scores not found in the provided AnnData file. Please please run "
"`cellarium.cas.insert_cas_ontology_aware_response_into_adata` prior to running this visualisation."
)

# setup cell domains
self.cell_domain_map = OrderedDict()
Expand All @@ -243,29 +249,26 @@ def __init__(
# instantiate the cell type ontology cache
self.cl = CellOntologyCache()

# insert CA ontology-aware response into adata
insert_cas_ontology_aware_response_into_adata(cas_ontology_aware_response, adata, self.cl)

# instantiate the Dash app
self.app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP])
self.server = self.app.server
self.app.layout = self.__create_layout()
self.__setup_initialization()
self.__setup_callbacks()

def run(self, port: int = 8050, **kwargs):
def run(self, port: int = 8050, jupyter_mode: str = "inline", **kwargs):
"""
Run the Dash application on the specified port.
:param port: The port on which to run the Dash application. |br|
`Default:` ``8050``
"""
log(logging.INFO, "Starting Dash application on port {port}...")
logger.info(f"Starting Dash application on port {port}...")
try:
self.app.run_server(port=port, jupyter_mode="inline", jupyter_height=self.height + 100, **kwargs)
self.app.run_server(port=port, jupyter_mode=jupyter_mode, jupyter_height=self.height + 100, **kwargs)
except OSError: # Dash raises OSError if the port is already in use
find_and_kill_process(port)
self.app.run_server(port=port, jupyter_mode="inline", jupyter_height=self.height + 100, **kwargs)
self.app.run_server(port=port, jupyter_mode=jupyter_mode, jupyter_height=self.height + 100, **kwargs)

def __instantiate_circular_tree_plot(self) -> CircularTreePlot:
# reduce scores over the provided cells
Expand All @@ -284,7 +287,7 @@ def __instantiate_circular_tree_plot(self) -> CircularTreePlot:
rooted_tree = convert_aggregated_cell_ontology_scores_to_rooted_tree(
aggregated_scores=aggregated_scores,
cl=self.cl,
root_cl_name=CL_EUKARYOTIC_CELL_ROOT_NODE,
root_cl_name=self.root_node,
min_fraction=self.min_cell_fraction.get(),
hidden_cl_names_set=self.hidden_cl_names_set,
)
Expand Down
Loading

0 comments on commit dda87f2

Please sign in to comment.