From e2853910438f8fdf86d6387722ae72cbebbbac10 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Mon, 29 Apr 2024 21:18:16 +0200 Subject: [PATCH 01/13] initial config file --- src/common/library.bib | 8 ++++ .../methods/scgpt/config.vsh.yaml | 46 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 src/tasks/batch_integration/methods/scgpt/config.vsh.yaml diff --git a/src/common/library.bib b/src/common/library.bib index 313bfff56d..d8eb59eb17 100644 --- a/src/common/library.bib +++ b/src/common/library.bib @@ -329,6 +329,14 @@ @article{cover1967nearest url = {https://doi.org/10.1109/tit.1967.1053964} } +@article{cui2023scGPT, +title={scGPT: Towards Building a Foundation Model for Single-Cell Multi-omics Using Generative AI}, +author={Cui, Haotian and Wang, Chloe and Maan, Hassaan and Pang, Kuan and Luo, Fengning and Wang, Bo}, +journal={bioRxiv}, +year={2023}, +publisher={Cold Spring Harbor Laboratory} +} + @inproceedings{davis2006prauc, title = {The relationship between Precision-Recall and {ROC} curves}, diff --git a/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml b/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml new file mode 100644 index 0000000000..a1778b1c1c --- /dev/null +++ b/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml @@ -0,0 +1,46 @@ +__merge__: ../../api/comp_method_embedding.yaml +functionality: + name: scgpt_embedding + info: + label: scGPT Embedding + summary: Generation of cell embeddings for the integration of single cell transcriptomic count data using scGPT. + description: Generation of cell embeddings for the integration of single cell transcriptomic count data using scGPT. + reference: cui2023scGPT + documentation_url: https://scgpt.readthedocs.io/en/latest + repository_url: https://github.com/bowang-lab/scGPT + arguments: + - name: "--model" + type: file + direction: input + required: true + example: best_model.pt + description: | + Path to scGPT model file. + - name: "--model_vocab" + type: file + direction: input + required: true + example: vocab.json + description: | + Path to scGPT model vocabulary file. + - name: "--model_config" + type: file + direction: input + required: true + example: args.json + description: | + Path to scGPT model config file. + resources: + - type: python_script + path: script.py +platforms: + - type: docker + image: ghcr.io/openproblems-bio/base_pytorch_nvidia:1.0.4 + setup: + - type: python + pypi: + - scgpt + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] + \ No newline at end of file From 4abf355aa4ffe1e8066b1b188fc19912a7c12f11 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Mon, 29 Apr 2024 21:37:45 +0200 Subject: [PATCH 02/13] WIP script --- .../batch_integration/methods/scgpt/script.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/tasks/batch_integration/methods/scgpt/script.py diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py new file mode 100644 index 0000000000..b8dc2099d5 --- /dev/null +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -0,0 +1,25 @@ +import numpy as np +import anndata as ad +import json +from scgpt.tokenizer.gene_tokenizer import GeneVocab +from scgpt.model import TransformerModel +from scgpt.utils.util import load_pretrained +import torch + +## VIASH START + +par = { + 'input': 'resources_test/batch_integration/pancreas/dataset.h5ad', + 'model': 'resources_test/batch_integration/scgpt/pretrained_model', + 'model_config': 'resources_test/batch_integration/scgpt/pretrained_model/config.json', + 'model_vocab': 'resources_test/batch_integration/scgpt/pretrained_model/vocab.json', + 'output': 'output.h5ad' +} + +meta = { + 'functionality_name' : 'scgpt', +} + +## VIASH END + + From 03f46bc4b89dfd66d0afba63b35b7a4fd1cc963a Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Tue, 30 Apr 2024 13:32:09 +0200 Subject: [PATCH 03/13] WIP script --- .../methods/scgpt/config.vsh.yaml | 4 ++ .../batch_integration/methods/scgpt/script.py | 72 +++++++++++++++++-- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml b/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml index a1778b1c1c..cdaa5b6904 100644 --- a/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml +++ b/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml @@ -30,6 +30,10 @@ functionality: example: args.json description: | Path to scGPT model config file. + - name: --n_hvg + type: integer + default: 2000 + description: Number of highly variable genes to use. resources: - type: python_script path: script.py diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py index b8dc2099d5..0c0d2fcaf1 100644 --- a/src/tasks/batch_integration/methods/scgpt/script.py +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -1,6 +1,8 @@ import numpy as np import anndata as ad import json +import scipy +from sklearn.model_selection import train_test_split from scgpt.tokenizer.gene_tokenizer import GeneVocab from scgpt.model import TransformerModel from scgpt.utils.util import load_pretrained @@ -9,17 +11,75 @@ ## VIASH START par = { - 'input': 'resources_test/batch_integration/pancreas/dataset.h5ad', - 'model': 'resources_test/batch_integration/scgpt/pretrained_model', - 'model_config': 'resources_test/batch_integration/scgpt/pretrained_model/config.json', - 'model_vocab': 'resources_test/batch_integration/scgpt/pretrained_model/vocab.json', - 'output': 'output.h5ad' + "input": "resources_test/batch_integration/pancreas/dataset.h5ad", + "model": "resources_test/batch_integration/scgpt/pretrained_model", + "model_config": "resources_test/batch_integration/scgpt/pretrained_model/config.json", + "model_vocab": "resources_test/batch_integration/scgpt/pretrained_model/vocab.json", + "output": "output.h5ad" } meta = { - 'functionality_name' : 'scgpt', + "functionality_name" : "scgpt", } ## VIASH END +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +print("Load data", flush=True) +adata = ad.read_h5ad(par["input"]) + +print("Preprocess data", flush=True) + +if par["n_hvg"]: + print(f"Select top {par["n_hvg"]} high variable genes", flush=True) + idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]] + adata = adata[:, idx].copy() + +vocab = GeneVocab.from_file(par["model_vocab"]) + +adata.var["id_in_vocabulary"] = [ 1 if feature in vocab else -1 for feature in adata.var["feature_name"]] +feature_ids_in_vocab = np.array(adata.var["id_in_vocabulary"]) +adata = adata[:, adata.var["id_in_vocab"] >= 0] + +print("Load model config", flush=True) + +with open(par["model_config"], "r") as f: + model_configs = json.load(f) + +embsize = model_configs["embsize"] +nhead = model_configs["nheads"] +d_hid = model_configs["d_hid"] +nlayers = model_configs["nlayers"] +n_layers_cls = model_configs["n_layers_cls"] + +print("Tokenize input data", flush=True) +all_counts = ( + adata.layers["normalized"].A + if scipy.sparse.issparse(adata.layers["normalized"]) + else adata.layers["normalized"] +) +genes = adata.var["feature_name"].tolist() + +celltypes_labels = adata.obs["label"].tolist() # make sure count from 0 +num_types = len(set(celltypes_labels)) +celltypes_labels = np.array(celltypes_labels) + +batch_ids = adata.obs["batch"].tolist() +num_batch_types = len(set(batch_ids)) +batch_ids = np.array(batch_ids) + +( + train_data, + valid_data, + train_celltype_labels, + valid_celltype_labels, + train_batch_labels, + valid_batch_labels, +) = train_test_split( + all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True +) + +vocab.set_default_index(vocab[""]) +gene_ids = np.array(vocab(genes), dtype=int) \ No newline at end of file From 0d1216cfc465ff9c13305114986c1976a214d805 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Tue, 30 Apr 2024 17:14:05 +0200 Subject: [PATCH 04/13] Update cross check genes --- .../batch_integration/methods/scgpt/script.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py index 0c0d2fcaf1..55835a892e 100644 --- a/src/tasks/batch_integration/methods/scgpt/script.py +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -15,6 +15,8 @@ "model": "resources_test/batch_integration/scgpt/pretrained_model", "model_config": "resources_test/batch_integration/scgpt/pretrained_model/config.json", "model_vocab": "resources_test/batch_integration/scgpt/pretrained_model/vocab.json", + "pad_token": "", + "n_bins": 51, "output": "output.h5ad" } @@ -37,12 +39,23 @@ idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]] adata = adata[:, idx].copy() +print("Cross check genes", flush=True) + +pad_token = par["pad_token"] +special_tokens = [pad_token, "", ""] + vocab = GeneVocab.from_file(par["model_vocab"]) +[vocab.append_token(s) for s in special_tokens if s not in vocab] -adata.var["id_in_vocabulary"] = [ 1 if feature in vocab else -1 for feature in adata.var["feature_name"]] -feature_ids_in_vocab = np.array(adata.var["id_in_vocabulary"]) +adata.var["id_in_vocab"] = [ 1 if feature in vocab else -1 for feature in adata.var["feature_name"]] +feature_ids_in_vocab = np.array(adata.var["id_in_vocab"]) adata = adata[:, adata.var["id_in_vocab"] >= 0] + +print("Binning data", flush=True) + + + print("Load model config", flush=True) with open(par["model_config"], "r") as f: From 0dd4fce6cc5386932c84f3c3ac0fa7aced909c70 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Tue, 30 Apr 2024 18:01:32 +0200 Subject: [PATCH 05/13] Add binning --- .../batch_integration/methods/scgpt/script.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py index 55835a892e..920e0210f6 100644 --- a/src/tasks/batch_integration/methods/scgpt/script.py +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -28,6 +28,17 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: + assert x.ndim == 1 and bins.ndim == 1 + left_digits = np.digitize(x, bins) + right_digits = np.digitize(x, bins, right=True) + rands = np.random.rand(len(x)) # uniform random numbers + digits = rands * (right_digits - left_digits) + left_digits + digits = np.ceil(digits) + smallest_dtype = np.min_scalar_type(digits.max().astype(np.uint)) # Already checked for non-negative values + digits = digits.astype(smallest_dtype) + return digits + print("Load data", flush=True) adata = ad.read_h5ad(par["input"]) @@ -39,6 +50,7 @@ idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]] adata = adata[:, idx].copy() + print("Cross check genes", flush=True) pad_token = par["pad_token"] @@ -54,7 +66,28 @@ print("Binning data", flush=True) - +layer_data = adata.layers["normalized"] + +binned_rows = [] +bin_edges = [] +for row_number in range(adata.layers["normalized"].indptr.size-1): + row_start_index, row_end_index = layer_data.indptr[row_number], layer_data.indptr[row_number+1] + non_zero_row = layer_data.data[row_start_index:row_end_index] + if non_zero_row.max() == 0: + binned_rows.append(np.zeros_like(non_zero_row, dtype=np.int8)) + bin_edges.append(np.array([0] * par["n_bins"])) + continue + bins = np.quantile(non_zero_row, np.linspace(0, 1, par["n_bins"] - 1)) + non_zero_digits = _digitize(non_zero_row, bins) + assert non_zero_digits.min() >= 1 + assert non_zero_digits.max() <= par["n_bins"] - 1 + binned_rows.append(non_zero_digits) + bin_edges.append(np.concatenate([[0], bins])) + +adata.layers["binned"] = scipy.sparse.csc_matrix((np.concatenate(binned_rows, casting="same_kind"), + layer_data.indices, layer_data.indptr), shape=layer_data.shape) + +adata.obsm["bin_edges"] = np.stack(bin_edges) print("Load model config", flush=True) From 8dc7fb462c2acb81a2dc52e12b3543024ed74f06 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Fri, 3 May 2024 14:41:48 +0200 Subject: [PATCH 06/13] Add parameters --- .../methods/scgpt/config.vsh.yaml | 15 +++++++++++++++ .../batch_integration/methods/scgpt/script.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml b/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml index cdaa5b6904..306d22bfb6 100644 --- a/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml +++ b/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml @@ -34,6 +34,21 @@ functionality: type: integer default: 2000 description: Number of highly variable genes to use. + - name: --pad_token + type: string + default: "" + description: Padding token. + - name: --pad_value + type: integer + default: -2 + description: Padding value. + - name: --max_seq_len + type: integer + description: Maximum sequence length. + - name: --n_bins + type: integer + default: 51 + description: Number of bins. resources: - type: python_script path: script.py diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py index 920e0210f6..c9fd4e4933 100644 --- a/src/tasks/batch_integration/methods/scgpt/script.py +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -16,6 +16,8 @@ "model_config": "resources_test/batch_integration/scgpt/pretrained_model/config.json", "model_vocab": "resources_test/batch_integration/scgpt/pretrained_model/vocab.json", "pad_token": "", + "max_seq_len": None, + "pad_value": -2, "n_bins": 51, "output": "output.h5ad" } From 1ba5dd5d8accd9f0020e570585324bbcb7278dd1 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Fri, 3 May 2024 16:11:47 +0200 Subject: [PATCH 07/13] Add embedding --- .../batch_integration/methods/scgpt/script.py | 157 ++++++++++++++---- 1 file changed, 128 insertions(+), 29 deletions(-) diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py index c9fd4e4933..6ffcad6288 100644 --- a/src/tasks/batch_integration/methods/scgpt/script.py +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -4,6 +4,7 @@ import scipy from sklearn.model_selection import train_test_split from scgpt.tokenizer.gene_tokenizer import GeneVocab +from scgpt.tokenizer import tokenize_and_pad_batch from scgpt.model import TransformerModel from scgpt.utils.util import load_pretrained import torch @@ -12,14 +13,15 @@ par = { "input": "resources_test/batch_integration/pancreas/dataset.h5ad", - "model": "resources_test/batch_integration/scgpt/pretrained_model", - "model_config": "resources_test/batch_integration/scgpt/pretrained_model/config.json", - "model_vocab": "resources_test/batch_integration/scgpt/pretrained_model/vocab.json", + "model": "resources_test/scGPT_human/best_model.pt", + "model_config": "resources_test/scGPT_human/args.json", + "model_vocab": "resources_test/scGPT_human/vocab.json", "pad_token": "", "max_seq_len": None, "pad_value": -2, "n_bins": 51, - "output": "output.h5ad" + "output": "output.h5ad", + "n_hvg": 2000, } meta = { @@ -48,7 +50,7 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: print("Preprocess data", flush=True) if par["n_hvg"]: - print(f"Select top {par["n_hvg"]} high variable genes", flush=True) + print(f"Select top {par['n_hvg']} high variable genes", flush=True) idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]] adata = adata[:, idx].copy() @@ -91,6 +93,43 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: adata.obsm["bin_edges"] = np.stack(bin_edges) + + +print("Tokenize input data", flush=True) + +all_counts = ( + adata.layers["normalized"].A + if scipy.sparse.issparse(adata.layers["normalized"]) + else adata.layers["normalized"] +) +genes = adata.var["feature_name"].tolist() + +vocab.set_default_index(vocab[""]) +ntokens = len(vocab) +gene_ids = np.array(vocab(genes), dtype=int) + +if not par["max_seq_len"]: + max_seq_len = adata.var.shape[0] + 1 +else: + max_seq_len = par["max_seq_len"] + +tokenized_data = tokenize_and_pad_batch( + all_counts, + gene_ids, + max_len=max_seq_len, + vocab=vocab, + pad_token=pad_token, + pad_value=par["pad_value"], + append_cls=True, # append token at the beginning, + include_zero_gene=False, + return_pt=True, + mod_type=None, + vocab_mod=None + ) + +all_gene_ids, all_values = tokenized_data["genes"], tokenized_data["values"] +padding_mask = all_gene_ids.eq(vocab[pad_token]) + print("Load model config", flush=True) with open(par["model_config"], "r") as f: @@ -102,32 +141,92 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: nlayers = model_configs["nlayers"] n_layers_cls = model_configs["n_layers_cls"] -print("Tokenize input data", flush=True) -all_counts = ( - adata.layers["normalized"].A - if scipy.sparse.issparse(adata.layers["normalized"]) - else adata.layers["normalized"] +batch_ids = adata.obs["batch"].tolist() +num_batch_types = len(set(batch_ids)) + +model = TransformerModel( + ntokens, + d_model=embsize, + nhead=nhead, + d_hid=d_hid, + nlayers=nlayers, + vocab=vocab, + dropout=0.5, # scGPT default, only relevant for fine-tuning applications + pad_token=pad_token, + pad_value=par["pad_value"], + nlayers_cls=3, # only applicable for decoder-based operations + n_cls=1, # only applicable for decoder-based operations + do_mvc=False, # only applicable for decoder-based operations + ecs_threshold=0.8, # only applicable for decoder-based operations + do_dab=False, # only applicable for decoder-based operations + use_batch_labels=False, # only applicable for decoder-based operations + num_batch_labels=num_batch_types, + domain_spec_batchnorm=True, + input_emb_style="continuous", # scGPT default + explicit_zero_prob=False, #TODO: Parametrize when GPU-based machine types are supported + use_fast_transformer=True if device == "cuda" else False, #TODO: Parametrize when GPU-based machine types are supported + # fast_transformer_backend="flash", #TODO: Parametrize when GPU-based machine types are supported + pre_norm=False #TODO: Parametrize when GPU-based machine types are supported + ) + +load_pretrained( + model, + torch.load(par["model"], map_location=device), + verbose=False + ) + +model.to(device) +model.eval() + + +cell_embeddings = model.encode_batch( + torch.from_numpy(all_gene_ids), + torch.from_numpy(all_values).float(), + src_key_padding_mask=torch.from_numpy(padding_mask), + batch_size=par["batch_size"], + batch_labels=torch.from_numpy(batch_ids).long() if par["dsbn"] else None, + output_to_cpu=True, + time_step=0, + return_np=True ) -genes = adata.var["feature_name"].tolist() -celltypes_labels = adata.obs["label"].tolist() # make sure count from 0 -num_types = len(set(celltypes_labels)) -celltypes_labels = np.array(celltypes_labels) +cell_embeddings = cell_embeddings / np.linalg.norm( + cell_embeddings, axis=1, keepdims=True +) -batch_ids = adata.obs["batch"].tolist() -num_batch_types = len(set(batch_ids)) -batch_ids = np.array(batch_ids) - -( - train_data, - valid_data, - train_celltype_labels, - valid_celltype_labels, - train_batch_labels, - valid_batch_labels, -) = train_test_split( - all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True +print("Store outputs", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": cell_embeddings, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["functionality_name"], + }, ) -vocab.set_default_index(vocab[""]) -gene_ids = np.array(vocab(genes), dtype=int) \ No newline at end of file +print("Write output to file", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +# celltypes_labels = adata.obs["label"].tolist() # make sure count from 0 +# num_types = len(set(celltypes_labels)) +# celltypes_labels = np.array(celltypes_labels) + +# batch_ids = adata.obs["batch"].tolist() +# num_batch_types = len(set(batch_ids)) +# batch_ids = np.array(batch_ids) + +# ( +# train_data, +# valid_data, +# train_celltype_labels, +# valid_celltype_labels, +# train_batch_labels, +# valid_batch_labels, +# ) = train_test_split( +# all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True +# ) + From 347f8db7f9a4a5701ebc6b23f62f89dc3016a171 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Mon, 6 May 2024 10:50:34 +0200 Subject: [PATCH 08/13] Update script --- .../batch_integration/methods/scgpt/config.vsh.yaml | 4 ++++ src/tasks/batch_integration/methods/scgpt/script.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml b/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml index 306d22bfb6..7c6713dc1b 100644 --- a/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml +++ b/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml @@ -49,6 +49,10 @@ functionality: type: integer default: 51 description: Number of bins. + - name: --batch_size + type: integer + default: 64 + description: Batch size. resources: - type: python_script path: script.py diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py index 6ffcad6288..1ed1e44a84 100644 --- a/src/tasks/batch_integration/methods/scgpt/script.py +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -22,6 +22,7 @@ "n_bins": 51, "output": "output.h5ad", "n_hvg": 2000, + "batch_size": 64, } meta = { @@ -129,6 +130,7 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: all_gene_ids, all_values = tokenized_data["genes"], tokenized_data["values"] padding_mask = all_gene_ids.eq(vocab[pad_token]) +padding_mask = padding_mask.numpy() print("Load model config", flush=True) @@ -141,7 +143,10 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: nlayers = model_configs["nlayers"] n_layers_cls = model_configs["n_layers_cls"] -batch_ids = adata.obs["batch"].tolist() +batch_id_cats = adata.obs["batch"].astype("category") +batch_id_labels = batch_id_cats.cat.codes.values +batch_ids = batch_id_labels.tolist() +batch_ids = np.array(batch_ids) num_batch_types = len(set(batch_ids)) model = TransformerModel( @@ -180,9 +185,9 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: cell_embeddings = model.encode_batch( - torch.from_numpy(all_gene_ids), - torch.from_numpy(all_values).float(), - src_key_padding_mask=torch.from_numpy(padding_mask), + torch.from_numpy(np.array(all_gene_ids)), + torch.from_numpy(np.array(all_values)).float(), + src_key_padding_mask=torch.from_numpy(np.array(padding_mask)), batch_size=par["batch_size"], batch_labels=torch.from_numpy(batch_ids).long() if par["dsbn"] else None, output_to_cpu=True, From b65e304dcc9419b240d49c49841bdc8819e467bb Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Mon, 6 May 2024 10:51:35 +0200 Subject: [PATCH 09/13] Add to workflow --- .../batch_integration/workflows/run_benchmark/config.vsh.yaml | 1 + src/tasks/batch_integration/workflows/run_benchmark/main.nf | 1 + 2 files changed, 2 insertions(+) diff --git a/src/tasks/batch_integration/workflows/run_benchmark/config.vsh.yaml b/src/tasks/batch_integration/workflows/run_benchmark/config.vsh.yaml index b430734e22..a7602d1896 100644 --- a/src/tasks/batch_integration/workflows/run_benchmark/config.vsh.yaml +++ b/src/tasks/batch_integration/workflows/run_benchmark/config.vsh.yaml @@ -69,6 +69,7 @@ functionality: - name: batch_integration/methods/scanorama_feature - name: batch_integration/methods/scanvi - name: batch_integration/methods/scvi + - name: batch_integration/methods/scgpt - name: batch_integration/control_methods/no_integration_batch - name: batch_integration/control_methods/random_embed_cell - name: batch_integration/control_methods/random_embed_cell_jitter diff --git a/src/tasks/batch_integration/workflows/run_benchmark/main.nf b/src/tasks/batch_integration/workflows/run_benchmark/main.nf index 5543ac91cd..f01d4b6411 100644 --- a/src/tasks/batch_integration/workflows/run_benchmark/main.nf +++ b/src/tasks/batch_integration/workflows/run_benchmark/main.nf @@ -27,6 +27,7 @@ workflow run_wf { scanorama_feature, scanvi, scvi, + scgpt, no_integration_batch, random_embed_cell, random_embed_cell_jitter, From 8cc8cd5684e1a3b8568c0d4bda9a6c7bd6f73132 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Mon, 6 May 2024 11:37:22 +0200 Subject: [PATCH 10/13] change to csr_matrix --- src/tasks/batch_integration/methods/scgpt/script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py index 1ed1e44a84..786cde6ab9 100644 --- a/src/tasks/batch_integration/methods/scgpt/script.py +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -89,7 +89,7 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: binned_rows.append(non_zero_digits) bin_edges.append(np.concatenate([[0], bins])) -adata.layers["binned"] = scipy.sparse.csc_matrix((np.concatenate(binned_rows, casting="same_kind"), +adata.layers["binned"] = scipy.sparse.csr_matrix((np.concatenate(binned_rows, casting="same_kind"), layer_data.indices, layer_data.indptr), shape=layer_data.shape) adata.obsm["bin_edges"] = np.stack(bin_edges) From 56c3700e6de0079982c8b94d3525a423586a63d6 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Mon, 6 May 2024 11:45:44 +0200 Subject: [PATCH 11/13] Remove dsbn par --- src/tasks/batch_integration/methods/scgpt/script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt/script.py index 786cde6ab9..b5ca5d11df 100644 --- a/src/tasks/batch_integration/methods/scgpt/script.py +++ b/src/tasks/batch_integration/methods/scgpt/script.py @@ -189,7 +189,7 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: torch.from_numpy(np.array(all_values)).float(), src_key_padding_mask=torch.from_numpy(np.array(padding_mask)), batch_size=par["batch_size"], - batch_labels=torch.from_numpy(batch_ids).long() if par["dsbn"] else None, + batch_labels=torch.from_numpy(batch_ids).long(), output_to_cpu=True, time_step=0, return_np=True From da55084ce01409b4ffc56fe6ae28fdc10bcc0643 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Mon, 6 May 2024 11:53:40 +0200 Subject: [PATCH 12/13] Fix typo --- .../methods/{scgpt => scgpt_embedding}/config.vsh.yaml | 0 .../methods/{scgpt => scgpt_embedding}/script.py | 0 .../batch_integration/methods/scgpt_embedding/test.sh | 10 ++++++++++ .../workflows/run_benchmark/config.vsh.yaml | 2 +- .../batch_integration/workflows/run_benchmark/main.nf | 2 +- 5 files changed, 12 insertions(+), 2 deletions(-) rename src/tasks/batch_integration/methods/{scgpt => scgpt_embedding}/config.vsh.yaml (100%) rename src/tasks/batch_integration/methods/{scgpt => scgpt_embedding}/script.py (100%) create mode 100644 src/tasks/batch_integration/methods/scgpt_embedding/test.sh diff --git a/src/tasks/batch_integration/methods/scgpt/config.vsh.yaml b/src/tasks/batch_integration/methods/scgpt_embedding/config.vsh.yaml similarity index 100% rename from src/tasks/batch_integration/methods/scgpt/config.vsh.yaml rename to src/tasks/batch_integration/methods/scgpt_embedding/config.vsh.yaml diff --git a/src/tasks/batch_integration/methods/scgpt/script.py b/src/tasks/batch_integration/methods/scgpt_embedding/script.py similarity index 100% rename from src/tasks/batch_integration/methods/scgpt/script.py rename to src/tasks/batch_integration/methods/scgpt_embedding/script.py diff --git a/src/tasks/batch_integration/methods/scgpt_embedding/test.sh b/src/tasks/batch_integration/methods/scgpt_embedding/test.sh new file mode 100644 index 0000000000..8475fe7a47 --- /dev/null +++ b/src/tasks/batch_integration/methods/scgpt_embedding/test.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +set -e + +viash run src/tasks/batch_integration/methods/scgpt/config.vsh.yaml -- \ + --input "resources_test/batch_integration/pancreas/dataset.h5ad" \ + --model "resources_test/scGPT_human/best_model.pt" \ + --model_config "resources_test/scGPT_human/args.json" \ + --model_vocab "resources_test/scGPT_human/vocab.json" \ + --output "output/temp/scgpt/pancreas.h5ad" \ \ No newline at end of file diff --git a/src/tasks/batch_integration/workflows/run_benchmark/config.vsh.yaml b/src/tasks/batch_integration/workflows/run_benchmark/config.vsh.yaml index a7602d1896..f03df7416d 100644 --- a/src/tasks/batch_integration/workflows/run_benchmark/config.vsh.yaml +++ b/src/tasks/batch_integration/workflows/run_benchmark/config.vsh.yaml @@ -69,7 +69,7 @@ functionality: - name: batch_integration/methods/scanorama_feature - name: batch_integration/methods/scanvi - name: batch_integration/methods/scvi - - name: batch_integration/methods/scgpt + - name: batch_integration/methods/scgpt_embedding - name: batch_integration/control_methods/no_integration_batch - name: batch_integration/control_methods/random_embed_cell - name: batch_integration/control_methods/random_embed_cell_jitter diff --git a/src/tasks/batch_integration/workflows/run_benchmark/main.nf b/src/tasks/batch_integration/workflows/run_benchmark/main.nf index f01d4b6411..5f2ef8906b 100644 --- a/src/tasks/batch_integration/workflows/run_benchmark/main.nf +++ b/src/tasks/batch_integration/workflows/run_benchmark/main.nf @@ -27,7 +27,7 @@ workflow run_wf { scanorama_feature, scanvi, scvi, - scgpt, + scgpt_embedding, no_integration_batch, random_embed_cell, random_embed_cell_jitter, From 63b912cc4dfc12e6119ceff54b8cd10476f4db94 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Wed, 8 May 2024 10:37:43 +0200 Subject: [PATCH 13/13] Add model location to workflow --- src/tasks/batch_integration/workflows/run_benchmark/main.nf | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tasks/batch_integration/workflows/run_benchmark/main.nf b/src/tasks/batch_integration/workflows/run_benchmark/main.nf index 5f2ef8906b..bc7a815069 100644 --- a/src/tasks/batch_integration/workflows/run_benchmark/main.nf +++ b/src/tasks/batch_integration/workflows/run_benchmark/main.nf @@ -27,7 +27,9 @@ workflow run_wf { scanorama_feature, scanvi, scvi, - scgpt_embedding, + scgpt_embedding.run( + args : [model: "s3://openproblems-data/resources/foundation_models/scgpt/scGPT_human/best_model.pt", model_config: "s3://openproblems-data/resources/foundation_models/scgpt/scGPT_human/args.json", model_vocab: "s3://openproblems-data/resources/foundation_models/scgpt/scGPT_human/vocab.json"] + ), no_integration_batch, random_embed_cell, random_embed_cell_jitter,