Skip to content

Commit

Permalink
update documentation, simplified model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
tony-kuo committed Feb 12, 2024
1 parent 7cf3b99 commit f3753a8
Show file tree
Hide file tree
Showing 14 changed files with 2,624 additions and 1,112 deletions.
3 changes: 1 addition & 2 deletions src/scimilarity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@
from .cell_annotation import CellAnnotation
from .cell_query import CellQuery
from .interpreter import Interpreter
from .utils import align_dataset
from .zarr_dataset import ZarrDataset
from .utils import align_dataset, lognorm_counts
186 changes: 85 additions & 101 deletions src/scimilarity/cell_annotation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
import json
import operator
import os
import time
from collections import defaultdict
from typing import Optional, Union, List, Set, Tuple

import anndata
import hnswlib
import numpy as np
import pandas as pd
import pegasusio as pgio
from tqdm import tqdm

from scimilarity.cell_embedding import CellEmbedding
from scimilarity.ontologies import import_cell_ontology, get_id_mapper
from scimilarity.utils import check_dataset, lognorm_counts, align_dataset
from scimilarity.zarr_dataset import ZarrDataset


class CellAnnotation(CellEmbedding):
Expand Down Expand Up @@ -49,6 +34,8 @@ def __init__(
>>> ca = CellAnnotation(model_path="/opt/data/model")
"""

import os

super().__init__(
model_path=model_path,
use_gpu=use_gpu,
Expand All @@ -60,11 +47,13 @@ def __init__(
if filenames is None:
filenames = {}

self.annotation_path = os.path.join(model_path, "annotation")
self.filenames["knn"] = os.path.join(
model_path, filenames.get("knn", "labelled_kNN.bin")
self.annotation_path, filenames.get("knn", "labelled_kNN.bin")
)
self.filenames["celltype_labels"] = os.path.join(
model_path, filenames.get("celltype_labels", "reference_labels.tsv")
self.annotation_path,
filenames.get("celltype_labels", "reference_labels.tsv"),
)

# get knn
Expand All @@ -74,12 +63,13 @@ def __init__(
with open(self.filenames["celltype_labels"], "r") as fh:
self.idx2label = {i: line.strip() for i, line in enumerate(fh)}

self.classes = set(self.label2int.keys())
self.safelist = None
self.blocklist = None

def build_kNN(
self,
input_data: Union[anndata.AnnData, pgio.MultimodalData, pgio.UnimodalData, str],
input_data: Union["anndata.AnnData", str],
knn_filename: str = "labelled_kNN.bin",
celltype_labels_filename: str = "reference_labels.tsv",
obs_field: str = "celltype_name",
Expand All @@ -91,15 +81,15 @@ def build_kNN(
Parameters
----------
input_data: Union[anndata.AnnData, pegasusio.MultimodalData, pegasusio.UnimodalData, str],
input_data: Union[anndata.AnnData, str],
If a string, the filename of h5ad data file or directory containing zarr stores.
Otherwise, the annotated data matrix with rows for cells and columns for genes.
knn_filename: str, default: "labelled_kNN.bin"
Filename of the kNN index.
celltype_labels_filename: str, default: "reference_labels.tsv"
Filename of the cell type reference labels.
obs_field: str, default: "celltype_name"
The obs key name of celltype labels.
The obs column name of celltype labels.
ef_construction: int, default: 1000
The size of the dynamic list for the nearest neighbors.
See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md
Expand All @@ -114,6 +104,15 @@ def build_kNN(
>>> ca.build_knn(filename="/opt/data/train/train.h5ad")
"""

import anndata
import hnswlib
import numpy as np
import os
import pandas as pd
from scimilarity.utils import align_dataset
from scimilarity.zarr_dataset import ZarrDataset
from tqdm import tqdm

if isinstance(input_data, str) and os.path.isdir(input_data):
data_list = [
f for f in os.listdir(input_data) if f.endswith(".aligned.zarr")
Expand All @@ -125,7 +124,7 @@ def build_kNN(
obs = pd.DataFrame({obs_field: dataset.get_obs(obs_field)})
obs.index = obs.index.astype(str)
data = anndata.AnnData(
X=dataset.X_copy,
X=dataset.get_X(in_mem=True),
obs=obs,
var=pd.DataFrame(index=dataset.var_index),
dtype=np.float32,
Expand All @@ -143,7 +142,7 @@ def build_kNN(
embeddings = np.concatenate(embeddings_list)
else:
if isinstance(input_data, str) and os.path.isfile(input_data):
data = pgio.read_input(input_data)
data = anndata.read_h5ad(input_data)
else:
data = input_data

Expand All @@ -170,14 +169,14 @@ def build_kNN(
self.knn.set_ef(ef_construction)
self.knn.add_items(embeddings, range(len(embeddings)))

knn_fullpath = os.path.join(self.model_path, knn_filename)
knn_fullpath = os.path.join(self.annotation_path, knn_filename)
if os.path.isfile(knn_fullpath): # backup existing
os.rename(knn_fullpath, knn_fullpath + ".bak")
self.knn.save_index(knn_fullpath)

# save labels
celltype_labels_fullpath = os.path.join(
self.model_path, celltype_labels_filename
self.annotation_path, celltype_labels_filename
)
if os.path.isfile(celltype_labels_fullpath): # backup existing
os.rename(
Expand All @@ -199,6 +198,9 @@ def reset_kNN(self):
>>> ca.reset_kNN()
"""

self.blocklist = None
self.safelist = None

# hnswlib does not have a marked status, so we need to unmark all
for i in self.idx2label:
try: # throws an expection if not already marked
Expand All @@ -216,22 +218,21 @@ def blocklist_celltypes(self, labels: Union[List[str], Set[str]]):
Notes
-----
Blocking a celltype will persist for this instance of the class
and subsequent predictions will have this blocklist.
Blocking a celltype will persist for this instance of the class and subsequent predictions will have this blocklist.
Blocklists and safelists are mutually exclusive, setting one will clear the other.
Examples
--------
>>> ca.blocklist_celltypes(["T cell"])
"""

self.blocklist = set(labels) if isinstance(labels, list) else labels
self.blocklist = set(labels)
self.safelist = None

self.reset_kNN()
for i in [
idx for idx in self.idx2label if self.idx2label[idx] in self.blocklist
]:
self.knn.mark_deleted(i)
for i, celltype_name in self.idx2label.items():
if celltype_name in self.blocklist:
self.knn.mark_deleted(i) # mark blocklist

def safelist_celltypes(self, labels: Union[List[str], Set[str]]):
"""Safelist celltypes.
Expand All @@ -243,8 +244,7 @@ def safelist_celltypes(self, labels: Union[List[str], Set[str]]):
Notes
-----
Safelisting a celltype will persist for this instance of the class
and subsequent predictions will have this safelist.
Safelisting a celltype will persist for this instance of the class and subsequent predictions will have this safelist.
Blocklists and safelists are mutually exclusive, setting one will clear the other.
Examples
Expand All @@ -253,24 +253,24 @@ def safelist_celltypes(self, labels: Union[List[str], Set[str]]):
"""

self.blocklist = None
self.safelist = set(labels) if isinstance(labels, list) else labels
for i in range(len(self.idx2label)): # mark all
self.safelist = set(labels)

for i in self.idx2label: # mark all
try: # throws an exception if already marked
self.knn.mark_deleted(i)
except:
pass
for i in [
idx for idx in self.idx2label if self.idx2label[idx] in self.safelist
]:
self.knn.unmark_deleted(i)
for i, celltype_name in self.idx2label.items():
if celltype_name in self.safelist:
self.knn.unmark_deleted(i) # unmark safelist

def get_predictions_kNN(
self,
embeddings: np.ndarray,
embeddings: "numpy.ndarray",
k: int = 50,
ef: int = 100,
weighting: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, pd.DataFrame]:
) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray", "pandas.DataFrame"]:
"""Get predictions from kNN search results.
Parameters
Expand All @@ -294,23 +294,31 @@ def get_predictions_kNN(
nn_dists: numpy.ndarray
A 2D numpy array of nearest neighbor distances [num_cells x k].
stats: pandas.DataFrame
Prediction statistics columns:
Prediction statistics dataframe with columns:
"hits" is a json string with the count for every class in k cells.
"min_dist" is the minimum distance.
"max_dist" is the maximum distance
"vs2nd" is sum of best / (best + 2nd best).
"vsAll" is sum of best / (all hits).
"vs2nd" is sum(best) / sum(best + 2nd best).
"vsAll" is sum(best) / sum(all hits).
"hits_weighted" is a json string with the weighted count for every class in k cells.
"vs2nd_weighted" is weighted sum of best / (best + 2nd best).
"vsAll_weighted" is weighted sum of best / (all hits).
"vs2nd_weighted" is weighted sum(best) / sum(best + 2nd best).
"vsAll_weighted" is weighted sum(best) / sum(all hits).
Examples
--------
>>> ca = CellAnnotation(model_path="/opt/data/model")
>>> embedding = ca.get_embeddings(align_dataset(data, ca.gene_order).X)
>>> predictions, nn_idxs, nn_dists, nn_stats = ca.get_predictions_kNN(embeddings)
>>> embeddings = ca.get_embeddings(align_dataset(data, ca.gene_order).X)
>>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_kNN(embeddings)
"""

from collections import defaultdict
import json
import operator
import numpy as np
import pandas as pd
import time
from tqdm import tqdm

start_time = time.time()
nn_idxs, nn_dists = self.get_nearest_neighbors(
embeddings=embeddings, k=k, ef=ef
Expand Down Expand Up @@ -339,7 +347,7 @@ def get_predictions_kNN(
celltype_weighted = defaultdict(float)
for neighbor, dist in zip(nns, d_nns):
celltype[self.idx2label[neighbor]] += 1
celltype_weighted[self.idx2label[neighbor]] += 1 / dist
celltype_weighted[self.idx2label[neighbor]] += 1 / max(dist, 1e-6)
# predict based on consensus max occurrence
if weighting:
predictions.append(
Expand All @@ -356,7 +364,10 @@ def get_predictions_kNN(
stats["max_dist"].append(np.max(d_nns))

hits = sorted(celltype.values(), reverse=True)
hits_weighted = sorted(celltype_weighted.values(), reverse=True)
hits_weighted = [
max(x, 1e-6)
for x in sorted(celltype_weighted.values(), reverse=True)
]
if len(hits) > 1:
stats["vs2nd"].append(hits[0] / (hits[0] + hits[1]))
stats["vsAll"].append(hits[0] / sum(hits))
Expand All @@ -380,75 +391,48 @@ def get_predictions_kNN(

def annotate_dataset(
self,
dataset: Union[anndata.AnnData, pgio.MultimodalData, pgio.UnimodalData, str],
return_type: Optional[str] = None,
skip_preprocessing: bool = False,
) -> Union[anndata.AnnData, pgio.UnimodalData]:
"""Read a dataset, check validity, preprocess, and then annotate with celltype predictions.
dataset: Union["anndata.AnnData", str],
) -> "anndata.AnnData":
"""Annotate dataset with celltype predictions.
Parameters
----------
dataset: Union[pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData, str]
dataset: Union[anndata.AnnData, str]
If a string, the filename of the h5ad file.
Otherwise, the annotated data matrix with rows for cells and columns for genes.
return_type: {"AnnData", "UnimodalData"}, optional
Data return type string. If None, then it will return the same type as the input dataset.
If a string was given for the dataset, defaults to UnimodalData as the return type.
skip_preprocessing: bool, default: False
Whether to skip preprocessing steps.
This function assumes the data has been log normalized (i.e. via lognorm_counts) accordingly and
aligned to the gene space (i.e. via align_dataset).
Returns
-------
Union["AnnData", "UnimodalData"]
A data object where the normalized data is in matrix/layer "lognorm",
celltype predictions are in obs["celltype_hint"],
and embeddings are in obs["X_triplet"].
anndata.AnnData
A data object where:
- celltype predictions are in obs["celltype_hint"]
- embeddings are in obs["X_scimilarity"].
Examples
--------
>>> ca = CellAnnotation(model_path="/opt/data/model")
>>> data = annotate_dataset("/opt/individual_anndatas/GSE124898/GSM3558026/GSM3558026.h5ad")
"""

valid_return_types = {"AnnData", "UnimodalData"}
if return_type is not None and return_type not in valid_return_types:
raise ValueError(
f"Unknown return_type {return_type}. Options are {valid_return_types}."
)
import anndata
from scimilarity.utils import align_dataset

if isinstance(dataset, str):
data = pgio.read_input(dataset)
if return_type is None:
return_type = "UnimodalData"
else:
data = dataset

if isinstance(data, anndata.AnnData):
return_type = "AnnData"
data = anndata.read_h5ad(dataset)
else:
return_type = "UnimodalData"
data = dataset.copy()

if skip_preprocessing:
normalized_data = data
else:
check_dataset(data, self.gene_order, gene_overlap_threshold=10000)
normalized_data = lognorm_counts(data)

embeddings = self.get_embeddings(
align_dataset(normalized_data, self.gene_order).X
)
normalized_data.obsm["X_triplet"] = embeddings
embeddings = self.get_embeddings(align_dataset(data, self.gene_order).X)
data.obsm["X_scimilarity"] = embeddings

predictions, _, _, nn_stats = self.get_predictions_kNN(embeddings)
normalized_data.obs["celltype_hint"] = predictions.values
normalized_data.obs["min_dist"] = nn_stats["min_dist"].values
normalized_data.obs["celltype_hits"] = nn_stats["hits"].values
normalized_data.obs["celltype_hits_weighted"] = nn_stats["hits_weighted"].values
normalized_data.obs["celltype_hint_stat"] = nn_stats["vsAll"].values
normalized_data.obs["celltype_hint_weighted_stat"] = nn_stats[
"vsAll_weighted"
].values

if return_type == "AnnData" and not isinstance(dataset, anndata.AnnData):
return normalized_data.to_anndata()
return normalized_data
data.obs["celltype_hint"] = predictions.values
data.obs["min_dist"] = nn_stats["min_dist"].values
data.obs["celltype_hits"] = nn_stats["hits"].values
data.obs["celltype_hits_weighted"] = nn_stats["hits_weighted"].values
data.obs["celltype_hint_stat"] = nn_stats["vsAll"].values
data.obs["celltype_hint_weighted_stat"] = nn_stats["vsAll_weighted"].values

return data
Loading

0 comments on commit f3753a8

Please sign in to comment.