From f3753a84eb4668fca5590ebeb3b2a269f4f990d9 Mon Sep 17 00:00:00 2001 From: tony-kuo Date: Mon, 12 Feb 2024 15:52:54 -0800 Subject: [PATCH] update documentation, simplified model loading --- src/scimilarity/__init__.py | 3 +- src/scimilarity/cell_annotation.py | 186 ++--- src/scimilarity/cell_embedding.py | 79 +- src/scimilarity/cell_query.py | 482 ++++++------ src/scimilarity/data_models.py | 23 +- src/scimilarity/interpreter.py | 74 +- src/scimilarity/nn_models.py | 55 +- src/scimilarity/ontologies.py | 352 +++++++-- src/scimilarity/training_models.py | 232 ++++-- src/scimilarity/triplet_selector.py | 152 +++- src/scimilarity/utils.py | 1124 ++++++++++++++++++--------- src/scimilarity/visualizations.py | 276 ++++++- src/scimilarity/zarr_data_models.py | 148 +++- src/scimilarity/zarr_dataset.py | 550 +++++++++---- 14 files changed, 2624 insertions(+), 1112 deletions(-) diff --git a/src/scimilarity/__init__.py b/src/scimilarity/__init__.py index 481a782..0c25a84 100644 --- a/src/scimilarity/__init__.py +++ b/src/scimilarity/__init__.py @@ -20,5 +20,4 @@ from .cell_annotation import CellAnnotation from .cell_query import CellQuery from .interpreter import Interpreter -from .utils import align_dataset -from .zarr_dataset import ZarrDataset +from .utils import align_dataset, lognorm_counts diff --git a/src/scimilarity/cell_annotation.py b/src/scimilarity/cell_annotation.py index 1ffbd90..d54b071 100644 --- a/src/scimilarity/cell_annotation.py +++ b/src/scimilarity/cell_annotation.py @@ -1,21 +1,6 @@ -import json -import operator -import os -import time -from collections import defaultdict from typing import Optional, Union, List, Set, Tuple -import anndata -import hnswlib -import numpy as np -import pandas as pd -import pegasusio as pgio -from tqdm import tqdm - from scimilarity.cell_embedding import CellEmbedding -from scimilarity.ontologies import import_cell_ontology, get_id_mapper -from scimilarity.utils import check_dataset, lognorm_counts, align_dataset -from scimilarity.zarr_dataset import ZarrDataset class CellAnnotation(CellEmbedding): @@ -49,6 +34,8 @@ def __init__( >>> ca = CellAnnotation(model_path="/opt/data/model") """ + import os + super().__init__( model_path=model_path, use_gpu=use_gpu, @@ -60,11 +47,13 @@ def __init__( if filenames is None: filenames = {} + self.annotation_path = os.path.join(model_path, "annotation") self.filenames["knn"] = os.path.join( - model_path, filenames.get("knn", "labelled_kNN.bin") + self.annotation_path, filenames.get("knn", "labelled_kNN.bin") ) self.filenames["celltype_labels"] = os.path.join( - model_path, filenames.get("celltype_labels", "reference_labels.tsv") + self.annotation_path, + filenames.get("celltype_labels", "reference_labels.tsv"), ) # get knn @@ -74,12 +63,13 @@ def __init__( with open(self.filenames["celltype_labels"], "r") as fh: self.idx2label = {i: line.strip() for i, line in enumerate(fh)} + self.classes = set(self.label2int.keys()) self.safelist = None self.blocklist = None def build_kNN( self, - input_data: Union[anndata.AnnData, pgio.MultimodalData, pgio.UnimodalData, str], + input_data: Union["anndata.AnnData", str], knn_filename: str = "labelled_kNN.bin", celltype_labels_filename: str = "reference_labels.tsv", obs_field: str = "celltype_name", @@ -91,7 +81,7 @@ def build_kNN( Parameters ---------- - input_data: Union[anndata.AnnData, pegasusio.MultimodalData, pegasusio.UnimodalData, str], + input_data: Union[anndata.AnnData, str], If a string, the filename of h5ad data file or directory containing zarr stores. Otherwise, the annotated data matrix with rows for cells and columns for genes. knn_filename: str, default: "labelled_kNN.bin" @@ -99,7 +89,7 @@ def build_kNN( celltype_labels_filename: str, default: "reference_labels.tsv" Filename of the cell type reference labels. obs_field: str, default: "celltype_name" - The obs key name of celltype labels. + The obs column name of celltype labels. ef_construction: int, default: 1000 The size of the dynamic list for the nearest neighbors. See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md @@ -114,6 +104,15 @@ def build_kNN( >>> ca.build_knn(filename="/opt/data/train/train.h5ad") """ + import anndata + import hnswlib + import numpy as np + import os + import pandas as pd + from scimilarity.utils import align_dataset + from scimilarity.zarr_dataset import ZarrDataset + from tqdm import tqdm + if isinstance(input_data, str) and os.path.isdir(input_data): data_list = [ f for f in os.listdir(input_data) if f.endswith(".aligned.zarr") @@ -125,7 +124,7 @@ def build_kNN( obs = pd.DataFrame({obs_field: dataset.get_obs(obs_field)}) obs.index = obs.index.astype(str) data = anndata.AnnData( - X=dataset.X_copy, + X=dataset.get_X(in_mem=True), obs=obs, var=pd.DataFrame(index=dataset.var_index), dtype=np.float32, @@ -143,7 +142,7 @@ def build_kNN( embeddings = np.concatenate(embeddings_list) else: if isinstance(input_data, str) and os.path.isfile(input_data): - data = pgio.read_input(input_data) + data = anndata.read_h5ad(input_data) else: data = input_data @@ -170,14 +169,14 @@ def build_kNN( self.knn.set_ef(ef_construction) self.knn.add_items(embeddings, range(len(embeddings))) - knn_fullpath = os.path.join(self.model_path, knn_filename) + knn_fullpath = os.path.join(self.annotation_path, knn_filename) if os.path.isfile(knn_fullpath): # backup existing os.rename(knn_fullpath, knn_fullpath + ".bak") self.knn.save_index(knn_fullpath) # save labels celltype_labels_fullpath = os.path.join( - self.model_path, celltype_labels_filename + self.annotation_path, celltype_labels_filename ) if os.path.isfile(celltype_labels_fullpath): # backup existing os.rename( @@ -199,6 +198,9 @@ def reset_kNN(self): >>> ca.reset_kNN() """ + self.blocklist = None + self.safelist = None + # hnswlib does not have a marked status, so we need to unmark all for i in self.idx2label: try: # throws an expection if not already marked @@ -216,8 +218,7 @@ def blocklist_celltypes(self, labels: Union[List[str], Set[str]]): Notes ----- - Blocking a celltype will persist for this instance of the class - and subsequent predictions will have this blocklist. + Blocking a celltype will persist for this instance of the class and subsequent predictions will have this blocklist. Blocklists and safelists are mutually exclusive, setting one will clear the other. Examples @@ -225,13 +226,13 @@ def blocklist_celltypes(self, labels: Union[List[str], Set[str]]): >>> ca.blocklist_celltypes(["T cell"]) """ - self.blocklist = set(labels) if isinstance(labels, list) else labels + self.blocklist = set(labels) self.safelist = None + self.reset_kNN() - for i in [ - idx for idx in self.idx2label if self.idx2label[idx] in self.blocklist - ]: - self.knn.mark_deleted(i) + for i, celltype_name in self.idx2label.items(): + if celltype_name in self.blocklist: + self.knn.mark_deleted(i) # mark blocklist def safelist_celltypes(self, labels: Union[List[str], Set[str]]): """Safelist celltypes. @@ -243,8 +244,7 @@ def safelist_celltypes(self, labels: Union[List[str], Set[str]]): Notes ----- - Safelisting a celltype will persist for this instance of the class - and subsequent predictions will have this safelist. + Safelisting a celltype will persist for this instance of the class and subsequent predictions will have this safelist. Blocklists and safelists are mutually exclusive, setting one will clear the other. Examples @@ -253,24 +253,24 @@ def safelist_celltypes(self, labels: Union[List[str], Set[str]]): """ self.blocklist = None - self.safelist = set(labels) if isinstance(labels, list) else labels - for i in range(len(self.idx2label)): # mark all + self.safelist = set(labels) + + for i in self.idx2label: # mark all try: # throws an exception if already marked self.knn.mark_deleted(i) except: pass - for i in [ - idx for idx in self.idx2label if self.idx2label[idx] in self.safelist - ]: - self.knn.unmark_deleted(i) + for i, celltype_name in self.idx2label.items(): + if celltype_name in self.safelist: + self.knn.unmark_deleted(i) # unmark safelist def get_predictions_kNN( self, - embeddings: np.ndarray, + embeddings: "numpy.ndarray", k: int = 50, ef: int = 100, weighting: bool = False, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, pd.DataFrame]: + ) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray", "pandas.DataFrame"]: """Get predictions from kNN search results. Parameters @@ -294,23 +294,31 @@ def get_predictions_kNN( nn_dists: numpy.ndarray A 2D numpy array of nearest neighbor distances [num_cells x k]. stats: pandas.DataFrame - Prediction statistics columns: + Prediction statistics dataframe with columns: "hits" is a json string with the count for every class in k cells. "min_dist" is the minimum distance. "max_dist" is the maximum distance - "vs2nd" is sum of best / (best + 2nd best). - "vsAll" is sum of best / (all hits). + "vs2nd" is sum(best) / sum(best + 2nd best). + "vsAll" is sum(best) / sum(all hits). "hits_weighted" is a json string with the weighted count for every class in k cells. - "vs2nd_weighted" is weighted sum of best / (best + 2nd best). - "vsAll_weighted" is weighted sum of best / (all hits). + "vs2nd_weighted" is weighted sum(best) / sum(best + 2nd best). + "vsAll_weighted" is weighted sum(best) / sum(all hits). Examples -------- >>> ca = CellAnnotation(model_path="/opt/data/model") - >>> embedding = ca.get_embeddings(align_dataset(data, ca.gene_order).X) - >>> predictions, nn_idxs, nn_dists, nn_stats = ca.get_predictions_kNN(embeddings) + >>> embeddings = ca.get_embeddings(align_dataset(data, ca.gene_order).X) + >>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_kNN(embeddings) """ + from collections import defaultdict + import json + import operator + import numpy as np + import pandas as pd + import time + from tqdm import tqdm + start_time = time.time() nn_idxs, nn_dists = self.get_nearest_neighbors( embeddings=embeddings, k=k, ef=ef @@ -339,7 +347,7 @@ def get_predictions_kNN( celltype_weighted = defaultdict(float) for neighbor, dist in zip(nns, d_nns): celltype[self.idx2label[neighbor]] += 1 - celltype_weighted[self.idx2label[neighbor]] += 1 / dist + celltype_weighted[self.idx2label[neighbor]] += 1 / max(dist, 1e-6) # predict based on consensus max occurrence if weighting: predictions.append( @@ -356,7 +364,10 @@ def get_predictions_kNN( stats["max_dist"].append(np.max(d_nns)) hits = sorted(celltype.values(), reverse=True) - hits_weighted = sorted(celltype_weighted.values(), reverse=True) + hits_weighted = [ + max(x, 1e-6) + for x in sorted(celltype_weighted.values(), reverse=True) + ] if len(hits) > 1: stats["vs2nd"].append(hits[0] / (hits[0] + hits[1])) stats["vsAll"].append(hits[0] / sum(hits)) @@ -380,29 +391,24 @@ def get_predictions_kNN( def annotate_dataset( self, - dataset: Union[anndata.AnnData, pgio.MultimodalData, pgio.UnimodalData, str], - return_type: Optional[str] = None, - skip_preprocessing: bool = False, - ) -> Union[anndata.AnnData, pgio.UnimodalData]: - """Read a dataset, check validity, preprocess, and then annotate with celltype predictions. + dataset: Union["anndata.AnnData", str], + ) -> "anndata.AnnData": + """Annotate dataset with celltype predictions. Parameters ---------- - dataset: Union[pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData, str] + dataset: Union[anndata.AnnData, str] If a string, the filename of the h5ad file. Otherwise, the annotated data matrix with rows for cells and columns for genes. - return_type: {"AnnData", "UnimodalData"}, optional - Data return type string. If None, then it will return the same type as the input dataset. - If a string was given for the dataset, defaults to UnimodalData as the return type. - skip_preprocessing: bool, default: False - Whether to skip preprocessing steps. + This function assumes the data has been log normalized (i.e. via lognorm_counts) accordingly and + aligned to the gene space (i.e. via align_dataset). Returns ------- - Union["AnnData", "UnimodalData"] - A data object where the normalized data is in matrix/layer "lognorm", - celltype predictions are in obs["celltype_hint"], - and embeddings are in obs["X_triplet"]. + anndata.AnnData + A data object where: + - celltype predictions are in obs["celltype_hint"] + - embeddings are in obs["X_scimilarity"]. Examples -------- @@ -410,45 +416,23 @@ def annotate_dataset( >>> data = annotate_dataset("/opt/individual_anndatas/GSE124898/GSM3558026/GSM3558026.h5ad") """ - valid_return_types = {"AnnData", "UnimodalData"} - if return_type is not None and return_type not in valid_return_types: - raise ValueError( - f"Unknown return_type {return_type}. Options are {valid_return_types}." - ) + import anndata + from scimilarity.utils import align_dataset if isinstance(dataset, str): - data = pgio.read_input(dataset) - if return_type is None: - return_type = "UnimodalData" - else: - data = dataset - - if isinstance(data, anndata.AnnData): - return_type = "AnnData" + data = anndata.read_h5ad(dataset) else: - return_type = "UnimodalData" + data = dataset.copy() - if skip_preprocessing: - normalized_data = data - else: - check_dataset(data, self.gene_order, gene_overlap_threshold=10000) - normalized_data = lognorm_counts(data) - - embeddings = self.get_embeddings( - align_dataset(normalized_data, self.gene_order).X - ) - normalized_data.obsm["X_triplet"] = embeddings + embeddings = self.get_embeddings(align_dataset(data, self.gene_order).X) + data.obsm["X_scimilarity"] = embeddings predictions, _, _, nn_stats = self.get_predictions_kNN(embeddings) - normalized_data.obs["celltype_hint"] = predictions.values - normalized_data.obs["min_dist"] = nn_stats["min_dist"].values - normalized_data.obs["celltype_hits"] = nn_stats["hits"].values - normalized_data.obs["celltype_hits_weighted"] = nn_stats["hits_weighted"].values - normalized_data.obs["celltype_hint_stat"] = nn_stats["vsAll"].values - normalized_data.obs["celltype_hint_weighted_stat"] = nn_stats[ - "vsAll_weighted" - ].values - - if return_type == "AnnData" and not isinstance(dataset, anndata.AnnData): - return normalized_data.to_anndata() - return normalized_data + data.obs["celltype_hint"] = predictions.values + data.obs["min_dist"] = nn_stats["min_dist"].values + data.obs["celltype_hits"] = nn_stats["hits"].values + data.obs["celltype_hits_weighted"] = nn_stats["hits_weighted"].values + data.obs["celltype_hint_stat"] = nn_stats["vsAll"].values + data.obs["celltype_hint_weighted_stat"] = nn_stats["vsAll_weighted"].values + + return data diff --git a/src/scimilarity/cell_embedding.py b/src/scimilarity/cell_embedding.py index d8781be..b09f7b3 100644 --- a/src/scimilarity/cell_embedding.py +++ b/src/scimilarity/cell_embedding.py @@ -1,18 +1,5 @@ -import json -import os from typing import Optional, Tuple, Union -import hnswlib -import numpy as np -import pandas as pd -import torch -import zarr -from scipy.sparse import csr_matrix - -from scimilarity.b_colors import BColors -from scimilarity.nn_models import Encoder -from scimilarity.utils import align_dataset - class CellEmbedding: """A class that embeds cell gene expression data using a ML model.""" @@ -33,9 +20,9 @@ def __init__( Path to the directory containing model files. use_gpu: bool, default: False Use GPU instead of CPU. - parameters: dict, optional + parameters: dict, optional, default: None Use a dictionary of custom model parameters instead of infering from model files. - filenames: dict, optional + filenames: dict, optional, default: None Use a dictionary of custom filenames for model files instead default. residual: bool, default: False Use residual connections. @@ -45,6 +32,11 @@ def __init__( >>> ce = CellEmbedding(model_path="/opt/data/model") """ + import json + import os + import pandas as pd + from scimilarity.nn_models import Encoder + self.model_path = model_path self.use_gpu = use_gpu self.knn = None @@ -98,35 +90,18 @@ def __init__( )["0"].to_dict() self.label2int = {value: key for key, value in self.int2label.items()} - def load_knn_index(self, knn_file: str): - """Load the kNN index file - - Parameters - ---------- - knn_file: str - Filename of the kNN index. - """ - if os.path.isfile(knn_file): - self.knn = hnswlib.Index(space="cosine", dim=self.model.latent_dim) - self.knn.load_index(knn_file) - else: - print( - f"{BColors.WARNING}Warning: No KNN index found at {knn_file}{BColors.ENDC}" - ) - self.knn = None - def get_embeddings( self, - X: Union[csr_matrix, np.ndarray], + X: Union["scipy.sparse.csr_matrix", "numpy.ndarray"], num_cells: int = -1, buffer_size: int = 10000, - ) -> np.ndarray: + ) -> "numpy.ndarray": """Calculate embeddings for lognormed gene expression matrix. Parameters ---------- X: scipy.sparse.csr_matrix, numpy.ndarray - Gene expression matrix. + Gene space aligned and log normalized (tp10k) gene expression matrix. num_cells: int, default: -1 The number of cells to embed, starting from index 0. A value of -1 will embed all cells. @@ -140,11 +115,18 @@ def get_embeddings( Examples -------- - >>> from scimilarity.utils import align_dataset + >>> from scimilarity.utils import align_dataset, lognorm_counts >>> ce = CellEmbedding(model_path="/opt/data/model") - >>> embedding = ce.get_embeddings(align_dataset(data, ce.gene_order).X) + >>> data = align_dataset(data, ce.gene_order) + >>> data = lognorm_counts(data) + >>> embeddings = ce.get_embeddings(data.X) """ + import numpy as np + from scipy.sparse import csr_matrix + import torch + import zarr + if num_cells == -1: num_cells = X.shape[0] @@ -194,9 +176,28 @@ def get_embeddings( return embedding + def load_knn_index(self, knn_file: str): + """Load the kNN index file + + Parameters + ---------- + knn_file: str + Filename of the kNN index. + """ + + import hnswlib + import os + + if os.path.isfile(knn_file): + self.knn = hnswlib.Index(space="cosine", dim=self.model.latent_dim) + self.knn.load_index(knn_file) + else: + print(f"Warning: No KNN index found at {knn_file}") + self.knn = None + def get_nearest_neighbors( - self, embeddings: np.ndarray, k: int = 50, ef: int = 100 - ) -> Tuple[np.ndarray, np.ndarray]: + self, embeddings: "numpy.ndarray", k: int = 50, ef: int = 100 + ) -> Tuple["numpy.ndarray", "numpy.ndarray"]: """Get nearest neighbors. Used by classes that inherit from CellEmbedding and have an instantiated kNN. diff --git a/src/scimilarity/cell_query.py b/src/scimilarity/cell_query.py index e0b6331..c5fae08 100644 --- a/src/scimilarity/cell_query.py +++ b/src/scimilarity/cell_query.py @@ -1,15 +1,6 @@ -import os from typing import Dict, List, Optional, Tuple, Union, Set -import anndata -import numpy as np -import pandas as pd -import pegasusio as pgio -import tiledb - from scimilarity.cell_embedding import CellEmbedding -from scimilarity.utils import get_cluster_centroids -from scimilarity.visualizations import aggregate_counts, assign_size, circ_dict2data, draw_circles class CellQuery(CellEmbedding): @@ -18,7 +9,6 @@ class CellQuery(CellEmbedding): def __init__( self, model_path: str, - cellsearch_path: str, use_gpu: bool = False, parameters: Optional[dict] = None, filenames: Optional[dict] = None, @@ -32,9 +22,7 @@ def __init__( Parameters ---------- model_path: str - Path to the directory containing model files. - cellsearch_path: str - Path to the directory containing cell search files. + Path to the model directory. use_gpu: bool, default: False Use GPU instead of CPU. parameters: dict, optional, default: None @@ -55,6 +43,12 @@ def __init__( >>> cq = CellQuery(model_path="/opt/data/model") """ + import os + import numpy as np + import pandas as pd + import tiledb + from scimilarity.utils import write_tiledb_array, optimize_tiledb_array + super().__init__( model_path=model_path, use_gpu=use_gpu, @@ -62,16 +56,20 @@ def __init__( filenames=filenames, residual=residual, ) - self.cellsearch_path = cellsearch_path + self.cellsearch_path = os.path.join(model_path, "cellsearch") if filenames is None: filenames = {} self.filenames["knn"] = os.path.join( - cellsearch_path, filenames.get("knn", "full_kNN.bin") + self.cellsearch_path, filenames.get("knn", "full_kNN.bin") ) self.filenames["cell_metadata"] = os.path.join( - cellsearch_path, filenames.get("cell_metadata", "full_kNN_meta.csv") + self.cellsearch_path, filenames.get("cell_metadata", "full_kNN_meta.csv") + ) + self.filenames["cell_embeddings"] = os.path.join( + self.cellsearch_path, + filenames.get("cell_embeddings", "full_kNN_embedding.npy"), ) # get knn @@ -81,22 +79,22 @@ def __init__( # get cell metadata: create tiledb storage if it does not exist # NOTE: process for creating this file is not hardened, no guarantee index column is unique - metadata_tiledb_uri = os.path.join(cellsearch_path, metadata_tiledb_uri) + metadata_tiledb_uri = os.path.join(self.cellsearch_path, metadata_tiledb_uri) if not os.path.isdir(metadata_tiledb_uri): print(f"Configuring tiledb dataframe: {metadata_tiledb_uri}") cell_metadata = ( pd.read_csv( self.filenames["cell_metadata"], header=0, + dtype=str, ) .fillna("NA") - .astype(str) .reset_index(drop=True) ) cell_metadata = cell_metadata.rename(columns={"Unnamed: 0": "index"}) convert_dict = { "index": int, - "nn_dist": float, + "prediction_nn_dist": float, "fm_signature_score": float, "total_counts": float, "n_genes_by_counts": float, @@ -108,127 +106,24 @@ def __init__( self.cell_metadata = tiledb.open_dataframe(metadata_tiledb_uri) # get cell embeddings: create tiledb storage if it does not exist - embedding_tiledb_uri = os.path.join(cellsearch_path, embedding_tiledb_uri) + embedding_tiledb_uri = os.path.join(self.cellsearch_path, embedding_tiledb_uri) if not os.path.isdir(embedding_tiledb_uri): - if os.path.isfile(os.path.join(cellsearch_path, "schub_ood_embedding.npy")): - npy_list = [ - "ood_embedding.npy", - "schub_ood_embedding.npy", - "train_embedding.npy", - "test_embedding.npy", - ] - else: - npy_list = [ - "ood_embedding.npy", - "train_embedding.npy", - "test_embedding.npy", - ] - data_list = [os.path.join(cellsearch_path, f) for f in npy_list] - self.create_tiledb_array(embedding_tiledb_uri, data_list) - self.optimize_tiledb_array(embedding_tiledb_uri) + cell_embeddings = np.load( + os.path.join(cellsearch_path, self.filenames["cell_embeddings"]) + ) + write_tiledb_array(embedding_tiledb_uri, cell_embeddings) + optimize_tiledb_array(embedding_tiledb_uri) self.cell_embedding = tiledb.open(embedding_tiledb_uri) - self.study_sample_cells = self.cell_metadata.groupby(["study", "sample"]).size() - self.study_cells = self.cell_metadata.groupby("study").size() - self.study_sample_index = self.cell_metadata.groupby( - ["study", "sample", "train_type"] - )["index"].min() - self.study_index = self.cell_metadata.groupby(["study", "train_type"])[ - "index" - ].min() - - def create_tiledb_array( - self, tiledb_array_uri: str, data_list: List[str], batch_size: int = 10000 - ): - """Create TileDB Array - - Parameters - ---------- - tiledb_array_uri: str - URI for the TileDB array. - data_list: List[str] - List of data values. - batch_size: int, default: 10000 - Batch size for the tiles. - """ - print(f"Configuring tiledb array: {tiledb_array_uri}") - - xdimtype = np.int32 - ydimtype = np.int32 - value_type = np.float32 - - xdim = tiledb.Dim( - name="x", - domain=(0, self.cell_metadata.shape[0] - 1), - tile=batch_size, - dtype=xdimtype, - ) - ydim = tiledb.Dim( - name="y", - domain=(0, self.latent_dim - 1), - tile=self.latent_dim, - dtype=ydimtype, - ) - dom = tiledb.Domain(xdim, ydim) - - attr = tiledb.Attr( - name="vals", - dtype=value_type, - filters=tiledb.FilterList([tiledb.GzipFilter()]), - ) - - schema = tiledb.ArraySchema( - domain=dom, - sparse=False, - cell_order="row-major", - tile_order="row-major", - attrs=[attr], + self.study_sample_index = ( + self.cell_metadata.groupby(["study", "sample", "data_type"], observed=True)[ + "index" + ] + .min() + .sort_values() ) - tiledb.Array.create(tiledb_array_uri, schema) - - tdbfile = tiledb.open(tiledb_array_uri, "w") - previous_shape = None - for f in data_list: - if previous_shape is None: - paging_idx = 0 - else: - paging_idx += previous_shape[0] - - arr = np.load(f) - previous_shape = arr.shape - tbd_slice = slice(paging_idx, paging_idx + arr.shape[0]) - tdbfile[tbd_slice, 0 : self.latent_dim] = arr - tdbfile.close() - - def optimize_tiledb_array(self, tiledb_array_uri: str, verbose: bool = True): - """Optimize TileDB Array - - Parameters - ---------- - tiledb_array_uri: str - URI for the TileDB array. - verbose: bool - Boolean indicating whether to use verbose printing. - """ - if verbose: - print(f"Optimizing {tiledb_array_uri}") - - frags = tiledb.array_fragments(tiledb_array_uri) - if verbose: - print("Fragments before consolidation: {}".format(len(frags))) - - cfg = tiledb.Config() - cfg["sm.consolidation.step_min_frags"] = 1 - cfg["sm.consolidation.step_max_frags"] = 200 - tiledb.consolidate(tiledb_array_uri, config=cfg) - tiledb.vacuum(tiledb_array_uri) - - frags = tiledb.array_fragments(tiledb_array_uri) - if verbose: - print("Fragments after consolidation: {}".format(len(frags))) - - def get_precomputed_embeddings(self, idx: List[int]) -> np.ndarray: + def get_precomputed_embeddings(self, idx: List[int]) -> "numpy.ndarray": """Fast get of embeddings from the cell_embedding tiledb array. Parameters @@ -243,13 +138,45 @@ def get_precomputed_embeddings(self, idx: List[int]) -> np.ndarray: Examples -------- - >>> array = cq.get_tiledb_array([0, 1, 100]) + >>> array = cq.get_precomputed_embeddings([0, 1, 100]) """ return self.cell_embedding.query(attrs=["vals"], coords=True).multi_index[idx][ "vals" ] - def compile_sample_metadata(self, nn_idxs: np.ndarray) -> pd.DataFrame: + def annotate_cell_index(self, metadata: "pandas.DataFrame") -> "pandas.DataFrame": + """Annotate a metadata dataframe with the cell index in sample datasets. + + Parameters + ---------- + metadata: pandas.DataFrame + A pandas dataframe containing columns: study, sample, and index. + Where index is the cell query index (i.e. from cq.cell_metadata). + + Returns + ------- + pandas.DataFrame + A pandas dataframe containing the "cell_index" column which is the cell index + per sample dataset. + + Examples + -------- + >>> metadata = cq.annotate_cell_index(metadata) + """ + + cell_index = [] + for i, row in metadata.iterrows(): + study = row["study"] + sample = row["sample"] + if "data_type" not in row: + raise RuntimeError("Required column: 'data_type'") + data_type = row["data_type"] + index_start = self.study_sample_index.loc[study, sample, data_type] + cell_index.append(row["index"] - int(index_start)) + metadata["cell_index"] = cell_index + return metadata + + def compile_sample_metadata(self, nn_idxs: "numpy.ndarray") -> "pandas.DataFrame": """Compile sample metadata for nearest neighbors. Parameters @@ -269,11 +196,13 @@ def compile_sample_metadata(self, nn_idxs: np.ndarray) -> pd.DataFrame: >>> sample_metadata = cq.compile_sample_metadata(nn_idxs) """ + import pandas as pd + levels = ["tissue", "disease", "study", "sample"] df = pd.concat( [ self.cell_metadata.loc[hits] - .groupby(levels) + .groupby(levels, observed=True) .size() .reset_index(name="cells") for hits in nn_idxs @@ -281,45 +210,20 @@ def compile_sample_metadata(self, nn_idxs: np.ndarray) -> pd.DataFrame: axis=0, ).reset_index(drop=True) + study_sample_cells = self.cell_metadata.groupby( + ["study", "sample"], observed=True + ).size() + fraction = [] for i, row in df.iterrows(): - total_cells = self.study_sample_cells.loc[(row["study"], row["sample"])] + total_cells = study_sample_cells.loc[(row["study"], row["sample"])] fraction.append(row["cells"] / total_cells) df["fraction"] = fraction return df - def visualize_sample_metadata( - self, - sample_metadata: pd.DataFrame, - fig_size: Tuple[int, int] = (10, 10), - filename: Optional[str] = None, - ): - """Visualize sample metadata as circle plots for tissue and disease. - - Parameters - ---------- - sample_metadata: pandas.DataFrame - A pandas dataframe containing sample metadata for nearest neighbors. - figsize: Tuple[int, int], default: (10, 10) - Figure size, width x height - filename: str, optional - Filename to save the figure. - - Examples - -------- - >>> cq.visualize_sample_metadata(sample_metadata) - """ - - levels = ["tissue", "disease"] - - circ_dict = aggregate_counts(sample_metadata, levels) - circ_dict = assign_size( - circ_dict, sample_metadata, levels, size_column="cells", name_column="study" - ) - circ_data = circ_dict2data(circ_dict) - draw_circles(circ_data, fig_size=fig_size, filename=filename) - - def groupby_studies(self, sample_metadata: pd.DataFrame) -> pd.DataFrame: + def groupby_studies( + self, sample_metadata: "pandas.DataFrame" + ) -> "pandas.DataFrame": """Performs a groupby studies operation on sample metadata. Parameters @@ -340,66 +244,28 @@ def groupby_studies(self, sample_metadata: pd.DataFrame) -> pd.DataFrame: levels = ["tissue", "disease", "study"] df = ( sample_metadata[levels + ["cells"]] - .groupby(levels)["cells"] + .groupby(levels, observed=True)["cells"] .sum() .reset_index(name="cells") ) + + study_cells = self.cell_metadata.groupby("study", observed=True).size() + fraction = [] for i, row in df.iterrows(): - total_cells = self.study_cells.loc[row["study"]] + total_cells = study_cells.loc[row["study"]] fraction.append(row["cells"] / total_cells) df["fraction"] = fraction return df - def annotate_cell_index(self, metadata: pd.DataFrame) -> pd.DataFrame: - """Annotate a metadata dataframe with the cell index in sample datasets. - - Parameters - ---------- - metadata: pandas.DataFrame - A pandas dataframe containing columns: study, sample, and index. - Where index is the cell query index (i.e. from cq.cell_metadata). - aggregated: bool, default: False - Whether the training and test datasets are aggregated. - - Returns - ------- - pandas.DataFrame - A pandas dataframe containing the cell_index column which is the cell index - per sample dataset. - - Examples - -------- - >>> metadata = cq.annotate_cell_index(metadata) - """ - cell_index = [] - for _, row in metadata.iterrows(): - study = row["study"] - sample = row["sample"] - - if "train_type" not in row: - raise RuntimeError("Required column: 'train_type'") - train_type = row["train_type"] - - if train_type == "ood" or train_type == "schub_ood": - index_start = self.study_sample_index.loc[study, sample, train_type] - elif train_type == "train" or train_type == "test": - index_start = self.study_index.loc[study, train_type] - else: - raise RuntimeError(f"{train_type}: Unknown train type.") - - cell_index.append(row["index"] - int(index_start)) - metadata["cell_index"] = cell_index - return metadata - def search( self, - embeddings: np.ndarray, + embeddings: "numpy.ndarray", k: int = 1000, ef: int = None, max_dist: float = None, exclude_studies: Optional[List[str]] = None, - ) -> Tuple[List[np.ndarray], List[np.ndarray], pd.DataFrame]: + ) -> Tuple[List["numpy.ndarray"], List["numpy.ndarray"], "pandas.DataFrame"]: """Performs a cell query search against the kNN. Parameters @@ -414,7 +280,7 @@ def search( max_dist: float, optional Assume k=1000000, then filter for cells that are within the max distance to the query. Overwrites the k parameter. - exclude_studies: List[str], optional + exclude_studies: List[str], optional, default: None A list of studies to exclude from the search, given as a list of str study names. WARNING: If you do not use max_dist, you will potentially get less than k hits as the study exclusion is performed after the search. @@ -435,6 +301,8 @@ def search( >>> nn_idxs, nn_dists, metadata = cq.search(embedding) """ + import pandas as pd + if max_dist is not None: k = 1000000 @@ -469,16 +337,124 @@ def search( nn_idxs = new_nn_idxs nn_dists = new_nn_dists - metadata = pd.concat( - [self.cell_metadata.loc[hits].reset_index(drop=True) for hits in nn_idxs], - axis=0, - ).reset_index(drop=True) + metadata = [] + for i in range(len(nn_idxs)): + hits = nn_idxs[i] + df = self.cell_metadata.loc[hits].reset_index(drop=True) + df["embedding_idx"] = i + df["query_nn_dist"] = nn_dists[i] + metadata.append(df) + metadata = pd.concat(metadata).reset_index(drop=True) return nn_idxs, nn_dists, metadata - def search_centroids( + def search_centroid( + self, + data: "anndata.AnnData", + centroid_key: str, + k: int = 1000, + ef: int = None, + max_dist: float = None, + exclude_studies: Optional[List[str]] = None, + qc: bool = True, + qc_params: dict = {"k_clusters": 10}, + ) -> Tuple[ + "numpy.ndarray", + List["numpy.ndarray"], + List["numpy.ndarray"], + "pandas.DataFrame", + dict, + ]: + """Performs a cell query search for a centroid constructed from marked cells. + + Parameters + ---------- + data: anndata.AnnData + Annotated data matrix with rows for cells and columns for genes. + Requires a layers["counts"]. + centroid_key: str + The obs column key that marks cells to centroid as 1, otherwise 0. + k: int, default: 1000 + The number of nearest neighbors. + ef: int, default: None + The size of the dynamic list for the nearest neighbors. Defaults to k if None. + See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md + max_dist: float, optional + Assume k=1000000, then filter for cells that are within the max distance to the + query. Overwrites the k parameter. + exclude_studies: List[str], optional, default: None + A list of studies to exclude from the search, given as a list of str study names. + WARNING: If you do not use max_dist, you will potentially get less than k hits as + the study exclusion is performed after the search. + qc: bool, default: True + Whether to perform QC on the query + qc_params: dict, default: {'k_clusters': 10} + Parameters for the QC: + k_clusters: the number of clusters in kmeans clustering + + Returns + ------- + centroid_embedding: numpy.ndarray + A 2D numpy array of the centroid embedding. + nn_idxs: List[numpy.ndarray] + A list of 2D numpy array of nearest neighbor indices. + One entry for every cell (row) in embeddings + nn_dists: List[numpy.ndarray] + A list of 2D numpy array of nearest neighbor distances. + One entry for every cell (row) in embeddings + metadata: pandas.DataFrame + A pandas dataframe containing cell metadata for nearest neighbors. + qc_stats: dict + A dictionary of stats for QC. + + Examples + -------- + >>> cells_used_in_query = adata.obs["celltype_name"] == "macrophage" + >>> adata.obs["used_in_query"] = cells_used_in_query.astype(int) + >>> centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats = cq.search_centroid(adata, 'used_in_query') + """ + + import numpy as np + from scipy.cluster.vq import kmeans + from scipy.spatial.distance import cdist, pdist, squareform + from scimilarity.utils import get_centroid, get_dist2centroid + + cells = data[data.obs[centroid_key] == 1].copy() + centroid = get_centroid(cells.layers["counts"]) + centroid_embedding = self.get_embeddings(centroid) + + if max_dist is not None: + k = 1000000 + + if ef is None: + ef = k + + (nn_idxs, nn_dists, metadata) = self.search( + centroid_embedding, + k=k, + ef=ef, + max_dist=max_dist, + exclude_studies=exclude_studies, + ) + + qc_stats = {} + if qc: + cells_embedding = self.get_embeddings(cells.X) + k_clusters = qc_params.get("k_clusters", 10) + cluster_centroids = kmeans(cells_embedding, k_clusters)[0] + + cell_nn_idxs, _, _ = self.search(cluster_centroids, k=100) + query_overlap = [] + for i in range(len(cell_nn_idxs)): + overlap = [x for x in cell_nn_idxs[i] if x in nn_idxs[0]] + query_overlap.append(len(overlap)) + qc_stats["query_stability"] = np.mean(query_overlap) + + return centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats + + def search_cluster_centroids( self, - data: Union[anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData], + data: "anndata.AnnData", cluster_key: str, cluster_label: Optional[str] = None, k: int = 1000, @@ -487,21 +463,22 @@ def search_centroids( max_dist: float = None, exclude_studies: Optional[List[str]] = None, ) -> Tuple[ - np.ndarray, + "numpy.ndarray", list, - Dict[str, np.ndarray], - Dict[str, np.ndarray], - pd.DataFrame, + Dict[str, "numpy.ndarray"], + Dict[str, "numpy.ndarray"], + "pandas.DataFrame", ]: """Performs a cell query search for cluster centroids against the kNN. Parameters ---------- - data: pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData + data: anndata.AnnData Annotated data matrix with rows for cells and columns for genes. + Requires a layers["counts"]. cluster_key: str The obs column key that contains cluster labels. - cluster_label: optional, str + cluster_label: str, optional, default: None The cluster label of interest. If None, then get the centroids of all clusters, otherwise get only the centroid for the cluster of interest @@ -515,7 +492,7 @@ def search_centroids( max_dist: float, optional Assume k=1000000, then filter for cells that are within the max distance to the query. Overwrites the k parameter. - exclude_studies: List[str], optional + exclude_studies: List[str], optional, default = None A list of studies to exclude from the search, given as a list of str study names. WARNING: If you do not use max_dist, you will potentially get less than k hits as the study exclusion is performed after the search. @@ -536,9 +513,11 @@ def search_centroids( Examples -------- - >>> centroid_embeddings, cluster_idx, nn_idx, nn_dists, all_metadata = cq.search_centroids(data, "leidan") + >>> centroid_embeddings, cluster_idx, nn_idx, nn_dists, all_metadata = cq.search_cluster_centroids(data, "leidan") """ + from scimilarity.utils import get_cluster_centroids + centroids, cluster_idx = get_cluster_centroids( data, self.gene_order, cluster_key, cluster_label, skip_null=skip_null ) @@ -551,31 +530,28 @@ def search_centroids( if ef is None: ef = k - nn_idxs = {} - nn_dists = {} - metadata = {} - for row in range(centroid_embeddings.shape[0]): - ( - nn_idxs[cluster_idx[row]], - nn_dists[cluster_idx[row]], - metadata[cluster_idx[row]], - ) = self.search( - centroid_embeddings[row], - k=k, - ef=ef, - max_dist=max_dist, - exclude_studies=exclude_studies, - ) - metadata[cluster_idx[row]]["centroid"] = cluster_idx[row] - metadata[cluster_idx[row]]["nn_dist"] = nn_dists[cluster_idx[row]][0] + (nn_idxs, nn_dists, metadata) = self.search( + centroid_embeddings, + k=k, + ef=ef, + max_dist=max_dist, + exclude_studies=exclude_studies, + ) + + metadata["centroid"] = metadata["embedding_idx"].map( + {i: x for i, x in enumerate(cluster_idx)} + ) - all_metadata = pd.concat(metadata.values()) - all_metadata = all_metadata.set_index("index", drop=False) + nn_idxs_dict = {} + nn_dists_dict = {} + for i in range(len(cluster_idx)): + nn_idxs_dict[cluster_idx[i]] = [nn_idxs[i]] + nn_dists_dict[cluster_idx[i]] = [nn_dists[i]] return ( centroid_embeddings, cluster_idx, - nn_idxs, - nn_dists, - all_metadata, + nn_idxs_dict, + nn_dists_dict, + metadata, ) diff --git a/src/scimilarity/data_models.py b/src/scimilarity/data_models.py index c48c570..c75bdb0 100644 --- a/src/scimilarity/data_models.py +++ b/src/scimilarity/data_models.py @@ -1,16 +1,15 @@ -from collections import Counter -from typing import Optional - import anndata +from collections import Counter import numpy as np import pandas as pd import pytorch_lightning as pl import scanpy import torch from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler +from typing import Optional -from scimilarity.ontologies import import_cell_ontology, get_id_mapper from scimilarity.utils import align_dataset +from scimilarity.ontologies import import_cell_ontology, get_id_mapper class scDataset(Dataset): @@ -31,7 +30,7 @@ def __getitem__(self, idx): class MetricLearningDataModule(pl.LightningDataModule): - """A class to encapsulates the steps needed to model the data.""" + """A class to encapsulate the anndata needed to train the model.""" def __init__( self, @@ -49,9 +48,9 @@ def __init__( ---------- train_path: str Path to the training h5ad file. - val_path: str, optional + val_path: str, optional, default: None Path to the validataion h5ad file. - test_path: str, optional + test_path: str, optional, default: None Path to the test h5ad file. obs_field: str, default: "celltype_name" The obs key name containing celltype labels. @@ -145,6 +144,7 @@ def subset_valid_terms(self, data: anndata.AnnData) -> anndata.AnnData: An object containing the data whose celltype labels have valid ontology id. """ + valid_terms_idx = data.obs[self.obs_field].isin(self.name2id.keys()) if valid_terms_idx.any(): return data[valid_terms_idx] @@ -165,6 +165,7 @@ def get_data(self, filename: str, n_obs: Optional[int] = None): tuple A tuple containing the X matrix, ontology ids, and study. """ + data = anndata.read_h5ad(filename) if n_obs: # subset n_obs from data scanpy.pp.subsample(data, n_obs=n_obs) @@ -181,7 +182,7 @@ def get_data(self, filename: str, n_obs: Optional[int] = None): return data.X, data.obs.label_int.values, data.obs.study - def two_way_weighting(self, vec1, vec2): + def two_way_weighting(self, vec1, vec2) -> dict: """Two-way weighting. Parameters @@ -196,6 +197,7 @@ def two_way_weighting(self, vec1, vec2): dict A dictionary containing the two-way weighting. """ + counts = pd.crosstab(vec1, vec2) weights_matrix = (1 / counts).replace(np.inf, 0) return weights_matrix.unstack().to_dict() @@ -213,6 +215,7 @@ def get_sampler_weights(self, dataset: scDataset) -> WeightedRandomSampler: WeightedRandomSampler A WeightedRandomSampler object. """ + if dataset.study is None: class_sample_count = Counter(dataset.Y) sample_weights = torch.Tensor( @@ -245,6 +248,7 @@ def collate(self, batch): A Tuple[torch.Tensor, torch.Tensor, list] containing information on the collated tensors. """ + profiles, labels, studies = tuple( map(list, zip(*batch)) ) # tuple([list(t) for t in zip(*batch)]) @@ -262,6 +266,7 @@ def train_dataloader(self) -> DataLoader: DataLoader A DataLoader object containing the training dataset. """ + return DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -280,6 +285,7 @@ def val_dataloader(self) -> DataLoader: DataLoader A DataLoader object containing the validation dataset. """ + if self.val_dataset is None: return None return DataLoader( @@ -300,6 +306,7 @@ def test_dataloader(self) -> DataLoader: DataLoader A DataLoader object containing the test dataset. """ + if self.test_dataset is None: return None return DataLoader( diff --git a/src/scimilarity/interpreter.py b/src/scimilarity/interpreter.py index 23ce381..533ca5d 100644 --- a/src/scimilarity/interpreter.py +++ b/src/scimilarity/interpreter.py @@ -1,35 +1,26 @@ +from torch import nn from typing import Optional, Union -import matplotlib as mpl -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns -import torch -from captum.attr import IntegratedGradients -from scipy.sparse import csr_matrix -mpl.rcParams["pdf.fonttype"] = 42 - - -class SimpleDist(torch.nn.Module): +class SimpleDist(nn.Module): """Calculates the distance between representations""" - def __init__(self, encoder: torch.nn.Module): + def __init__(self, encoder: "torch.nn.Module"): """Constructor. - Parameters - ---------- - encoder: torch.nn.Module - The encoder pytorch object. - """ + Parameters + ---------- + encoder: torch.nn.Module + The encoder pytorch object. + """ + super().__init__() self.encoder = encoder def forward( self, - anchors: torch.Tensor, - negatives: torch.Tensor, + anchors: "torch.Tensor", + negatives: "torch.Tensor", ): """Forward. @@ -45,6 +36,7 @@ def forward( float Sum of squares distance for the encoded tensors. """ + f_anc = self.encoder(anchors) f_neg = self.encoder(negatives) return ((f_neg - f_anc) ** 2).sum(dim=1) @@ -55,7 +47,7 @@ class Interpreter: def __init__( self, - encoder: torch.nn.Module, + encoder: "torch.nn.Module", gene_order: list, ): """Constructor. @@ -71,15 +63,18 @@ def __init__( -------- >>> interpreter = Interpreter(CellEmbedding("/opt/data/model").model) """ + + from captum.attr import IntegratedGradients + self.encoder = encoder self.dist_ig = IntegratedGradients(SimpleDist(self.encoder)) self.gene_order = gene_order def get_attributions( self, - anchors: Union[torch.Tensor, np.ndarray, csr_matrix], - negatives: Union[torch.Tensor, np.ndarray, csr_matrix], - ) -> np.ndarray: + anchors: Union["torch.Tensor", "numpy.ndarray", "scipy.sparse.csr_matrix"], + negatives: Union["torch.Tensor", "numpy.ndarray", "scipy.sparse.csr_matrix"], + ) -> "numpy.ndarray": """Returns attributions, which can later be aggregated. High attributions for genes that are expressed more highly in the anchor and that affect the distance between anchors and negatives strongly. @@ -101,6 +96,10 @@ def get_attributions( >>> attr = interpreter.get_attributions(anchors, negatives) """ + import numpy as np + from scipy.sparse import csr_matrix + import torch + assert anchors.shape == negatives.shape if isinstance(anchors, np.ndarray): @@ -135,7 +134,7 @@ def get_attributions( return attr.detach().cpu().numpy() return attr.detach().numpy() - def get_ranked_genes(self, attrs: np.ndarray) -> pd.DataFrame: + def get_ranked_genes(self, attrs: "numpy.ndarray") -> "pandas.DataFrame": """Get the ranked gene list based on highest attributions. Parameters @@ -153,6 +152,9 @@ def get_ranked_genes(self, attrs: np.ndarray) -> pd.DataFrame: >>> attrs_df = interpreter.get_ranked_genes(attrs) """ + import numpy as np + import pandas as pd + mean_attrs = attrs.mean(axis=0) idx = mean_attrs.argsort()[::-1] df = { @@ -166,7 +168,7 @@ def get_ranked_genes(self, attrs: np.ndarray) -> pd.DataFrame: def plot_ranked_genes( self, - attrs_df: pd.DataFrame, + attrs_df: "pandas.DataFrame", n_plot: int = 15, filename: Optional[str] = None, ): @@ -186,13 +188,29 @@ def plot_ranked_genes( >>> interpreter.plot_ranked_genes(attrs_df) """ + import matplotlib.pyplot as plt + import matplotlib as mpl + import numpy as np + import seaborn as sns + + mpl.rcParams["pdf.fonttype"] = 42 + df = attrs_df.head(n_plot) ci = 1.96 * df["attribution_std"] / np.sqrt(df["cells"]) - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 3), dpi=200) - ax = sns.barplot(data=df, x="gene", y="attribution", yerr=ci, ax=ax) + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 2), dpi=200) + sns.barplot(ax=ax, data=df, x="gene", y="attribution", hue="gene", dodge=False) ax.set_yticks([]) plt.tick_params(axis="x", which="major", labelsize=8, labelrotation=90) + ax.errorbar( + df["gene"].values, + df["attribution"].values, + yerr=ci, + ecolor="black", + fmt="none", + ) + ax.get_legend().remove() + if filename: # save the figure fig.savefig(filename, bbox_inches="tight") diff --git a/src/scimilarity/nn_models.py b/src/scimilarity/nn_models.py index 1314ae5..e96a982 100644 --- a/src/scimilarity/nn_models.py +++ b/src/scimilarity/nn_models.py @@ -3,15 +3,15 @@ These are all you need for inference. """ -from typing import List - import torch -import torch.nn.functional as F from torch import nn +import torch.nn.functional as F +from typing import List class Encoder(nn.Module): """A class that encapsulates the encoder.""" + def __init__( self, n_genes: int, @@ -40,6 +40,7 @@ def __init__( residual: bool, default: False Use residual connections. """ + super().__init__() self.latent_dim = latent_dim self.network = nn.ModuleList() @@ -68,7 +69,20 @@ def __init__( # output layer self.network.append(nn.Linear(hidden_dim[-1], latent_dim)) - def forward(self, x) -> F.Tensor: + def forward(self, x) -> torch.Tensor: + """Forward. + + Parameters + ---------- + x: torch.Tensor + Input tensor corresponding to input layer. + + Returns + ------- + torch.Tensor + Output tensor corresponding to output layer. + """ + for i, layer in enumerate(self.network): if self.residual and (0 < i < len(self.network) - 1): x = layer(x) + x @@ -77,13 +91,14 @@ def forward(self, x) -> F.Tensor: return F.normalize(x, p=2, dim=1) def save_state(self, filename: str): - """Save state dictionary. + """Save model state. Parameters ---------- filename: str - Filename to save the state dictionary. + Filename to save the model state. """ + torch.save({"state_dict": self.state_dict()}, filename) def load_state(self, filename: str, use_gpu: bool = False): @@ -93,9 +108,10 @@ def load_state(self, filename: str, use_gpu: bool = False): ---------- filename: str Filename containing the model state. - use_gpu: bool + use_gpu: bool, default: False Boolean indicating whether or not to use GPUs. """ + if not use_gpu: ckpt = torch.load(filename, map_location=torch.device("cpu")) else: @@ -131,6 +147,7 @@ def __init__( residual: bool, default: False Use residual connections. """ + super().__init__() self.latent_dim = latent_dim self.network = nn.ModuleList() @@ -158,7 +175,19 @@ def __init__( # reconstruction layer self.network.append(nn.Linear(hidden_dim[-1], n_genes)) - def forward(self, x): + def forward(self, x) -> torch.Tensor: + """Forward. + + Parameters + ---------- + x: torch.Tensor + Input tensor corresponding to input layer. + + Returns + ------- + torch.Tensor + Output tensor corresponding to output layer. + """ for i, layer in enumerate(self.network): if self.residual and (0 < i < len(self.network) - 1): x = layer(x) + x @@ -167,13 +196,14 @@ def forward(self, x): return x def save_state(self, filename: str): - """Save state dictionary. + """Save model state. Parameters ---------- filename: str - Filename to save the state dictionary. + Filename to save the model state. """ + torch.save({"state_dict": self.state_dict()}, filename) def load_state(self, filename: str, use_gpu: bool = False): @@ -183,9 +213,10 @@ def load_state(self, filename: str, use_gpu: bool = False): ---------- filename: str Filename containing the model state. - use_gpu: bool - Boolean indicating whether to use GPUs. + use_gpu: bool, default: False + Boolean indicating whether or not to use GPUs. """ + if not use_gpu: ckpt = torch.load(filename, map_location=torch.device("cpu")) else: diff --git a/src/scimilarity/ontologies.py b/src/scimilarity/ontologies.py index 5b9501d..e04db45 100644 --- a/src/scimilarity/ontologies.py +++ b/src/scimilarity/ontologies.py @@ -1,32 +1,52 @@ -import itertools -from typing import Tuple - import networkx as nx -import numpy as np import obonet -import pandas as pd -from scipy.spatial.distance import cdist +from typing import Union, Tuple, List + +def subset_nodes_to_set(nodes, restricted_set: Union[list, set]) -> nx.DiGraph: + """Restrict nodes to a given set. + + Parameters + ---------- + nodes: networkx.DiGraph + Node graph. + restricted_set: list, set + Restricted node list. + + Returns + ------- + networkx.DiGraph + Node graph of restricted set. + + Examples + -------- + >>> subset_nodes_to_set(nodes, node_list) + """ -def subset_nodes_to_set(nodes, restricted_set): return {node for node in nodes if node in restricted_set} def import_cell_ontology( - url="http://purl.obolibrary.org/obo/cl/cl-basic.obo", + url="/gstore/data/omni/scdb/cell-ontology-2022-09-15/cl-basic.obo", + # url="http://purl.obolibrary.org/obo/cl/cl-basic.obo", ) -> nx.DiGraph: - """Read the taxrank ontology. + """Import taxrank cell ontology. Parameters ---------- - url: str - URL for the cell ontology. + url: str, default: "/gstore/data/omni/scdb/cell-ontology-2022-09-15/cl-basic.obo" + The url of the ontology obo file. Returns ------- networkx.DiGraph - DiGraph containing the cell ontology. + Node graph of ontology. + + Examples + -------- + >>> onto = import_cell_ontology() """ + graph = obonet.read_obo(url).reverse() # flip for intuitiveness return nx.DiGraph(graph) # return as graph @@ -34,18 +54,23 @@ def import_cell_ontology( def import_uberon_ontology( url="http://purl.obolibrary.org/obo/uberon/basic.obo", ) -> nx.DiGraph: - """Read the uberon ontology. + """Import uberon tissue ontology. Parameters ---------- - url: str - URL for the uberon ontology. + url: str, default: "http://purl.obolibrary.org/obo/uberon/basic.obo" + The url of the ontology obo file. Returns ------- networkx.DiGraph - DiGraph containing the uberon ontology. + Node graph of ontology. + + Examples + -------- + >>> onto = import_uberon_ontology() """ + graph = obonet.read_obo(url).reverse() # flip for intuitiveness return nx.DiGraph(graph) # return as graph @@ -53,18 +78,23 @@ def import_uberon_ontology( def import_doid_ontology( url="http://purl.obolibrary.org/obo/doid.obo", ) -> nx.DiGraph: - """Read the doid ontology. + """Import doid disease ontology. Parameters ---------- - url: str - URL for the doid ontology. + url: str, default: "http://purl.obolibrary.org/obo/doid.obo" + The url of the ontology obo file. Returns ------- networkx.DiGraph - DiGraph containing the doid ontology. + Node graph of ontology. + + Examples + -------- + >>> onto = import_doid_ontology() """ + graph = obonet.read_obo(url).reverse() # flip for intuitiveness return nx.DiGraph(graph) # return as graph @@ -72,18 +102,23 @@ def import_doid_ontology( def import_mondo_ontology( url="http://purl.obolibrary.org/obo/mondo.obo", ) -> nx.DiGraph: - """Read the mondo ontology. + """Import mondo disease ontology. Parameters ---------- - url: str - URL for the mondo ontology. + url: str, default: "http://purl.obolibrary.org/obo/mondo.obo" + The url of the ontology obo file. Returns ------- networkx.DiGraph - DiGraph containing the mondo ontology. + Node graph of ontology. + + Examples + -------- + >>> onto = import_mondo_ontology() """ + graph = obonet.read_obo(url).reverse() # flip for intuitiveness return nx.DiGraph(graph) # return as graph @@ -94,39 +129,132 @@ def get_id_mapper(graph) -> dict: Parameters ---------- graph: networkx.DiGraph - onotology graph. + Node graph. Returns ------- dict - Dictionary containing the term ID to name mapper. + The id to name mapping dictionary. + + Examples + -------- + >>> id2name = get_id_mapper(onto) """ + return {id_: data.get("name") for id_, data in graph.nodes(data=True)} -def get_children(graph, node, node_list=None): +def get_children(graph, node, node_list=None) -> nx.DiGraph: + """Get children nodes of a given node. + + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + node: str + ID of given node. + node_list: list, set, optional, default: None + A restricted node list for filtering. + + Returns + ------- + networkx.DiGraph + Node graph of children. + + Examples + -------- + >>> children = get_children(onto, id) + """ + children = {item[1] for item in graph.out_edges(node)} if node_list is None: return children return subset_nodes_to_set(children, node_list) -def get_parents(graph, node, node_list=None): +def get_parents(graph, node, node_list=None) -> nx.DiGraph: + """Get parent nodes of a given node. + + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + node: str + ID of given node. + node_list: list, set, optional, default: None + A restricted node list for filtering. + + Returns + ------- + networkx.DiGraph + Node graph of parents. + + Examples + -------- + >>> parents = get_parents(onto, id) + """ + parents = {item[0] for item in graph.in_edges(node)} if node_list is None: return parents return subset_nodes_to_set(parents, node_list) -def get_siblings(graph, node): +def get_siblings(graph, node, node_list=None) -> nx.DiGraph: + """Get sibling nodes of a given node. + + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + node: str + ID of given node. + node_list: list, set, optional, default: None + A restricted node list for filtering. + + Returns + ------- + networkx.DiGraph + Node graph of siblings. + + Examples + -------- + >>> siblings = get_siblings(onto, id) + """ + parents = get_parents(graph, node) siblings = set.union( *[set(get_children(graph, parent)) for parent in parents] ) - set([node]) - return siblings + if node_list is None: + return siblings + return subset_nodes_to_set(siblings, node_list) + + +def get_all_ancestors(graph, node, node_list=None, inclusive=False) -> nx.DiGraph: + """Get all ancestor nodes of a given node. + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + node: str + ID of given node. + node_list: list, set, optional, default: None + A restricted node list for filtering. + inclusive: bool, default: False + Whether to include the given node in the results. + + Returns + ------- + networkx.DiGraph + Node graph of ancestors. + + Examples + -------- + >>> ancestors = get_all_ancestors(onto, id) + """ -def get_all_ancestors(graph, node, node_list=None, inclusive=False): ancestors = nx.ancestors(graph, node) if inclusive: ancestors = ancestors | {node} @@ -136,7 +264,30 @@ def get_all_ancestors(graph, node, node_list=None, inclusive=False): return subset_nodes_to_set(ancestors, node_list) -def get_all_descendants(graph, nodes, node_list=None, inclusive=False): +def get_all_descendants(graph, nodes, node_list=None, inclusive=False) -> nx.DiGraph: + """Get all descendant nodes of given node(s). + + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + nodes: str, list + ID of given node or a list of node IDs. + node_list: list, set, optional, default: None + A restricted node list for filtering. + inclusive: bool, default: False + Whether to include the given node in the results. + + Returns + ------- + networkx.DiGraph + Node graph of descendants. + + Examples + -------- + >>> descendants = get_all_descendants(onto, id) + """ + if isinstance(nodes, str): # one term id descendants = nx.descendants(graph, nodes) else: # list of term ids @@ -150,45 +301,150 @@ def get_all_descendants(graph, nodes, node_list=None, inclusive=False): return subset_nodes_to_set(descendants, node_list) -def get_lowest_common_ancestor(graph, node1, node2): +def get_lowest_common_ancestor(graph, node1, node2) -> nx.DiGraph: + """Get the lowest common ancestor of two nodes. + + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + node1: str + ID of node1. + node2: str + ID of node2. + + Returns + ------- + networkx.DiGraph + Node graph of descendants. + + Examples + -------- + >>> common_ancestor = get_lowest_common_ancestor(onto, id1, id2) + """ + return nx.algorithms.lowest_common_ancestors.lowest_common_ancestor( graph, node1, node2 ) -def ontology_similarity(graph, term1, term2, blacklisted_terms=None): - common_ancestors = get_all_ancestors(graph, term1).intersection( - get_all_ancestors(graph, term2) +def ontology_similarity(graph, node1, node2, restricted_set=None) -> int: + """Get the ontology similarity of two terms based on the number of common ancestors. + + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + node1: str + ID of node1. + node2: str + ID of node2. + restricted_set: set + Set of restricted nodes to remove from their common ancestors. + + Returns + ------- + int + Number of common ancestors. + + Examples + -------- + >>> onto_sim = ontology_similarity(onto, id1, id2) + """ + + common_ancestors = get_all_ancestors(graph, node1).intersection( + get_all_ancestors(graph, node2) ) - if blacklisted_terms is not None: - common_ancestors -= blacklisted_terms + if restricted_set is not None: + common_ancestors -= restricted_set return len(common_ancestors) -def all_pair_similarities(graph, used_terms, blacklisted_terms=None): - node_pairs = itertools.combinations(used_terms, 2) - similarity_df = pd.DataFrame(0, index=used_terms, columns=used_terms) - for (term1, term2) in node_pairs: +def all_pair_similarities(graph, nodes, restricted_set=None) -> "pandas.DataFrame": + """Get the ontology similarity of all pairs in a node list. + + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + nodes: list, set + List of nodes. + restricted_set: set + Set of restricted nodes to remove from their common ancestors. + + Returns + ------- + pandas.DataFrame + A pandas dataframe showing similarity for all node pairs. + + Examples + -------- + >>> onto_sim = all_pair_similarities(onto, id1, id2) + """ + + import itertools + import pandas as pd + + node_pairs = itertools.combinations(nodes, 2) + similarity_df = pd.DataFrame(0, index=nodes, columns=nodes) + for node1, node2 in node_pairs: s = ontology_similarity( - graph, term1, term2, blacklisted_terms=blacklisted_terms + graph, node1, node2, restricted_set=restricted_set ) # too slow, cause recomputes each ancestor - similarity_df.at[term1, term2] = s + similarity_df.at[node1, node2] = s return similarity_df + similarity_df.T def ontology_silhouette_width( - embeddings: np.ndarray, - labels: list, + embeddings: "numpy.ndarray", + labels: List[str], onto: nx.DiGraph, name2id: dict, metric: str = "cosine", -) -> Tuple[float, pd.DataFrame]: +) -> Tuple[float, "pandas.DataFrame"]: + """Get the average silhouette width of celltypes, being aware of cell ontology such that + ancestors are not considered inter-cluster and descendants are considered intra-cluster. + + Parameters + ---------- + embeddings: numpy.ndarray + Cell embeddings. + labels: List[str] + Celltype names. + onto: + Cell ontology graph object. + name2id: dict + A mapping dictionary of celltype name to id + metric: str, default: "cosine" + The distance metric to use for scipy.spatial.distance.cdist(). + + Returns + ------- + asw: float + The average silhouette width. + asw_df: pandas.DataFrame + A dataframe containing silhouette width as well as + inter and intra cluster distances for all cell types. + + Examples + -------- + >>> asw, asw_df = ontology_silhouette_width( + embeddings, labels, onto, name2id, metric="cosine" + ) + """ + + import numpy as np + import pandas as pd + from scipy.spatial.distance import cdist + data = {"label": [], "intra": [], "inter": [], "sw": []} for i, name1 in enumerate(labels): term_id1 = name2id[name1] ancestors = get_all_ancestors(onto, term_id1) descendants = get_all_descendants(onto, term_id1) - distances = cdist(embeddings[i].reshape(1, -1), embeddings, metric=metric).flatten() + distances = cdist( + embeddings[i].reshape(1, -1), embeddings, metric=metric + ).flatten() a_i = [] b_i = {} diff --git a/src/scimilarity/training_models.py b/src/scimilarity/training_models.py index 9ba1997..d9a4a8d 100644 --- a/src/scimilarity/training_models.py +++ b/src/scimilarity/training_models.py @@ -1,17 +1,16 @@ -import json -import os from datetime import datetime -from typing import List, Optional - import hnswlib +import json +import os import pandas as pd import pytorch_lightning as pl import torch -import torch.nn.functional as F from torch import nn +import torch.nn.functional as F +from typing import Optional, List from scimilarity.triplet_selector import TripletSelector -from scimilarity.nn_models import Decoder, Encoder +from scimilarity.nn_models import Encoder, Decoder class TripletLoss(torch.nn.TripletMarginLoss): @@ -215,8 +214,12 @@ def __init__( self.scheduler = None + self.val_step_outputs = [] + self.test_step_outputs = [] + def configure_optimizers(self): """Configure optimizers.""" + optimizer = torch.optim.AdamW(self.parameters(), self.lr, weight_decay=self.l2) self.scheduler = { "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR( @@ -231,34 +234,95 @@ def configure_optimizers(self): } # pytorch-lightning required format def forward(self, x): + """Forward. + + Parameters + ---------- + x: torch.Tensor + Input tensor corresponding to input layer. + + Returns + ------- + z: torch.Tensor + Output tensor corresponding to the last encoder layer. + x_hat: torch.Tensor + Output tensor corresponding to the last decoder layer. + """ + z = self.encoder(x) x_hat = self.decoder(z) return z, x_hat def get_losses(self, batch, use_studies: bool = True): + """Calculate the triplet and reconstruction loss. + + Parameters + ---------- + batch: + A batch as defined by a pytorch DataLoader. + use_studies: bool, default: True + Whether to use studies metadata in mining triplets and calculating triplet loss + + Returns + ------- + triplet_loss: torch.Tensor + Triplet loss. + reconstruction_loss: torch.Tensor + reconstruction loss + num_hard_triplets: torch.Tensor + Number of hard triplets. + num_viable_triplets: torch.Tensor + Number of viable triplets. + """ + cells, labels, studies = batch if not use_studies: studies = None embedding, reconstruction = self(cells) - triplet_losses, num_hard_triplets, num_viable_triplets = self.triplet_loss_fn( + triplet_loss, num_hard_triplets, num_viable_triplets = self.triplet_loss_fn( embedding, labels, self.trainer.datamodule.int2label, studies ) reconstruction_loss = self.mse_loss_fn(cells, reconstruction) return ( - triplet_losses, + triplet_loss, reconstruction_loss, num_hard_triplets, num_viable_triplets, ) - def mixed_loss(self, triplet_loss, reconstruction_loss): + def get_mixed_loss(self, triplet_loss, reconstruction_loss): + """Calculate the mixed loss. + + Parameters + ---------- + triplet_loss: torch.Tensor + Triplet loss. + reconstruction_loss: torch.Tensor + reconstruction loss + + Returns + ------- + torch.Tensor + Mixed loss. + """ + if self.alpha == 0: return reconstruction_loss if self.alpha == 1: return triplet_loss return (self.alpha * triplet_loss) + ((1.0 - self.alpha) * reconstruction_loss) - def training_step(self, batch, batch_idx): # pytorch-lightning required parameters + def training_step(self, batch, batch_idx): + """Pytorch-lightning training step. + + Parameters + ---------- + batch: + A batch as defined by a pytorch DataLoader. + batch_idx: + A batch index as defined by a pytorch-lightning. + """ + ( triplet_losses, reconstruction_loss, @@ -270,7 +334,7 @@ def training_step(self, batch, batch_idx): # pytorch-lightning required paramet num_nonzero_loss = (triplet_losses > 0).sum(dtype=torch.float).detach() hard_triplets = num_hard_triplets / num_viable_triplets - loss = self.mixed_loss(triplet_loss, reconstruction_loss) + loss = self.get_mixed_loss(triplet_loss, reconstruction_loss) current_lr = self.scheduler["scheduler"].get_last_lr()[0] @@ -324,29 +388,76 @@ def training_step(self, batch, batch_idx): # pytorch-lightning required paramet "train_num_viable_triplets": num_viable_triplets, } - def validation_step( - self, batch, batch_idx, dataloader_idx: Optional[int] = None - ): # pytorch-lightning required parameters + def on_validation_epoch_start(self): + """Pytorch-lightning validation epoch start.""" + super().on_validation_epoch_start() + self.val_step_outputs = [] + + def validation_step(self, batch, batch_idx): + """Pytorch-lightning validation step. + + Parameters + ---------- + batch: + A batch as defined by a pytorch DataLoader. + batch_idx: + A batch index as defined by a pytorch-lightning. + """ + if self.trainer.datamodule.val_dataset is None: return {} return self._eval_step(batch, prefix="val") - def validation_epoch_end(self, step_outputs: list): + def on_validation_epoch_end(self): + """Pytorch-lightning validation epoch end evaluation.""" + if self.trainer.datamodule.val_dataset is None: return {} - return self._eval_epoch(step_outputs, prefix="val") + return self._eval_epoch(prefix="val") + + def on_test_epoch_start(self): + """Pytorch-lightning test epoch start.""" + super().on_test_epoch_start() + self.test_step_outputs = [] + + def test_step(self, batch, batch_idx): + """Pytorch-lightning test step. + + Parameters + ---------- + batch: + A batch as defined by a pytorch DataLoader. + batch_idx: + A batch index as defined by a pytorch-lightning. + """ - def test_step(self, batch, batch_idx): # pytorch-lightning required parameters if self.trainer.datamodule.test_dataset is None: return {} return self._eval_step(batch, prefix="test") - def test_epoch_end(self, step_outputs: list): + def on_test_epoch_end(self): + """Pytorch-lightning test epoch end evaluation.""" + if self.trainer.datamodule.test_dataset is None: return {} - return self._eval_epoch(step_outputs, prefix="test") + return self._eval_epoch(prefix="test") def _eval_step(self, batch, prefix: str): + """Evaluation of validation or test step. + + Parameters + ---------- + batch: + A batch as defined by a pytorch DataLoader. + prefix: str + A string prefix to label validation versus test evaluation. + + Returns + ------- + dict + A dictionary containing step evaluation metrics. + """ + ( triplet_losses, reconstruction_loss, @@ -358,7 +469,7 @@ def _eval_step(self, batch, prefix: str): num_nonzero_loss = (triplet_losses > 0).sum() hard_triplets = num_hard_triplets / num_viable_triplets - loss = self.mixed_loss(triplet_loss, reconstruction_loss) + loss = self.get_mixed_loss(triplet_loss, reconstruction_loss) losses = { f"{prefix}_loss": loss, @@ -369,9 +480,32 @@ def _eval_step(self, batch, prefix: str): f"{prefix}_num_hard_triplets": num_hard_triplets, f"{prefix}_num_viable_triplets": num_viable_triplets, } + + if prefix == "val": + self.val_step_outputs.append(losses) + elif prefix == "test": + self.test_step_outputs.append(losses) return losses - def _eval_epoch(self, step_outputs: list, prefix: str): + def _eval_epoch(self, prefix: str): + """Evaluation of validation or test epoch. + + Parameters + ---------- + prefix: str + A string prefix to label validation versus test evaluation. + + Returns + ------- + dict + A dictionary containing epoch evaluation metrics. + """ + + if prefix == "val": + step_outputs = self.val_step_outputs + elif prefix == "test": + step_outputs = self.test_step_outputs + loss = torch.Tensor([step[f"{prefix}_loss"] for step in step_outputs]).mean() triplet_loss = torch.Tensor( [step[f"{prefix}_triplet_loss"] for step in step_outputs] @@ -411,42 +545,9 @@ def _eval_epoch(self, step_outputs: list, prefix: str): } return losses - def get_dataset_embedding(self, dataset): - embedding_parts = [] - labels_parts = [] - study_parts = [] - - self.encoder.eval() - with torch.inference_mode(): - buffer_size = 10000 - for i in range(0, len(dataset), buffer_size): - profiles, labels_part, studies = dataset[i : i + buffer_size] - profiles = torch.Tensor(profiles).cuda() - - embedding_parts.append(self.encoder(profiles).detach().cpu()) - labels_parts.extend(labels_part) - study_parts.extend(studies.values) - - return torch.vstack(embedding_parts), pd.Categorical(labels_parts), study_parts - - def build_knn_classifier( - self, embeddings: torch.Tensor, ef_construction=1000, M=80 - ): - n_cells, latent_dim = embeddings.shape - nn_reference = hnswlib.Index( - space="cosine", dim=latent_dim - ) # possible options are l2, cosine, or ip - nn_reference.init_index( - max_elements=n_cells, ef_construction=ef_construction, M=M - ) - nn_reference.set_ef(ef_construction) - nn_reference.add_items(embeddings, range(len(embeddings))) - return nn_reference - def save_all( self, model_path: str, - save_knn: bool = False, ef_construction: int = 1000, M: int = 80, ): @@ -506,15 +607,6 @@ def save_all( with open(os.path.join(model_path, "metadata.json"), "w") as f: f.write(json.dumps(meta_data)) - if save_knn: # build and save KNN model - embeddings, labels, _ = self.get_dataset_embedding( - self.trainer.datamodule.train_dataset - ) - knn = self.build_knn_classifier( - embeddings, ef_construction=ef_construction, M=M - ) - knn.save_index(os.path.join(model_path, "kNN")) - def load_state( self, encoder_filename: str, @@ -522,6 +614,20 @@ def load_state( use_gpu: bool = False, freeze: bool = False, ): + """Load model state. + + Parameters + ---------- + encoder_filename: str + Filename containing the encoder model state. + decoder_filename: str + Filename containing the decoder model state. + use_gpu: bool, default: False + Boolean indicating whether or not to use GPUs. + freeze: bool, default: False + Freeze all but bottleneck layer, used if pretraining the encoder. + """ + self.encoder.load_state(encoder_filename, use_gpu) self.decoder.load_state(decoder_filename, use_gpu) diff --git a/src/scimilarity/triplet_selector.py b/src/scimilarity/triplet_selector.py index a8a5094..d876c04 100644 --- a/src/scimilarity/triplet_selector.py +++ b/src/scimilarity/triplet_selector.py @@ -1,23 +1,20 @@ from itertools import combinations -from typing import Optional, Union - import numpy as np import random import torch +from typing import Union, Optional from scimilarity.ontologies import ( import_cell_ontology, + get_id_mapper, get_all_ancestors, get_all_descendants, - get_id_mapper, get_parents, ) class TripletSelector: - """ - For each anchor-positive pair, mine negative samples to create a triplet. - """ + """For each anchor-positive pair, mine negative samples to create a triplet.""" def __init__( self, @@ -26,6 +23,28 @@ def __init__( perturb_labels: bool = False, perturb_labels_fraction: float = 0.5, ): + """Constructor. + + Parameters + ---------- + margin: float + Triplet loss margin. + negative_selection: str + Method for negative selection: {"semihard", "hardest", "random"} + perturb_labels: bool, default: False + Whether to perturb the ontology labels by coarse graining one level up. + perturb_labels_fraction: float, default: 0.5 + The fraction of labels to perturb + + Examples + -------- + >>> triplet_selector = TripletSelector(margin=0.05, + negative_selection="semihard", + perturb_labels=True, + perturb_labels_fraction=0.5, + ) + """ + self.margin = margin self.negative_selection = negative_selection @@ -43,6 +62,30 @@ def get_triplets_idx( int2label: dict, studies: Optional[Union[np.ndarray, torch.Tensor, list]] = None, ): + """Get triplets as anchor, positive, and negative cell indices. + + Parameters + ---------- + embeddings: numpy.ndarray, torch.Tensor + Cell embeddings. + labels: numpy.ndarray, torch.Tensor + Cell labels in integer form. + int2label: dict + Dictionary to map labels in integer form to string + studies: numpy.ndarray, torch.Tensor, optional, default: None + Studies metadata for each cell. + + Returns + ------- + triplets: Tuple[List, List, List] + A tuple of lists containing anchor, positive, and negative cell indices. + num_hard_triplets: int + Number of hard triplets. + num_viable_triplets: int + Number of viable triplets. + ) + """ + if isinstance(embeddings, torch.Tensor): distance_matrix = self.pdist(embeddings.detach().cpu().numpy()) else: @@ -91,8 +134,8 @@ def get_triplets_idx( break # label perturbed, skip the rest of the ancestors triplets = [] - total_hard_triplets = 0 - total_viable_triplets = 0 + num_hard_triplets = 0 + num_viable_triplets = 0 for label in labels_set: term_id = self.name2id[int2label[label]] ancestors = get_all_ancestors(self.onto, term_id) @@ -126,8 +169,8 @@ def get_triplets_idx( - distance_matrix[[anchor_positive[0]], negative_indices] + self.margin ) - total_hard_triplets += (loss_values > 0).sum() - total_viable_triplets += loss_values.size + num_hard_triplets += (loss_values > 0).sum() + num_viable_triplets += loss_values.size # select one negative for anchor positive pair based on selection function if self.negative_selection == "semihard": @@ -158,8 +201,8 @@ def get_triplets_idx( positive_idx, negative_idx, ), - total_hard_triplets, - total_viable_triplets, + num_hard_triplets, + num_viable_triplets, ) def get_triplets( @@ -169,10 +212,33 @@ def get_triplets( int2label: dict, studies: Optional[Union[np.ndarray, torch.Tensor, list]] = None, ): + """Get triplets as anchor, positive, and negative cell embeddings. + + Parameters + ---------- + embeddings: numpy.ndarray, torch.Tensor + Cell embeddings. + labels: numpy.ndarray, torch.Tensor + Cell labels in integer form. + int2label: dict + Dictionary to map labels in integer form to string + studies: numpy.ndarray, torch.Tensor, optional, default: None + Studies metadata for each cell. + + Returns + ------- + triplets: Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray] + A tuple of numpy arrays containing anchor, positive, and negative cell embeddings. + num_hard_triplets: int + Number of hard triplets. + num_viable_triplets: int + Number of viable triplets. + """ + ( triplets_idx, - total_hard_triplets, - total_viable_triplets, + num_hard_triplets, + num_viable_triplets, ) = self.get_triplets_idx(embeddings, labels, int2label, studies) anchor_idx, positive_idx, negative_idx = triplets_idx return ( @@ -181,11 +247,24 @@ def get_triplets( embeddings[positive_idx], embeddings[negative_idx], ), - total_hard_triplets, - total_viable_triplets, + num_hard_triplets, + num_viable_triplets, ) - def pdist(self, vectors): + def pdist(self, vectors: np.ndarray): + """Get pair-wise distance between all cell embeddings. + + Parameters + ---------- + vectors: numpy.ndarray + Cell embeddings. + + Returns + ------- + numpy.ndarray + Distance matrix of cell embeddings. + """ + vectors_squared_sum = (vectors**2).sum(axis=1) distance_matrix = ( -2 * np.matmul(vectors, np.matrix.transpose(vectors)) @@ -195,14 +274,53 @@ def pdist(self, vectors): return distance_matrix def hardest_negative(self, loss_values): + """Get hardest negative. + + Parameters + ---------- + loss_values: numpy.ndarray + Triplet loss of all negatives for given anchor positive pair. + + Returns + ------- + int + Index of selection. + """ + hard_negative = np.argmax(loss_values) return hard_negative if loss_values[hard_negative] > 0 else None def random_negative(self, loss_values): + """Get random negative. + + Parameters + ---------- + loss_values: numpy.ndarray + Triplet loss of all negatives for given anchor positive pair. + + Returns + ------- + int + Index of selection. + """ + hard_negatives = np.where(loss_values > 0)[0] return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None def semihard_negative(self, loss_values): + """Get a random semihard negative. + + Parameters + ---------- + loss_values: numpy.ndarray + Triplet loss of all negatives for given anchor positive pair. + + Returns + ------- + int + Index of selection. + """ + semihard_negatives = np.where( np.logical_and(loss_values < self.margin, loss_values > 0) )[0] diff --git a/src/scimilarity/utils.py b/src/scimilarity/utils.py index c75dda2..a3a6172 100644 --- a/src/scimilarity/utils.py +++ b/src/scimilarity/utils.py @@ -1,79 +1,208 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Union, Tuple, List -import anndata -import numpy as np -import pandas as pd -import pegasusio as pgio -import scanpy as sc -from numba import njit -from scipy.sparse import csr_matrix +def get_pseudobulk_values( + data: "anndata.AnnData", +) -> ["numpy.ndarray", "numpy.ndarray", "numpy.ndarray"]: + """Get pseudobulk values from AnnData as numpy arrays. -def check_dataset( - data: Union[anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData], - target_gene_order: np.ndarray, - gene_overlap_threshold: int = 10000, + Parameters + ---------- + data: anndata.AnnData + Annotated data matrix with rows for cells and columns for genes. + + Returns + ------- + counts: numpy.ndarray + 1 x n_genes numpy array of sum of values in layer "counts". + detection: numpy.ndarray + 1 x n_genes numpy array of mean detection values based on layer "counts". + + Examples + -------- + >>> counts, detection = get_pseudobulk_values(data) + """ + + import numpy as np + + if "counts" not in data.layers: + raise ValueError(f"Raw counts matrix not found in layers['counts'].") + + counts = np.array(data.layers["counts"].sum(axis=0)).reshape(-1, data.shape[1]) + + detection = data.layers["counts"].copy() + detection[detection > 0] = 1.0 + detection[detection < 0] = 0.0 + detection = np.array(detection.mean(axis=0)).reshape(-1, data.shape[1]) + + return counts, detection + + +def pseudobulk_anndata( + data: "anndata.AnnData", + pseudobulk_label: str, + groupby_labels: Optional[list] = None, + qc_filters: Optional[dict] = None, + min_num_cells: int = 1, + only_orig_genes: bool = False, ): - """Check dataset to see if it able to be processed. + """Pseudobulk an AnnData and return a new AnnData. Parameters ---------- - data: pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData + data: anndata.AnnData Annotated data matrix with rows for cells and columns for genes. - target_gene_order: numpy.ndarray - An array containing the gene space. - gene_overlap_threshold: int, default: 10000 - The minimum number of genes in common between data and target_gene_order to be valid. + pseudobulk_label: str + Column label for basis of pseudobulk, typically the celltype name column. + groupby_labels: list, optional, default: None + Optional list of labels to groupby prior to pseudobulking based on pseudobulk_label. + We will always add pseudobulk_label into groupby_labels if it does not exist. + For example: ["sample", "tissue", "disease", "celltype_name"] + will groupby these columns and perform pseudobulking based on these groups. + qc_filters: dict, optional, default: None + Dictionary containing cell filters to perform prior to pseudobulking: + "mito_percent": max percent of reads in mitochondrial genes + "min_counts": min read count for cell + "min_genes": min number of genes with reads for cell + "max_nn_dist": max nearest neighbor distance to a reference label for predicted labels. + min_num_cells: int, default: 1 + The minimum number of cells in a pseudobulk in order to be considered. + only_orig_genes: bool, default: False + Account for an aligned gene space and mask non original genes to the dataset with NaN as + their pseudobulk. Assumes the original gene list is in data.uns["orig_genes"]. Examples -------- - >>> ca = CellAnnotation(model_path="/opt/data/model") - >>> check_dataset(data, ca.gene_order, gene_overlap_threshold=10000) + >>> pseudobulk_label = "celltype_name" + >>> groupby_labels = ["sample", "tissue_raw", pseudobulk_label] + >>> qc_filters = {"mito_percent": 20.0, "min_counts": 1000, "min_genes": 500, "max_nn_dist": 0.03, "max_nn_dist_col": "min_dist"} + >>> pseudobulk = pseudobulk_anndata(data, pseudobulk_label, groupby_labels, qc_filters=qc_filters, only_orig_genes=True) """ - if type(data) not in [anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData]: - raise ValueError(f"Unknown data type {type(data)}.") + import anndata + from collections import Counter + import numpy as np + import pandas as pd + import scanpy as sc + from scipy.sparse import csr_matrix - # check gene overlap - n_genes = sum(data.var.index.isin(target_gene_order)) - - if n_genes < gene_overlap_threshold: - raise RuntimeError( - f"Dataset incompatible: gene overlap less than {gene_overlap_threshold}" - ) + if "counts" not in data.layers: + raise ValueError(f"Raw counts matrix not found in layers['counts'].") - # check if count matrix exists - counts_exist = False - if isinstance(data, pgio.MultimodalData): - data = data.get_data(data.list_data()[0]) - if isinstance(data, pgio.UnimodalData): - counts_exist = "counts" in data.list_keys() or "raw.X" in data.list_keys() - if isinstance(data, anndata.AnnData): - counts_exist = "counts" in data.layers + if qc_filters is not None: + # determine prefix for mitochondrial genes + mito_prefix = "MT-" + if any(data.var.index.str.startswith("mt-")) is True: + mito_prefix = "mt-" - if not counts_exist: - raise RuntimeError("Dataset incompatible: no counts matrix found") + mito_percent = qc_filters.get("mito_percent", 100.0) + min_counts = qc_filters.get("min_counts", None) + min_genes = qc_filters.get("min_genes", None) + max_nn_dist = qc_filters.get("max_nn_dist", 0.03) + max_nn_dist_col = qc_filters.get("max_nn_dist_col", "nn_dist") + + data = data.copy() + data.var["mt"] = data.var_names.str.startswith(mito_prefix) + sc.pp.calculate_qc_metrics( + data, + qc_vars=["mt"], + percent_top=None, + log1p=False, + inplace=True, + layer="counts", + ) + data = data[data.obs["pct_counts_mt"] <= mito_percent].copy() + if min_counts is not None: + data = data[data.obs["total_counts"] >= min_counts].copy() + if min_genes is not None: + data = data[data.obs["n_genes_by_counts"] >= min_genes].copy() + if max_nn_dist_col in data.obs.columns: + data = data[data.obs[max_nn_dist_col] <= max_nn_dist].copy() + + counts_list = [] + detection_list = [] + obs_list = [] + if groupby_labels is None: + classes = Counter(data.obs[pseudobulk_label]) + for c in classes: + subset = data[data.obs[pseudobulk_label] == c] + counts, detection = get_pseudobulk_values(subset) + counts_list.append(counts) + detection_list.append(detection) + + # construct the adata + meta = pd.DataFrame( + {pseudobulk_label: [c], "cells": [classes[c]]} + ).set_index(pseudobulk_label, drop=False) + meta.index = meta.index.astype(str) + obs_list.append(meta) + else: + if pseudobulk_label not in groupby_labels: + groupby_labels.append(pseudobulk_label) + + # group by labels + df_sample = data.obs.groupby(groupby_labels, observed=True).size() + df_sample = df_sample[df_sample > 0].reset_index(name="cells") + + # use groups to perform pseudobulk + for i, row in df_sample.iterrows(): + num_cells = row["cells"] + row = row.drop("cells") + subset = data[(data.obs[list(row.index)] == row).all(axis=1)] + counts, detection = get_pseudobulk_values(subset) + counts_list.append(counts) + detection_list.append(detection) + + # Construct the adata + meta = ( + pd.DataFrame(row) + .transpose() + .astype("category") + .set_index(pseudobulk_label, drop=False) + ) + meta["cells"] = num_cells + meta.index = meta.index.astype(str) + obs_list.append(meta) + + if len(counts_list) == 0: + return None + + counts = np.vstack(counts_list) + detection = np.vstack(detection_list) + adata = anndata.AnnData( + X=csr_matrix(counts.shape), + obs=pd.concat(obs_list), + var=pd.DataFrame(index=data.var.index), + ) + adata.layers["counts"] = counts + adata.layers["detection"] = detection + + if min_num_cells > 1: + adata = adata[adata.obs["cells"] >= min_num_cells].copy() + if only_orig_genes and "uns" in dir(data) and "orig_genes" in data.uns: + orig_genes = set(data.uns["orig_genes"]) + not_orig_genes_idx = [ + i for i, x in enumerate(data.var.index.tolist()) if x not in orig_genes + ] + adata[:, not_orig_genes_idx] = np.nan + adata.layers["counts"][:, not_orig_genes_idx] = np.nan + adata.layers["detection"][:, not_orig_genes_idx] = np.nan + return adata def lognorm_counts( - data: Union[anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData], - clip_threshold: Optional[float] = None, - clip_threshold_percentile: Optional[float] = None, -) -> Union[anndata.AnnData, pgio.UnimodalData]: + data: "anndata.AnnData", +) -> "anndata.AnnData": """Log normalize the gene expression raw counts (per 10k). Parameters ---------- - data: pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData + data: anndata.AnnData Annotated data matrix with rows for cells and columns for genes. - clip_threshold: float, optional - Clip the data to the given max value. - clip_threshold_percentile: float, optional - Clip the data to the value at the given data percentile. Returns ------- - pegasusio.UnimodalData, anndata.AnnData + anndata.AnnData A data object with normalized data that is ready to be used in further processes. Examples @@ -81,54 +210,47 @@ def lognorm_counts( >>> data = lognorm_counts(data) """ - if type(data) not in [anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData]: - raise ValueError(f"Unknown data type {type(data)}.") + import numpy as np + import scanpy as sc - return_unimodaldata = False - if isinstance(data, pgio.MultimodalData): - data = data.get_data(data.list_data()[0]) - if isinstance(data, pgio.UnimodalData): - return_unimodaldata = True - data = data.to_anndata() + if "counts" not in data.layers: + raise ValueError(f"Raw counts matrix not found in layers['counts'].") - if "counts" not in data.layers and "raw.X" not in data.layers: - raise ValueError(f"Raw counts matrix not found.") - - if "raw.X" in data.layers: - data.layers["counts"] = data.layers["raw.X"].copy() - del data.layers["raw.X"] data.X = data.layers["counts"].copy() - # winsorize data - if clip_threshold_percentile: - clip_threshold = np.percentile(data.X.data, clip_threshold_percentile) - if clip_threshold: - data.X[data.X > clip_threshold] = clip_threshold + # check for nan in expression data, zero + if isinstance(data.X, np.ndarray) and np.isnan(data.X).any(): + import warnings + + warnings.warn( + "NANs detected in counts. NANs will be zeroed before normalization in X.", + UserWarning, + ) + data.X = np.nan_to_num(data.X, nan=0.0) # log norm sc.pp.normalize_total(data, target_sum=1e4) sc.pp.log1p(data) + del data.uns["log1p"] - if return_unimodaldata: - data = pgio.UnimodalData(data) return data def filter_cells( - data: Union[anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData], + data: "anndata.AnnData", min_genes: int = 400, - mito_prefix: str = None, + mito_prefix: Optional[str] = None, mito_percent: float = 30.0, -) -> Union[anndata.AnnData, pgio.MultimodalData]: - """QC filter the dataset from gene expression raw counts. +) -> "anndata.AnnData": + """QC filter cells in the dataset from gene expression raw counts. Parameters ---------- - data: pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData + data: anndata.AnnData Annotated data matrix with rows for cells and columns for genes. min_genes: int, default: 400 The minimum number of expressed genes in order not to be filtered out. - mito_prefix: str, optional + mito_prefix: str, optional, default: None The prefix to represent mitochondria genes. Typically "MT-" or "mt-". If None, it will try to infer whether it is either "MT-" or "mt-". mito_percent: float, default: 30.0 @@ -136,7 +258,7 @@ def filter_cells( Returns ------- - pegasusio.MultimodalData, anndata.AnnData + anndata.AnnData A data object with cells filtered out based on QC metrics that is ready to be used in further processes. @@ -145,22 +267,10 @@ def filter_cells( >>> data = filter_cells(data) """ - if type(data) not in [anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData]: - raise ValueError(f"Unknown data type {type(data)}.") + import scanpy as sc - return_unimodaldata = False - if isinstance(data, pgio.MultimodalData): - data = data.get_data(data.list_data()[0]) - if isinstance(data, pgio.UnimodalData): - return_unimodaldata = True - data = data.to_anndata() - - if "counts" not in data.layers and "raw.X" not in data.layers: - raise ValueError(f"Raw counts matrix not found.") - - if "raw.X" in data.layers: - data.layers["counts"] = data.layers["raw.X"].copy() - del data.layers["raw.X"] + if "counts" not in data.layers: + raise ValueError(f"Raw counts matrix not found in layers['counts'].") # determine between "MT-" and "mt-" if not mito_prefix: @@ -182,169 +292,90 @@ def filter_cells( cell_subset, _ = sc.pp.filter_cells(data, min_genes=min_genes, inplace=False) data = data[cell_subset].copy() - if return_unimodaldata: - data = pgio.UnimodalData(data) return data -def process_data( - data: Union[anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData], - n_top_genes: int = 2000, - batch_key: Optional[str] = None, - resolution: float = 1.3, -) -> Union[anndata.AnnData, pgio.UnimodalData]: - """Process the dataset: hvf selection, pca, umap, clustering +def consolidate_duplicate_symbols( + data: "anndata.AnnData", +) -> "anndata.AnnData": + """Consolidate duplicate gene symbols with sum. Parameters ---------- - data: pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData + data: anndata.AnnData Annotated data matrix with rows for cells and columns for genes. - n_top_genes: int, default: 2000 - The number of highly variable genes to select. - batch_key: str, optional - The obs key which holds batch information for highly variable gene selection. - resolution: float, default: 1.3 - The leiden clustering resolution. Returns ------- - pegasusio.UnimodalData, anndata.AnnData - A data object where highly variable genes are in obs["highly_variable_features"], - pca data is in obsm["X_pca"], umap data is in obsm["X_umap"], and clustering - data is in obs["leiden_labels"]. + anndata.AnnData + AnnData object with duplicate gene symbols consolidated. Examples -------- - >>> data = filter_cells(data) + >>> data = consolidate_duplicate_symbols(data) """ - return_unimodaldata = False - if isinstance(data, pgio.MultimodalData): - data = data.get_data(data.list_data()[0]) - if isinstance(data, pgio.UnimodalData): - return_unimodaldata = True - data = data.to_anndata() + import anndata + from collections import Counter + import pandas as pd + from scipy.sparse import csr_matrix - # pca - sc.pp.highly_variable_genes(data, n_top_genes=n_top_genes, batch_key=batch_key) - sc.tl.pca(data) + if "counts" not in data.layers: + raise ValueError(f"Raw counts matrix not found in layers['counts'].") - # umap - sc.pp.neighbors(data, use_rep="X_pca") - sc.tl.umap(data) - - # clustering - sc.tl.leiden(data, resolution=resolution) + gene_count = Counter(data.var.index.values) + dup_genes = {k for k in gene_count if gene_count[k] > 1} + if len(dup_genes) == 0: + return data - if return_unimodaldata: - data = pgio.UnimodalData(data) + dup_genes_data = [] + for k in dup_genes: + idx = [i for i, x in enumerate(data.var.index.values) if x == k] + X = csr_matrix(data.layers["counts"][:, idx].sum(axis=1)) + gene_data = anndata.AnnData( + X=X, + var=pd.DataFrame(index=[k]), + ) + gene_data.layers["counts"] = X.copy() + dup_genes_data.append(gene_data.copy()) + del gene_data + + obs = data.obs.copy() + dup_genes_data = anndata.concat(dup_genes_data, axis=1) + dup_genes_data.obs = obs.reset_index(drop=True) + dup_genes_data.obs.index = dup_genes_data.obs.index.astype(str) + + data.obs = obs.reset_index(drop=True) + data.obs.index = data.obs.index.astype(str) + data = anndata.concat( + [data[:, ~data.var.index.isin(dup_genes)].copy(), dup_genes_data], axis=1 + ) + data.obs = obs.copy() return data -def switch_gene_symbols( - data: Union[anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData], - var_key: str, -) -> Union[anndata.AnnData, pgio.UnimodalData]: - """Switch to a different set of gene symbols, contained in data.var - - Parameters - ---------- - data: pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData - Annotated data matrix with rows for cells and columns for genes. - var_key: str - The var key which holds the symbol information. - - Returns - ------- - pegasusio.UnimodalData, anndata.AnnData - A data object where the var index is set to those in var_key, with - nulls and duplicates removed. - - Examples - -------- - >>> data = switch_gen_symbols(data, "symbol") - """ - data.var = data.var.set_index(var_key, drop=False) - return data[:, ~(data.var.index.isnull() | data.var.index.duplicated())].copy() - - -@njit(fastmath=True, cache=True) -def select_csr( - data: np.ndarray, - indices: np.ndarray, - indptr: np.ndarray, - indexer: np.ndarray, - new_size: int, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Subset a scipy.sparse.csr_matrix based on index. - - Parameters - ---------- - data: numpy.ndarray - Data array of the matrix. - indices: numpy.ndarray - Index array of the matrix. - indptr: numpy.ndarray - Index pointer array of the matrix. - indexer: numpy.ndarray - The subset index array. - new_size: int - The size of the new matrix. - - Returns - ------- - numpy.ndarray - The new data array. - numpy.ndarray - The new index array. - numpy.ndarray - The new index pointer array. - - Examples - -------- - >>> data = filter_cells(data) - """ - - data_new = np.zeros_like(data[0:new_size]) - indices_new = np.zeros_like(indices[0:new_size]) - indptr_new = np.zeros_like(indptr) - - cnt = 0 - for i in range(indptr.size - 1): - indptr_new[i] = cnt - for j in range(indptr[i], indptr[i + 1]): - new_idx = indexer[indices[j]] - if new_idx >= 0: - data_new[cnt] = data[j] - indices_new[cnt] = new_idx - cnt += 1 - indptr_new[indptr.size - 1] = cnt - - return data_new, indices_new, indptr_new - - def align_dataset( - data: Union[anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData], - target_gene_order: np.ndarray, - keep_obsm: bool = False, + data: "anndata.AnnData", + target_gene_order: list, + keep_obsm: bool = True, gene_overlap_threshold: int = 5000, -) -> Union[anndata.AnnData, pgio.UnimodalData]: +) -> "anndata.AnnData": """Align the gene space to the target gene order. Parameters ---------- - data: pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData + data: anndata.AnnData Annotated data matrix with rows for cells and columns for genes. target_gene_order: numpy.ndarray An array containing the gene space. - keep_obsm: bool, default: False + keep_obsm: bool, default: True Retain the original data's obsm matrices in output. - gene_overlap_threshold: int, default 5000 + gene_overlap_threshold: int, default: 5000 The minimum number of genes in common between data and target_gene_order to be valid. Returns ------- - pegasusio.UnimodalData, anndata.AnnData + anndata.AnnData A data object with aligned gene space ready to be used for embedding cells. Examples @@ -353,8 +384,10 @@ def align_dataset( >>> align_dataset(data, ca.gene_order) """ - if isinstance(data, pgio.MultimodalData): - data = data.get_data(data.list_data()[0]) + import anndata + import numpy as np + import pandas as pd + from scipy.sparse import csr_matrix # raise an error if not enough genes from target_gene_order exists if sum(data.var.index.isin(target_gene_order)) < gene_overlap_threshold: @@ -370,69 +403,17 @@ def align_dataset( if data.var.index.values.tolist() == target_gene_order: return data - shell = None - if isinstance(data, pgio.UnimodalData): - mat = data.X - obs_field = data.obs - var_field = pd.DataFrame(index=target_gene_order) - - indexer = var_field.index.get_indexer(data.var_names) - new_size = (indexer[mat.indices] >= 0).sum() - data_new, indices_new, indptr_new = select_csr( - mat.data, mat.indices, mat.indptr, indexer, new_size - ) - data_matrix = csr_matrix( - (data_new, indices_new, indptr_new), - shape=(mat.shape[0], len(target_gene_order)), - ) - data_matrix.sort_indices() - - obs_field.index.name = "barcodekey" - var_field.index.name = "featurekey" - shell = pgio.UnimodalData( - barcode_metadata=obs_field, - feature_metadata=var_field, - matrices={"X": data_matrix}, - ) - - if "counts" in data.list_keys(): - mat = data.get_matrix("counts") - data_new, indices_new, indptr_new = select_csr( - mat.data, mat.indices, mat.indptr, indexer, new_size - ) - shell.add_matrix( - "counts", - csr_matrix( - (data_new, indices_new, indptr_new), - shape=(mat.shape[0], len(target_gene_order)), - ), - ) - if "raw.X" in data.list_keys(): - mat = data.get_matrix("raw.X") - data_new, indices_new, indptr_new = select_csr( - mat.data, mat.indices, mat.indptr, indexer, new_size - ) - shell.add_matrix( - "raw.X", - csr_matrix( - (data_new, indices_new, indptr_new), - shape=(mat.shape[0], len(target_gene_order)), - ), - ) - if keep_obsm and hasattr(data, "obsm"): - shell.obsm = data.obsm - - if isinstance(data, anndata.AnnData): - shell = anndata.AnnData( - X=csr_matrix((0, len(target_gene_order))), - var=pd.DataFrame(index=target_gene_order), - dtype=np.float32, - ) - shell = anndata.concat( - (shell, data[:, data.var.index.isin(shell.var.index)]), join="outer" - ) - if not keep_obsm and hasattr(data, "obsm"): - delattr(shell, "obsm") + orig_genes = data.var.index.values # record original gene list before alignment + shell = anndata.AnnData( + X=csr_matrix((0, len(target_gene_order))), + var=pd.DataFrame(index=target_gene_order), + ) + shell = anndata.concat( + (shell, data[:, data.var.index.isin(shell.var.index)]), join="outer" + ) + shell.uns["orig_genes"] = orig_genes + if not keep_obsm and hasattr(data, "obsm"): + delattr(shell, "obsm") if data.var.shape[0] == 0: raise RuntimeError(f"Empty gene space detected.") @@ -440,48 +421,94 @@ def align_dataset( return shell -def get_centroid(sparse_counts_mat: csr_matrix) -> np.ndarray: - """Get the centroid for a raw counts matrix in scipy.sparse.csr_matrix format. +def get_centroid( + counts: Union["scipy.sparse.csr_matrix", "numpy.ndarray"] +) -> "numpy.ndarray": + """Get the centroid for a raw counts matrix. Parameters ---------- - sparse_counts_mat: scipy.sparse.csr_matrix - Sparse matrix of raw gene expression counts. + counts: scipy.sparse.csr_matrix, numpy.ndarray + Raw gene expression counts. Returns ------- numpy.ndarray - A 2D numpy array of the log normalized (1e4) centroid. + A 2D numpy array of the log normalized (1e4) for the centroid. Examples -------- >>> centroid = get_centroid(data.get_matrix("counts")) >>> centroid = get_centroid(data.layers["counts"]) """ - summed_counts = sparse_counts_mat.sum(axis=0).A - normalization_factor = sparse_counts_mat.sum(axis=1).A.sum() - centroid = np.log(1 + 1e4 * summed_counts / normalization_factor) + + import numpy as np + from scipy.sparse import csr_matrix + + if isinstance(counts, np.ndarray): + counts = csr_matrix(counts) + + sum_counts = counts.sum(axis=0).A + normalization_factor = counts.sum() + centroid = np.log(1 + 1e4 * sum_counts / normalization_factor) + return centroid +def get_dist2centroid( + centroid_embedding: "numpy.ndarray", + X: Union["scipy.sparse.csr_matrix", "numpy.ndarray"], +) -> "numpy.ndarray": + """Get the centroid for a raw counts matrix in sparse csr_matrix format. + + Parameters + ---------- + centroid_embedding: numpy.ndarray + The embedding of the centroid. + X: scipy.sparse.csr_matrix, numpy.ndarray + The embedding of SCimilarity log normalized gene expression values or + SCimilarity log normalized gene expression values. + embed: bool, default: False + Whether to embed X. + + Returns + ------- + float + The mean distance of cells in X to the centroid embedding. + + Examples + -------- + >>> distances = cq.get_dist2centroid(centroid_embedding, X) + """ + + from scipy.spatial.distance import cdist + from scipy.sparse import csr_matrix + + if isinstance(X, csr_matrix): + X = X.A + distances = cdist(centroid_embedding.reshape(1, -1), X, metric="cosine").flatten() + + return distances + + def get_cluster_centroids( - data: Union[anndata.AnnData, pgio.UnimodalData, pgio.MultimodalData], - target_gene_order: np.ndarray, + data: "anndata.AnnData", + target_gene_order: "numpy.ndarray", cluster_key: str, cluster_label: Optional[str] = None, skip_null: bool = True, -) -> Tuple[np.ndarray, list]: +) -> Tuple["numpy.ndarray", list]: """Get centroids of clusters based on raw read counts. Parameters ---------- - data: pegasusio.MultimodalData, pegasusio.UnimodalData, anndata.AnnData + data: anndata.AnnData Annotated data matrix with rows for cells and columns for genes. target_gene_order: numpy.ndarray An array containing the gene space. cluster_key: str - The obs column key that contains cluster labels. - cluster_label: optional, str + The obs column name that contains cluster labels. + cluster_label: str, optional, default: None The cluster label of interest. If None, then get the centroids of all clusters, otherwise get only the centroid for the cluster of interest @@ -490,16 +517,21 @@ def get_cluster_centroids( Returns ------- - numpy.ndarray + centroids: numpy.ndarray A 2D numpy array of the log normalized (1e4) cluster centroids. - list + cluster_idx: list A list of cluster labels corresponding to the order returned in centroids. Examples -------- - >>> centroids, cluster_idx = get_cluster_centroids(data, gene_order, "leiden_labels") + >>> centroids, cluster_idx = get_cluster_centroids(data, gene_order, "cluster_label") """ + import numpy as np + + if "counts" not in data.layers: + raise ValueError(f"Raw counts matrix not found in layers['counts'].") + centroids = [] cluster_idx = [] @@ -508,53 +540,19 @@ def get_cluster_centroids( aligned_data = aligned_data[aligned_data.obs[cluster_key].notnull()].copy() aligned_data.obs[cluster_key] = aligned_data.obs[cluster_key].astype(str) - if isinstance(aligned_data, pgio.UnimodalData): - for i in set(aligned_data.obs[cluster_key]): - if cluster_label is not None: - i = cluster_label - - cluster_idx.append(i) - if "counts" in aligned_data.list_keys(): - centroids.append( - get_centroid( - aligned_data[aligned_data.obs[cluster_key] == i] - .copy() - .get_matrix("counts") - ) - ) - elif "raw.X" in aligned_data.list_keys(): - centroids.append( - get_centroid( - aligned_data[aligned_data.obs[cluster_key] == i] - .copy() - .get_matrix("raw.X") - ) - ) - else: - raise RuntimeError("Dataset incompatible: no counts matrix found") - - if cluster_label is not None: - break - - if isinstance(aligned_data, anndata.AnnData): - for i in set(aligned_data.obs[cluster_key]): - if cluster_label is not None: - i = cluster_label - - cluster_idx.append(i) - if "counts" in aligned_data.layers: - centroids.append( - get_centroid( - aligned_data[aligned_data.obs[cluster_key] == i] - .copy() - .layers["counts"] - ) - ) - else: - raise RuntimeError("Dataset incompatible: no counts matrix found") - - if cluster_label is not None: - break + for i in set(aligned_data.obs[cluster_key]): + if cluster_label is not None: + i = cluster_label + + cluster_idx.append(i) + centroids.append( + get_centroid( + aligned_data[aligned_data.obs[cluster_key] == i].copy().layers["counts"] + ) + ) + + if cluster_label is not None: + break centroids = np.vstack(centroids) @@ -562,3 +560,431 @@ def get_cluster_centroids( raise RuntimeError(f"NaN detected in centroids.") return centroids, cluster_idx + + +def write_tiledb_array( + tiledb_array_uri: str, arr: "numpy.ndarray", batch_size: int = 100000 +): + """Write TileDB Array from a numpy array. + + Parameters + ---------- + tiledb_array_uri: str + URI for the TileDB array. + batch_size: int, default: 10000 + Batch size for the tiles. + """ + + import numpy as np + import tiledb + from tqdm import tqdm + + print(f"Configuring tiledb array: {tiledb_array_uri}") + + xdimtype = np.int32 + ydimtype = np.int32 + value_type = np.float32 + + xdim = tiledb.Dim( + name="x", domain=(0, arr.shape[0] - 1), tile=batch_size, dtype=xdimtype + ) + ydim = tiledb.Dim( + name="y", domain=(0, arr.shape[1] - 1), tile=arr.shape[1], dtype=ydimtype + ) + dom = tiledb.Domain(xdim, ydim) + + attr = tiledb.Attr( + name="vals", + dtype=value_type, + filters=tiledb.FilterList([tiledb.GzipFilter()]), + ) + + schema = tiledb.ArraySchema( + domain=dom, + sparse=False, + cell_order="row-major", + tile_order="row-major", + attrs=[attr], + ) + tiledb.Array.create(tiledb_array_uri, schema) + + tdbfile = tiledb.open(tiledb_array_uri, "w") + for row in tqdm(range(0, arr.shape[0], batch_size)): + mat_slice = slice(row, row + batch_size) + sub_matrix = np.array( + arr[mat_slice, :].astype(value_type).tolist(), dtype=value_type + ) + tdbfile[mat_slice, 0 : arr.shape[1]] = sub_matrix + tdbfile.close() + + +def create_tiledb_array( + tiledb_array_uri: str, + data_list: List[str], + nrows: int, + ncols: int, + batch_size: int = 10000, +): + """Create TileDB Array from a list of numpy data files. + + Parameters + ---------- + tiledb_array_uri: str + URI for the TileDB array. + data_list: List[str] + List of data files. + nrows: int + Number of total rows + ncols: int + Number of columns, must be consistent between files + batch_size: int, default: 10000 + Batch size for the tiles. + """ + + import numpy as np + import tiledb + + print(f"Configuring tiledb array: {tiledb_array_uri}") + + xdimtype = np.int32 + ydimtype = np.int32 + value_type = np.float32 + + xdim = tiledb.Dim( + name="x", + domain=(0, nrows - 1), + tile=batch_size, + dtype=xdimtype, + ) + ydim = tiledb.Dim( + name="y", + domain=(0, ncols - 1), + tile=ncols, + dtype=ydimtype, + ) + dom = tiledb.Domain(xdim, ydim) + + attr = tiledb.Attr( + name="vals", + dtype=value_type, + filters=tiledb.FilterList([tiledb.GzipFilter()]), + ) + + schema = tiledb.ArraySchema( + domain=dom, + sparse=False, + cell_order="row-major", + tile_order="row-major", + attrs=[attr], + ) + tiledb.Array.create(tiledb_array_uri, schema) + + tdbfile = tiledb.open(tiledb_array_uri, "w") + previous_shape = None + for f in data_list: + if previous_shape is None: + paging_idx = 0 + else: + paging_idx += previous_shape[0] + + arr = np.load(f) + previous_shape = arr.shape + + tbd_slice = slice(paging_idx, paging_idx + arr.shape[0]) + tdbfile[tbd_slice, 0:ncols] = arr + tdbfile.close() + + +def optimize_tiledb_array(tiledb_array_uri: str, verbose: bool = True): + """Optimize TileDB Array. + + Parameters + ---------- + tiledb_array_uri: str + URI for the TileDB array. + verbose: bool + Boolean indicating whether to use verbose printing. + """ + + import tiledb + + if verbose: + print(f"Optimizing {tiledb_array_uri}") + + frags = tiledb.array_fragments(tiledb_array_uri) + if verbose: + print("Fragments before consolidation: {}".format(len(frags))) + + cfg = tiledb.Config() + cfg["sm.consolidation.step_min_frags"] = 1 + cfg["sm.consolidation.step_max_frags"] = 200 + tiledb.consolidate(tiledb_array_uri, config=cfg) + tiledb.vacuum(tiledb_array_uri) + + frags = tiledb.array_fragments(tiledb_array_uri) + if verbose: + print("Fragments after consolidation: {}".format(len(frags))) + + +def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": + """Mapper to clean tissue names. + + Parameters + ---------- + tissues: pandas.Series + A pandas Series containing tissue names. + + Returns + ------- + pandas.Series + A pandas Series containing cleaned tissue names. + + Examples + -------- + >>> data.obs["tissue_simple"] = clean_tissues(data.obs["tissue"]).fillna("other tissue") + """ + + tissue_mapper = { + "adipose": { + "omentum", + "adipose tissue", + "Fat", + "omental fat pad", + "white adipose tissue", + }, + "adrenal gland": {"adrenal gland", "visceral fat"}, + "airway": { + "trachea", + "trachea;bronchus", + "Trachea", + "bronchus", + "nasopharynx", + "respiratory tract epithelium", + "bronchiole", + "inferior nasal concha", + "nose", + "nasal turbinal", + }, + "bone": {"bone", "bone tissue", "head of femur", "synovial fluid"}, + "bladder": {"urinary bladder", "Bladder", "bladder"}, + "blood": {"blood", "umbilical cord blood", "peripheral blood", "Blood"}, + "bone marrow": {"bone marrow", "Bone_Marrow"}, + "brain": { + "brain", + "cortex", + "prefrontal cortex", + "occipital cortex", + "cerebrospinal fluid", + "midbrain", + "spinal cord", + "superior frontal gyrus", + "entorhinal cortex", + "White Matter brain tissue", + "Entorhinal Cortex", + "cerebral hemisphere", + "brain white matter", + "cerebellum", + "hypothalamus", + }, + "breast": {"breast", "Mammary", "mammary gland"}, + "esophagus": { + "esophagus", + "esophagusmucosa", + "esophagusmuscularis", + "esophagus mucosa", + "esophagus muscularis mucosa", + }, + "eye": {"eye", "uvea", "corneal epithelium", "retina", "Eye"}, + "stomach": {"stomach"}, + "gut": { + "colon", + "ascending colon", + "sigmoid colon", + "large intestine", + "small intestine", + "intestine", + "Small_Intestine", + "Large_Intestine", + "ileum", + "right colon", + "left colon", + "transverse colon", + "digestive tract", + "caecum", + "jejunum", + "jejunum ", + "descending colon", + }, + "heart": { + "heart", + "aorta", + "cardiac muscle of left ventricle", + "Heart", + "heart left ventricle", + "pulmonary artery", + }, + "kidney": { + "adult mammalian kidney", + "kidney", + "Kidney", + "inner medulla of kidney", + "outer cortex of kidney", + }, + "liver": {"liver", "Liver", "caudate lobe of liver"}, + "lung": { + "lung", + "alveolar system", + "lung parenchyma", + "respiratory airway", + "trachea;respiratory airway", + "BAL", + "Lung", + "Parenchymal lung tissue", + "Distal", + "Proximal", + "Intermediate", + "lower lobe of lung", + "upper lobe of lung", + }, + "lymph node": { + "lymph node", + "axillary lymph node", + "Lymph_Node", + "craniocervical lymph node", + }, + "male reproduction": { + "male reproductive gland", + "testis", + "prostate gland", + "epididymis epithelium", + "Prostate", + "prostate", + "peripheral zone of prostate", + }, + "female reproduction": { + "ovary", + "tertiary ovarian follicle", + "ovarian follicle", + "fimbria of uterine tube", + "ampulla of uterine tube", + "isthmus of fallopian tube", + "fallopian tube", + "uterus", + "Uterus", + }, + "pancreas": {"pancreas", "Pancreas", "islet of Langerhans"}, + "skin": { + "skin of body", + "skin epidermis", + "skin of prepuce of penis", + "scrotum skin", + "Skin", + "skin", + }, + "spleen": {"spleen", "Spleen"}, + "thymus": {"thymus", "Thymus"}, + "vasculature": { + "vasculature", + "mesenteric artery", + "umbilical vein", + "Vasculature", + }, + } + term2simple = {} + for tissue_simplified, children in tissue_mapper.items(): + for child in children: + term2simple[child] = tissue_simplified + + return tissues.map(term2simple) + + +def clean_diseases(diseases: "pandas.Series") -> "pandas.Series": + """Mapper to clean disease names. + + Parameters + ---------- + diseases: pandas.Series + A pandas Series containing disease names. + + Returns + ------- + pandas.Series + A pandas Series containing cleaned disease names. + + Examples + -------- + >>> data.obs["disease_simple"] = clean_diseases(data.obs["disease"]).fillna("healthy") + """ + + disease_mapper = { + "healthy": {"healthy", "", "NA"}, + "Alzheimer's": { + "Alzheimer's disease", + }, + "COVID-19": { + "COVID-19", + }, + "ILD": { + "pulmonary fibrosis", + "idiopathic pulmonary fibrosis", + "interstitial lung disease", + "systemic scleroderma;interstitial lung disease", + "fibrosis", + "hypersensitivity pneumonitis", + }, + "cancer": { + "head and neck squamous cell carcinoma", + "renal cell adenocarcinoma", + "hepatocellular carcinoma", + "B-cell acute lymphoblastic leukemia", + "glioma", + "ovarian serous carcinoma", + "neuroblastoma", + "pancreatic carcinoma", + "melanoma", + "multiple myeloma", + "Gastrointestinal stromal tumor", + "neuroblastoma" "nasopharyngeal neoplasm", + "adenocarcinoma", + "pancreatic ductal adenocarcinoma", + "chronic lymphocytic leukemia", + "Uveal Melanoma", + "Myelofibrosis", + }, + "MS": { + "multiple sclerosis", + }, + "dengue": { + "dengue disease", + }, + "IBD": { + "Crohn's disease", + }, + "SLE": {"systemic lupus erythematosus"}, + "scleroderma": {"scleroderma"}, + "LCH": {"Langerhans Cell Histiocytosis"}, + "NAFLD": {"non-alcoholic fatty liver disease"}, + "Kawasaki disease": {"mucocutaneous lymph node syndrome"}, + "eczema": {"atopic eczema"}, + "sepsis": {"septic shock"}, + "obesity": {"obesity"}, + "DRESS": {"drug hypersensitivity syndrome"}, + "hidradenitis suppurativa": {"hidradenitis suppurativa"}, + "T2 diabetes": {"type II diabetes mellitus"}, + "non-alcoholic steatohepatitis": {"non-alcoholic steatohepatitis"}, + "Biliary atresia": {"Biliary atresia"}, + "essential thrombocythemia": {"essential thrombocythemia"}, + "HIV": {"HIV enteropathy"}, + "monoclonal gammopathy": {"monoclonal gammopathy"}, + "psoriatic arthritis": {"psoriatic arthritis"}, + "RA": {"rheumatoid arthritis"}, + "osteoarthritis": {"osteoarthritis"}, + "periodontitis": {"periodontitis"}, + "Lymphangioleiomyomatosis": {"Lymphangioleiomyomatosis"}, + } + term2simple = {} + for disease_simplified, children in disease_mapper.items(): + for child in children: + term2simple[child] = disease_simplified + + return diseases.map(term2simple) diff --git a/src/scimilarity/visualizations.py b/src/scimilarity/visualizations.py index 9c3b3a8..610e10d 100644 --- a/src/scimilarity/visualizations.py +++ b/src/scimilarity/visualizations.py @@ -1,21 +1,35 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Dict, Tuple, Optional -import circlify as circ -import matplotlib as mpl -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns -mpl.rcParams["pdf.fonttype"] = 42 +def aggregate_counts(data: "pandas.DataFrame", levels: List[str]) -> dict: + """Aggregates cell counts on sample metadata and compiles it into circlify format. + Parameters + ---------- + data: pandas.DataFrame + A pandas dataframe containing sample metadata. + levels: List[str] + Specify the groupby columns for grouping the sample metadata. + + Returns + ------- + dict + A circlify format dictionary containing grouped sample metadata. + + Examples + -------- + >>> circ_dict = aggregate_counts(sample_metadata, ["tissue", "disease"]) + """ -def aggregate_counts(data: pd.DataFrame, levels: List[str]): data_dict = {} for n in range(len(levels)): # construct a groupby dataframe to obtain counts columns = levels[0 : (n + 1)] - df = data.groupby(columns)[columns[0]].count().reset_index(name="count") + df = ( + data.groupby(columns, observed=True)[columns[0]] + .count() + .reset_index(name="count") + ) # construct a nested dict to handle children levels for r in df.index: @@ -41,13 +55,42 @@ def aggregate_counts(data: pd.DataFrame, levels: List[str]): def assign_size( data_dict: dict, - data: pd.DataFrame, + data: "pandas.DataFrame", levels: List[str], size_column: str, name_column: str, -): +) -> dict: + """Assigns circle sizes to a circlify format dictionary. + + Parameters + ---------- + data_dict: dict + A circlify format dictionary. + data: pandas.DataFrame + A pandas dataframe containing sample metadata. + levels: List[str] + Specify the groupby columns for grouping the sample metadata. + size_column: str + The name of the column that will be used for circle size. + name_column: str + The name of the column that will be used for circle name. + + Returns + ------- + dict + A circlify format dictionary. + + Examples + -------- + >>> circ_dict = assign_size(circ_dict, sample_metadata, ["tissue", "disease"], size_column="cells", name_column="study") + """ + df = data[levels + [size_column, name_column]] - df = df.groupby(levels + [name_column])[size_column].sum().reset_index(name="count") + df = ( + df.groupby(levels + [name_column], observed=True)[size_column] + .sum() + .reset_index(name="count") + ) for ( r ) in ( @@ -64,11 +107,36 @@ def assign_size( def assign_suffix( data_dict: dict, - data: pd.DataFrame, + data: "pandas.DataFrame", levels: List[str], suffix_column: str, name_column: str, -): +) -> dict: + """Assigns circle name and suffix to a circlify format dictionary. + + Parameters + ---------- + data_dict: dict + A circlify format dictionary. + data: pandas.DataFrame + A pandas dataframe containing sample metadata. + levels: List[str] + Specify the groupby columns for grouping the sample metadata. + suffix_column: str + The name of the column that will be used for the circle name suffix. + name_column: str + The name of the column that will be used for circle name. + + Returns + ------- + dict + A circlify format dictionary. + + Examples + -------- + >>> circ_dict = assign_suffix(circ_dict, sample_metadata, ["tissue", "disease"], suffix_column="cells", name_column="study") + """ + df = data[levels + [suffix_column, name_column]] for r in df.index: # find the deepest levels in data_dict and rename with suffix entry = data_dict[df.iloc[r, 0]] @@ -83,11 +151,36 @@ def assign_suffix( def assign_colors( data_dict: dict, - data: pd.DataFrame, + data: "pandas.DataFrame", levels: List[str], color_column: str, name_column: str, -): +) -> dict: + """Assigns circle name and color to a circlify format dictionary. + + Parameters + ---------- + data_dict: dict + A circlify format dictionary. + data: pandas.DataFrame + A pandas dataframe containing sample metadata. + levels: List[str] + Specify the groupby columns for grouping the sample metadata. + color_column: str + The name of the column that will be used for the circle color. + name_column: str + The name of the column that will be used for circle name. + + Returns + ------- + dict + A circlify format dictionary. + + Examples + -------- + >>> circ_dict = assign_colors(circ_dict, sample_metadata, ["tissue", "disease"], color_column="cells", name_column="study") + """ + df = data[levels + [color_column, name_column]] for r in df.index: # find the deepest levels in data_dict and rename with color entry = data_dict[df.iloc[r, 0]] @@ -100,7 +193,24 @@ def assign_colors( return data_dict -def get_children_data(data_dict: dict): +def get_children_data(data_dict: dict) -> List[dict]: + """Recursively get all children data for a given circle. + + Parameters + ---------- + data_dict: dict + A circlify format dictionary + + Returns + ------- + List[dict] + A list of children data. + + Examples + -------- + >>> children = get_children_data(circ_dict[i]["children"]) + """ + child_data = [] for i in data_dict: # recursively get all children data entry = {"id": i, "datum": data_dict[i]["datum"]} @@ -111,7 +221,24 @@ def get_children_data(data_dict: dict): return child_data -def circ_dict2data(circ_dict: dict): +def circ_dict2data(circ_dict: dict) -> List[dict]: + """Convert a circlify format dictionary to the list format expected by circlify. + + Parameters + ---------- + data_dict: dict + A circlify format dictionary + + Returns + ------- + List[dict] + A list of circle data. + + Examples + -------- + >>> circ_data = circ_dict2data(circ_dict) + """ + circ_data = [] for i in circ_dict: # convert dict to circlify list data entry = {"id": i, "datum": circ_dict[i]["datum"]} @@ -123,17 +250,53 @@ def circ_dict2data(circ_dict: dict): def draw_circles( - circ_data: list, + circ_data: List[dict], title: str = "", - fig_size: Tuple[int, int] = (10, 10), + figsize: Tuple[int, int] = (10, 10), filename: Optional[str] = None, use_colormap: Optional[str] = None, use_suffix: Optional[dict] = None, use_suffix_as_color: bool = False, ): + """Draw the circlify plot. + + Parameters + ---------- + circ_data: List[dict] + A circlify format list. + title: str, default: "" + The figure title. + figsize: Tuple[int, int], default: (10, 10) + The figure size in inches. + filename: str, optional, default: None + Filename to save the figure. + use_colormap: str, optional, default: None + The colormap identifier. + use_suffix: dict, optional, default: None + A mapping of suffix to color using a dictionary in the form {suffix: float} + use_suffix_as_color: bool, default: False + Use the suffix as the color. This expects the suffix to be a float. + + Examples + -------- + >>> draw_circles(circ_data) + """ + + try: + import circlify as circ + except: + raise ImportError( + "Package 'circlify' not found. Please install with 'pip install circlify'." + ) + + import matplotlib.pyplot as plt + import matplotlib as mpl + + mpl.rcParams["pdf.fonttype"] = 42 + circles = circ.circlify(circ_data, show_enclosure=True) - fig, ax = plt.subplots(figsize=fig_size) + fig, ax = plt.subplots(figsize=figsize) if use_colormap: cmap = mpl.cm.get_cmap(use_colormap) @@ -235,26 +398,63 @@ def draw_circles( fig.savefig(filename, bbox_inches="tight") +def hits_circles( + metadata: "pandas.DataFrame", + levels: list = ["tissue", "disease"], + figsize: Tuple[int, int] = (10, 10), + filename: Optional[str] = None, +): + """Visualize sample metadata as circle plots for tissue and disease. + + Parameters + ---------- + metadata: pandas.DataFrame + A pandas dataframe containing sample metadata for nearest neighbors + with at least columns: ["study", "cells"], that represent the number + of circles and circle size respectively. + levels: list, default: ["tissue", "disease"] + The columns to uses as group levels in the circles hierarchy. + figsize: Tuple[int, int], default: (10, 10) + Figure size, width x height + filename: str, optional + Filename to save the figure. + + Examples + -------- + >>> hits_circles(metadata) + """ + + circ_dict = aggregate_counts(metadata, levels) + circ_dict = assign_size( + circ_dict, metadata, levels, size_column="cells", name_column="study" + ) + circ_data = circ_dict2data(circ_dict) + draw_circles(circ_data, figsize=figsize, filename=filename) + + def hits_heatmap( - sample_metadata: Dict[str, pd.DataFrame], + sample_metadata: Dict[str, "pandas.DataFrame"], x: str, y: str, count_type: str = "cells", + figsize: Tuple[int, int] = (10, 10), filename: Optional[str] = None, ): """Visualize a list of sample metadata objects as a heatmap. Parameters ---------- - sample_metadata: Dict[pandas.DataFrame] + sample_metadata: Dict[str, pandas.DataFrame] A dict where keys are cluster names and values are pandas dataframes containing - sample metadata for each cluster centroid. + sample metadata for each cluster centroid with columns: ["tissue", "disease", "study", "sample"]. x: str x-axis label key. This corresponds to cluster name values. y: str y-axis label key. This corresponds to the dataframe column to visualize. count_type: {"cells", "fraction"}, default: "cells" Count type to color in the heatmap. + figsize: Tuple[int, int], default: (10, 10) + Figure size, width x height filename: str, optional Filename to save the figure. @@ -263,6 +463,14 @@ def hits_heatmap( >>> hits_heatmap(sample_metadata, "time", "disease") """ + import matplotlib.pyplot as plt + import matplotlib as mpl + import numpy as np + import pandas as pd + import seaborn as sns + + mpl.rcParams["pdf.fonttype"] = 42 + valid_count_types = {"cells", "fraction"} if count_type not in valid_count_types: raise ValueError( @@ -274,11 +482,18 @@ def hits_heatmap( df = pd.concat(sample_metadata).reset_index(drop=True) if count_type == "cells": - df_m = df.groupby([x, y])["cells"].sum().unstack(level=0).fillna(0) + df_m = ( + df.groupby([x, y], observed=True)["cells"].sum().unstack(level=0).fillna(0) + ) else: - df_m = df.groupby([x, y])["fraction"].mean().unstack(level=0).fillna(0) + df_m = ( + df.groupby([x, y], observed=True)["fraction"] + .mean() + .unstack(level=0) + .fillna(0) + ) - fig, ax = plt.subplots(figsize=(5, 5)) + fig, ax = plt.subplots(figsize=figsize) sns.heatmap( ax=ax, data=df_m, @@ -286,10 +501,11 @@ def hits_heatmap( yticklabels=True, square=True, cmap="Blues", - linewidth=0.3, + linewidth=0.01, + linecolor="gray", cbar_kws={"shrink": 0.5}, ) - plt.tick_params(axis="both", labelsize=4) + plt.tick_params(axis="both", labelsize=8, grid_alpha=0.0) # xticks ax.xaxis.tick_top() diff --git a/src/scimilarity/zarr_data_models.py b/src/scimilarity/zarr_data_models.py index d37047f..3390d73 100644 --- a/src/scimilarity/zarr_data_models.py +++ b/src/scimilarity/zarr_data_models.py @@ -1,7 +1,6 @@ -import os from collections import Counter - import numpy as np +import os import pandas as pd import pytorch_lightning as pl import torch @@ -13,7 +12,8 @@ class scDatasetFromList(Dataset): - """A class to encapsulation single cell datasets from list""" + """A class that represent a collection of single cell datasets in zarr format.""" + def __init__(self, data_list, obs_celltype="celltype_name", obs_study="study"): """Constructor. @@ -26,6 +26,7 @@ def __init__(self, data_list, obs_celltype="celltype_name", obs_study="study"): obs_study: str, default: "study" Study name. """ + self.data_list = data_list self.ncells_list = [data.shape[0] for data in data_list] self.ncells = sum(self.ncells_list) @@ -54,7 +55,8 @@ def __getitem__(self, idx): class MetricLearningZarrDataModule(pl.LightningDataModule): - """A class to encapsulate zarr data model.""" + """A class to encapsulate a collection of zarr datasets to train the model.""" + def __init__( self, train_path: str, @@ -71,14 +73,16 @@ def __init__( ---------- train_path: str Path to folder containing all training datasets. + All datasets should be in zarr format, aligned to a known gene space, and + cleaned to only contain valid cell ontology terms. gene_order: str Use a given gene order as described in the specified file. One gene symbol per line. IMPORTANT: the zarr datasets should already be in this gene order after preprocessing. - val_path: str, optional + val_path: str, optional, default: None Path to folder containing all validation datasets. - test_path: str, optional + test_path: str, optional, default: None Path to folder containing all test datasets. obs_field: str, default: "celltype_name" The obs key name containing celltype labels. @@ -114,11 +118,23 @@ def __init__( train_data_list = [] self.train_Y = [] # text labels self.train_study = [] # text studies + + if self.train_path[-1] != os.sep: + self.train_path += os.sep + self.train_file_list = [ - f for f in os.listdir(self.train_path) if f.endswith(".aligned.zarr") + ( + root.replace(self.train_path, "").split(os.sep)[0], + dirs[0].replace(".aligned.zarr", ""), + ) + for root, dirs, files in os.walk(self.train_path) + if dirs and dirs[0].endswith(".aligned.zarr") ] - for filename in tqdm(self.train_file_list): - data_path = os.path.join(self.train_path, filename) + + for study, sample in tqdm(self.train_file_list): + data_path = os.path.join( + self.train_path, study, sample, sample + ".aligned.zarr" + ) if os.path.isdir(data_path): zarr_data = ZarrDataset(data_path) train_data_list.append(zarr_data) @@ -137,11 +153,23 @@ def __init__( val_data_list = [] self.val_Y = [] self.val_study = [] + + if self.val_path[-1] != os.sep: + self.val_path += os.sep + self.val_file_list = [ - f for f in os.listdir(self.val_path) if f.endswith(".aligned.zarr") + ( + root.replace(self.val_path, "").split(os.sep)[0], + dirs[0].replace(".aligned.zarr", ""), + ) + for root, dirs, files in os.walk(self.val_path) + if dirs and dirs[0].endswith(".aligned.zarr") ] - for filename in tqdm(self.val_file_list): - data_path = os.path.join(self.val_path, filename) + + for study, sample in tqdm(self.val_file_list): + data_path = os.path.join( + self.val_path, study, sample, sample + ".aligned.zarr" + ) if os.path.isdir(data_path): zarr_data = ZarrDataset(data_path) val_data_list.append(zarr_data) @@ -158,11 +186,23 @@ def __init__( test_data_list = [] self.test_Y = [] self.test_study = [] + + if self.test_path[-1] != os.sep: + self.test_path += os.sep + self.test_file_list = [ - f for f in os.listdir(self.test_path) if f.endswith(".aligned.zarr") + ( + root.replace(self.test_path, "").split(os.sep)[0], + dirs[0].replace(".aligned.zarr", ""), + ) + for root, dirs, files in os.walk(self.test_path) + if dirs and dirs[0].endswith(".aligned.zarr") ] - for filename in tqdm(self.test_file_list): - data_path = os.path.join(self.test_path, filename) + + for study, sample in tqdm(self.test_file_list): + data_path = os.path.join( + self.test_path, study, sample, sample + ".aligned.zarr" + ) if os.path.isdir(data_path): zarr_data = ZarrDataset(data_path) test_data_list.append(zarr_data) @@ -176,12 +216,42 @@ def __init__( # Lazy load test data from list of zarr datasets self.test_dataset = scDatasetFromList(test_data_list) - def two_way_weighting(self, vec1: list, vec2: list): + def two_way_weighting(self, vec1: list, vec2: list) -> dict: + """Two-way weighting. + + Parameters + ---------- + vec1 + Vector 1 + vec2 + Vector 2 + + Returns + ------- + dict + A dictionary containing the two-way weighting. + """ + counts = pd.crosstab(vec1, vec2) weights_matrix = (1 / counts).replace(np.inf, 0) return weights_matrix.unstack().to_dict() - def get_sampler_weights(self, labels: list, studies: Optional[list] = None): + def get_sampler_weights( + self, labels: list, studies: Optional[list] = None + ) -> WeightedRandomSampler: + """Get weighted random sampler. + + Parameters + ---------- + dataset: scDataset + Single cell dataset. + + Returns + ------- + WeightedRandomSampler + A WeightedRandomSampler object. + """ + if studies is None: class_sample_count = Counter(labels) sample_weights = torch.Tensor([1.0 / class_sample_count[t] for t in labels]) @@ -199,6 +269,20 @@ def get_sampler_weights(self, labels: list, studies: Optional[list] = None): return WeightedRandomSampler(sample_weights, len(sample_weights)) def collate(self, batch): + """Collate tensors. + + Parameters + ---------- + batch: + Batch to collate. + + Returns + ------- + tuple + A Tuple[torch.Tensor, torch.Tensor, list] containing information + on the collated tensors. + """ + profiles, labels, studies = tuple( map(list, zip(*batch)) ) # tuple([list(t) for t in zip(*batch)]) @@ -208,7 +292,15 @@ def collate(self, batch): studies, ) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """Load the training dataset. + + Returns + ------- + DataLoader + A DataLoader object containing the training dataset. + """ + return DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -219,7 +311,15 @@ def train_dataloader(self): collate_fn=self.collate, ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """Load the validation dataset. + + Returns + ------- + DataLoader + A DataLoader object containing the validation dataset. + """ + if self.val_dataset is None: return None return DataLoader( @@ -232,7 +332,15 @@ def val_dataloader(self): collate_fn=self.collate, ) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: + """Load the test dataset. + + Returns + ------- + DataLoader + A DataLoader object containing the test dataset. + """ + if self.test_dataset is None: return None return DataLoader( diff --git a/src/scimilarity/zarr_dataset.py b/src/scimilarity/zarr_dataset.py index 54a609a..9759ecf 100644 --- a/src/scimilarity/zarr_dataset.py +++ b/src/scimilarity/zarr_dataset.py @@ -1,10 +1,6 @@ +from scipy.sparse import csr_matrix, csc_matrix, coo_matrix from typing import Dict, List, Optional, Tuple, Union -import numcodecs -import numpy as np -import pandas as pd -import zarr -from scipy.sparse import csr_matrix, csc_matrix, coo_matrix ARRAY_FORMATS = { "csr_matrix": csr_matrix, @@ -14,12 +10,8 @@ class ZarrDataset: - """A class that reads zarr datasets saved by AnnData from disk. + """A class that reads and manipulates zarr datasets saved by AnnData from disk. Adapted from https://github.com/lilab-bcb/backedarray - - Example - ------- - zarr_data = ZarrDataset("/data/dataset.zarr") """ def __init__(self, store_path: str, mode: str = "r"): @@ -34,9 +26,11 @@ def __init__(self, store_path: str, mode: str = "r"): Examples -------- - >>> zarr_data = ZarrDataset("/data/dataset.zarr") + >>> zarr_data = ZarrDataset("/data/dataset.zarr") """ + import zarr + self.store_path = zarr.DirectoryStore(store_path) self.root = zarr.open_group( self.store_path, mode=mode, chunk_store=self.store_path @@ -48,7 +42,7 @@ def dataset_info(self) -> Dict[str, list]: Returns ------- - dict + d: dict A dict containing information on the content of the dataset, such as keys in the various object attributes. @@ -58,41 +52,18 @@ def dataset_info(self) -> Dict[str, list]: """ d = {} - if "var" in self.root.keys(): - d["var"] = list(self.root["var"].keys()) - if "obs" in self.root.keys(): - d["obs"] = list(self.root["obs"].keys()) - if "X" in self.root.keys(): + if "var" in self.root: + d["var"] = list(self.root["var"]) + if "obs" in self.root: + d["obs"] = list(self.root["obs"]) + if "X" in self.root: d["shape"] = self.root["X"].attrs["shape"] if "layers" in self.root: - d["layers"] = list(self.root["layers"].keys()) + d["layers"] = list(self.root["layers"]) if "uns" in self.root: - d["uns"] = list(self.root["uns"].keys()) + d["uns"] = list(self.root["uns"]) return d - @property - def X(self): - if "X" in self.root.keys(): - X = self.root["X"] - return self.get_matrix(X) - return None - - @property - def X_copy(self): - if "X" in self.root.keys(): - X = self.root["X"] - return self.get_matrix(X, in_mem=True) - return None - - @property - def counts(self): - if "layers" in self.root.keys(): - layers = self.root["layers"] - if "counts" in layers.keys(): - counts = layers["counts"] - return self.get_matrix(counts) - return None - @property def shape(self) -> Tuple[int, int]: """Get the shape of the gene expression matrix. @@ -107,17 +78,17 @@ def shape(self) -> Tuple[int, int]: >>> zarr_data.shape """ - if "X" in self.root.keys(): + if "X" in self.root: return self.root["X"].attrs["shape"] return None @property - def var_index(self) -> pd.Index: + def var_index(self) -> "pandas.Index": """Get the var index. Returns ------- - pandas.Index + var_index: pandas.Index A pandas Index containing the var index. Examples @@ -125,12 +96,12 @@ def var_index(self) -> pd.Index: >>> zarr_data.var_index """ - if "var" in self.root.keys(): + if "var" in self.root: return self.get_annotation_index(self.root["var"]) return None @property - def var(self) -> pd.DataFrame: + def var(self) -> "pandas.DataFrame": """Get the var dataframe. Returns @@ -143,7 +114,9 @@ def var(self) -> pd.DataFrame: >>> zarr_data.var """ - if "var" in self.root.keys(): + import pandas as pd + + if "var" in self.root: var = pd.DataFrame( {x: self.get_var(x) for x in self.dataset_info["var"] if x != "_index"} ) @@ -152,7 +125,7 @@ def var(self) -> pd.DataFrame: return None @property - def obs_index(self) -> pd.Index: + def obs_index(self) -> "pandas.Index": """Get the obs index. Returns @@ -165,12 +138,12 @@ def obs_index(self) -> pd.Index: >>> zarr_data.obs_index """ - if "obs" in self.root.keys(): + if "obs" in self.root: return self.get_annotation_index(self.root["obs"]) return None @property - def obs(self) -> pd.DataFrame: + def obs(self) -> "pandas.DataFrame": """Get the obs dataframe. Returns @@ -183,7 +156,9 @@ def obs(self) -> pd.DataFrame: >>> zarr_data.obs """ - if "obs" in self.root.keys(): + import pandas as pd + + if "obs" in self.root: obs = pd.DataFrame( {x: self.get_obs(x) for x in self.dataset_info["obs"] if x != "_index"} ) @@ -191,13 +166,61 @@ def obs(self) -> pd.DataFrame: return obs return None - def set_X(self, matrix: Union[csr_matrix, csc_matrix, coo_matrix]) -> None: + def get_X(self, in_mem: bool = False): + """Get the X matrix backed by zarr storage. + + Parameters + ---------- + in_mem: bool, default: False + Return the full matrix in memory rather than a reference to zarr group. + + Returns + ------- + scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix + The sparse X matrix. + + Examples + -------- + >>> zarr_data.X + """ + + if "X" in self.root: + X = self.root["X"] + return self.get_matrix(X, in_mem=in_mem) + return None + + def get_counts(self, in_mem: bool = False): + """Get the count matrix backed by zarr storage. + + Parameters + ---------- + in_mem: bool, default: False + Return the full matrix in memory rather than a reference to zarr group. + + Returns + ------- + scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix + The sparse X matrix. + + Examples + -------- + >>> zarr_data.counts + """ + + if "layers" in self.root: + layers = self.root["layers"] + if "counts" in layers: + counts = layers["counts"] + return self.get_matrix(counts, in_mem=in_mem) + return None + + def set_X(self, matrix: Union[csr_matrix, csc_matrix, coo_matrix]): """Set the X sparse matrix. This will overwrite the current stored X. Parameters ---------- - matrix: Union[scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix] + matrix: csr_matrix, csc_matrix, coo_matrix The sparse matrix. Examples @@ -210,28 +233,25 @@ def set_X(self, matrix: Union[csr_matrix, csc_matrix, coo_matrix]) -> None: def append_X( self, matrix: Union[csr_matrix, csc_matrix], axis: Optional[int] = None - ) -> None: + ): """Append to the X sparse matrix. + Only row-wise concatentation for csr_matrix. + Only col-wise concatentation for csc_matrix. Parameters ---------- - matrix: Union[scipy.sparse.csr_matrix, scipy.sparse.csc_matrix] + matrix: csr_matrix, csc_matrix The sparse matrix. - Notes - ----- - Only row-wise concatentation for csr_matrix. - Only col-wise concatentation for csc_matrix. - Examples -------- >>> zarr_data.append_X(matrix) """ - if "X" in self.root.keys(): + if "X" in self.root: self.append_matrix(self.root["X"], matrix, axis) - def get_var(self, column: str) -> Union[np.ndarray, pd.Categorical]: + def get_var(self, column: str) -> Union["numpy.ndarray", "pandas.Categorical"]: """Get data.var[column] data. Parameters @@ -249,17 +269,17 @@ def get_var(self, column: str) -> Union[np.ndarray, pd.Categorical]: >>> zarr_data.get_var("symbol") """ - if "var" in self.root.keys(): + if "var" in self.root: return self.get_annotation_column(self.root["var"], column) return None - def get_obs(self, column: str) -> Union[np.ndarray, pd.Categorical]: + def get_obs(self, column: str) -> Union["numpy.ndarray", "pandas.Categorical"]: """Get data.obs[column] data. Parameters ---------- column: str, - Column in obs. + Column name in obs. Returns ------- @@ -271,7 +291,7 @@ def get_obs(self, column: str) -> Union[np.ndarray, pd.Categorical]: >>> zarr_data.get_obs("celltype_name") """ - if "obs" in self.root.keys(): + if "obs" in self.root: return self.get_annotation_column(self.root["obs"], column) return None @@ -285,7 +305,7 @@ def get_uns(self, key: str): Returns ------- - undefined + object The data in data.uns[key] in the format it was stored as. Examples @@ -293,121 +313,191 @@ def get_uns(self, key: str): >>> zarr_data.get_uns("orig_genes") """ - if "uns" in self.root.keys(): + if "uns" in self.root: group = self.root["uns"] if key in group: return group[key][...] return None - def get_cell(self, idx: int) -> Union[csr_matrix, csc_matrix]: - """Get gene expression data for one cell as sparse matrix. + def get_row(self, group, idx: int) -> Union[csr_matrix, coo_matrix]: + """Get sparse row data as sparse matrix. Parameters ---------- + group: + A zarr group idx: int, Numerical index of the cell. + Returns + ------- + scipy.sparse.csr_matrix, scipy.sparse.coo_matrix + Row data as sparse matrix. + + Examples + -------- + >>> zarr_data.get_row(group, 42) + """ + + encoding_type = group.attrs["encoding-type"] + + if encoding_type == "csr_matrix": + return self.row_slice_csr(group, idx) + elif encoding_type == "coo_matrix": + return self.slice_coo(group, idx, axis=0) + raise RuntimeError( + f"Unsupported encoding-type for row slicing: {encoding_type}." + ) + + def get_col(self, group, idx: int) -> Union[csc_matrix, coo_matrix]: + """Get sparse column data as sparse matrix. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group + idx: int, + Numerical index of the cell. + + Returns + ------- + scipy.sparse.csc_matrix, scipy.sparse.coo_matrix + Column data as sparse matrix. + + Examples + -------- + >>> zarr_data.get_col(group, 42) + """ + + encoding_type = group.attrs["encoding-type"] + + if encoding_type == "csc_matrix": + return self.col_slice_csc(group, idx) + elif encoding_type == "coo_matrix": + return self.slice_coo(group, idx, axis=1) + raise RuntimeError( + f"Unsupported encoding-type for col slicing: {encoding_type}." + ) + + def get_cell(self, idx: int) -> Union[csr_matrix, coo_matrix]: + """Get gene expression data for one cell row as sparse matrix. + + Parameters + ---------- + idx: int, + Numerical index of the cell. + + Returns + ------- + scipy.sparse.csr_matrix, scipy.sparse.coo_matrix + Cell row data as sparse matrix. + Examples -------- >>> zarr_data.get_cell(42) """ - if "X" in self.root.keys(): - X = self.root["X"] - encoding_type = X.attrs["encoding-type"] - - if encoding_type == "csr_matrix": - return self.row_slice_csr(X, idx) - elif encoding_type == "coo_matrix": - return self.slice_coo(X, idx, axis=0) - raise RuntimeError( - f"Unsupported encoding-type for row slicing: {encoding_type}." - ) + if "X" in self.root: + return self.get_row(self.root["X"], idx) return None - def get_layer_cell(self, layer_key: str, idx: int) -> Union[csr_matrix, csc_matrix]: - """Get data for one cell from a layer as sparse matrix. + def get_layer_cell(self, layer_key: str, idx: int) -> Union[csr_matrix, coo_matrix]: + """Get data for one cell row from a layer as sparse matrix. Parameters ---------- idx: int, Numerical index of the cell. + Returns + ------- + scipy.sparse.csr_matrix, scipy.sparse.coo_matrix + Cell row data as sparse matrix. + Examples -------- >>> zarr_data.get_layer_cell(42) """ - if "layers" in self.root.keys(): - layers = self.root["layers"] - if layer_key in layers.keys(): - X = layers[layer_key] - encoding_type = X.attrs["encoding-type"] - - if encoding_type == "csr_matrix": - return self.row_slice_csr(X, idx) - elif encoding_type == "coo_matrix": - return self.slice_coo(X, idx, axis=0) - raise RuntimeError( - f"Unsupported encoding-type for row slicing: {encoding_type}." - ) + if "layers" in self.root: + if layer_key in self.root["layers"]: + return self.get_row(self.root["layers"][layer_key], idx) return None - def get_gene(self, idx: int) -> Union[csr_matrix, csc_matrix]: - """Get gene expression data for one gene as sparse matrix. + def get_gene(self, idx: int) -> Union[csc_matrix, coo_matrix]: + """Get gene expression data for one gene column as sparse matrix. Parameters ---------- idx: int, Numerical index of the gene. + Returns + ------- + scipy.sparse.csc_matrix, scipy.sparse.coo_matrix + Gene column data as sparse matrix. + Examples -------- >>> zarr_data.get_gene(42) """ - if "X" in self.root.keys(): - X = self.root["X"] - encoding_type = X.attrs["encoding-type"] - - if encoding_type == "csc_matrix": - return self.col_slice_csc(X, idx) - elif encoding_type == "coo_matrix": - return self.slice_coo(X, idx, axis=1) - raise RuntimeError( - f"Unsupported encoding-type for col slicing: {encoding_type}." - ) + if "X" in self.root: + return self.get_col(self.root["X"], idx) return None - def get_layer_gene(self, layer_key: str, idx: int) -> Union[csr_matrix, csc_matrix]: - """Get data for one gene from a layer as sparse matrix. + def get_layer_gene(self, layer_key: str, idx: int) -> Union[csc_matrix, coo_matrix]: + """Get data for one gene column from a layer as sparse matrix. Parameters ---------- + layer_key: str + The layer name. idx: int, Numerical index of the cell. + Returns + ------- + scipy.sparse.csc_matrix, scipy.sparse.coo_matrix + Gene column data as sparse matrix. + Examples -------- >>> zarr_data.get_layer_gene(42) """ - if "layers" in self.root.keys(): - layers = self.root["layers"] - if layer_key in layers.keys(): - X = layers[layer_key] - encoding_type = X.attrs["encoding-type"] - - if encoding_type == "csc_matrix": - return self.col_slice_csc(X, idx) - elif encoding_type == "coo_matrix": - return self.slice_coo(X, idx, axis=1) - raise RuntimeError( - f"Unsupported encoding-type for col slicing: {encoding_type}." - ) + if "layers" in self.root: + if layer_key in self.root["layers"]: + return self.get_col(self.root["layers"][layer_key], idx) return None - def slice_with(self, group, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + def slice_with( + self, group, idx: int + ) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray"]: + """Slice a sparse matrix, with its directional specification. + i.e. row-wise for csr, column-wise for csc. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + idx: int, + Numerical index of the cell. + + Returns + ------- + data: numpy.ndarray + Sparse matrix data list. + indices: numpy.ndarray + Sparse matrix indices. + indptr: numpy.ndarray + Sparse matrix indptr. + + Examples + -------- + >>> zarr_data.slice_with(group, 42) + """ + data = group["data"] indices = group["indices"] indptr = group["indptr"] @@ -420,7 +510,32 @@ def slice_with(self, group, idx: int) -> Tuple[np.ndarray, np.ndarray, np.ndarra def slice_across( self, group, idx: int - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray"]: + """Slice a sparse matrix, across its directional specification. + i.e. column-wise for csr, row-wise for csc. + This can be slow for large matrices. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + idx: int, + Numerical index of the cell. + + Returns + ------- + data: numpy.ndarray + Sparse matrix data list. + indices: numpy.ndarray + Sparse matrix indices. + indptr: numpy.ndarray + Sparse matrix indptr. + + Examples + -------- + >>> zarr_data.slice_across(group, 42) + """ + data = group["data"] indices = group["indices"] indptr = group["indptr"] @@ -439,16 +554,75 @@ def slice_across( return (new_data, new_indices, new_indptr) def row_slice_csr(self, group, idx: int) -> csr_matrix: + """Row slice a sparse csr matrix. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + idx: int, + Numerical index of the cell. + + Returns + ------- + scipy.sparse.csr_matrix + Sparse csr matrix slice for one row. + + Examples + -------- + >>> zarr_data.row_slice_csr(group, 42) + """ + new_data, new_indices, new_indptr = self.slice_with(group, idx) shape = group.attrs["shape"] return csr_matrix((new_data, new_indices, new_indptr), shape=(1, shape[1])) - def col_slice_csc(self, group: csc_matrix, idx: int) -> csc_matrix: + def col_slice_csc(self, group, idx: int) -> csc_matrix: + """Column slice a sparse csc matrix. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + idx: int, + Numerical index of the cell. + + Returns + ------- + scipy.sparse.csc_matrix + Sparse csc matrix slice for one column. + + Examples + -------- + >>> zarr_data.col_slice_csc(group, 42) + """ + new_data, new_indices, new_indptr = self.slice_with(group, idx) shape = group.attrs["shape"] return csc_matrix((new_data, new_indices, new_indptr), shape=(shape[0], 1)) def slice_coo(self, group, idx: int, axis: int) -> coo_matrix: + """Slice a sparse coo matrix. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + idx: int, + Numerical index of the cell. + axis: int + The axis along which to slice. + + Returns + ------- + scipy.sparse.coo_matrix + Sparse coo matrix sliced for one row or column. + + Examples + -------- + >>> zarr_data.slice_coo(group, 42, 0) + """ + data = group["data"] row = group["row"] col = group["col"] @@ -472,6 +646,25 @@ def slice_coo(self, group, idx: int, axis: int) -> coo_matrix: def get_matrix( self, group, in_mem: bool = False ) -> Union[csr_matrix, csc_matrix, coo_matrix]: + """Get the sparse matrix from zarr group. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + in_mem: bool, default: False + Return the full matrix in memory rather than a reference to zarr group. + + Returns + ------- + scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix + Sparse matrix. + + Examples + -------- + >>> zarr_data.get_matrix(group) + """ + encoding_type = group.attrs["encoding-type"] mtx = ARRAY_FORMATS[encoding_type]( tuple(group.attrs["shape"]), dtype=group["data"].dtype @@ -496,9 +689,22 @@ def get_matrix( raise RuntimeError(f"Unsupported encoding-type: {encoding_type}.") return mtx - def set_matrix( - self, group, matrix: Union[csr_matrix, csc_matrix, coo_matrix] - ) -> None: + def set_matrix(self, group, matrix: Union[csr_matrix, csc_matrix, coo_matrix]): + """Set the sparse matrix for a zarr group. + This will overwrite the current data. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + matrix: scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix + A sparse matrix. + + Examples + -------- + >>> zarr_data.set_matrix(group, matrix) + """ + encoding_type = type(matrix).__name__ group.attrs.setdefault("encoding-type", encoding_type) group.attrs.setdefault("encoding-version", "0.1.0") @@ -519,7 +725,23 @@ def set_matrix( def append_matrix( self, group, matrix: Union[csr_matrix, csc_matrix], axis: Optional[int] = None - ) -> None: + ): + """Append a sparse matrix for a zarr group. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + matrix: scipy.sparse.csr_matrix, scipy.sparse.csc_matrix, scipy.sparse.coo_matrix + A sparse matrix. + + Examples + -------- + >>> zarr_data.append_matrix(group, matrix) + """ + + import numpy as np + encoding_type = group.attrs["encoding-type"] shape = group.attrs["shape"] @@ -604,7 +826,26 @@ def append_matrix( col[orig_data_size:] = matrix.col + append_offset group.attrs["shape"] = new_shape - def get_annotation_index(self, group) -> pd.Index: + def get_annotation_index(self, group) -> "pandas.Index": + """Get the annotation index for a zarr group. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + + Returns + ------- + pandas.Index + The annotation index. + + Examples + -------- + >>> zarr_data.get_annotation_index(group) + """ + + import pandas as pd + group_index_field = group.attrs["_index"] idx = group[group_index_field][...] if pd.api.types.is_object_dtype(idx): @@ -613,13 +854,33 @@ def get_annotation_index(self, group) -> pd.Index: def get_annotation_column( self, group, column: str - ) -> Union[np.ndarray, pd.Categorical]: + ) -> Union["numpy.ndarray", "pandas.Categorical"]: + """Get an annotation column for a zarr group. + + Parameters + ---------- + group: zarr.hierarchy.Group + A zarr group. + column: str + The column name. + + Returns + ------- + numpy.ndarray, pandas.Categorical + The annotation column data, as a pandas categorical series + if the data is categorical, otherwise as a numpy ndarray. + + Examples + -------- + >>> zarr_data.get_annotation_column(group, "sample") + """ + + import pandas as pd + import zarr + if column in group: series = group[column] - if ( - isinstance(series, zarr.hierarchy.Group) - and "categories" in series.keys() - ): + if isinstance(series, zarr.hierarchy.Group) and "categories" in series: categories = series["categories"][...] if pd.api.types.is_object_dtype(categories): categories = categories.astype(str) @@ -645,9 +906,9 @@ def get_annotation_column( return values return None - def set_annotation(self, annotation: str, df: pd.DataFrame): + def set_annotation(self, annotation: str, df: "pandas.DataFrame"): """Store annotation (i.e. obs, var) from a dataframe. - This will overwrite the current stored annotation. + This will overwrite the current data. Parameters ---------- @@ -659,6 +920,9 @@ def set_annotation(self, annotation: str, df: pd.DataFrame): >>> zarr_data.set_annotation("obs", df) """ + import numcodecs + import pandas as pd + anno = self.root.create_group(annotation, overwrite=True) anno.attrs.setdefault("_index", "_index") anno.attrs.setdefault("column-order", list(df.columns)) @@ -708,7 +972,7 @@ def set_annotation(self, annotation: str, df: pd.DataFrame): anno[k].attrs.setdefault("encoding-type", "array") anno[k].attrs.setdefault("encoding-version", "0.2.0") - def append_annotation(self, annotation: str, df: pd.DataFrame): + def append_annotation(self, annotation: str, df: "pandas.DataFrame"): """Append annotation (i.e. obs, var) from a dataframe. Parameters @@ -721,8 +985,10 @@ def append_annotation(self, annotation: str, df: pd.DataFrame): >>> zarr_data.append_annotation("obs", df) """ + import pandas as pd + group = self.root[annotation] - columns = list(self.root[annotation].keys()) + columns = list(self.root[annotation]) current_df = pd.DataFrame( {x: self.get_annotation_column(group, x) for x in columns if x != "_index"} )