Skip to content

Commit

Permalink
scgpt in model server
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Jan 6, 2025
1 parent f0e8810 commit f6108df
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 10 deletions.
2 changes: 2 additions & 0 deletions tdc/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,7 @@ def get_task2category():
"evebio_pharmone_v1_summary_result_table": "tab",
"evebio_pharmone_v1_target_doc": "tab",
"evebio_pharmone_v1_target_table": "tab",
"cellxgene_sample_small": "h5ad",
}

name2id = {
Expand Down Expand Up @@ -1162,6 +1163,7 @@ def get_task2category():
"evebio_pharmone_v1_summary_result_table": 10741542,
"evebio_pharmone_v1_target_doc": 10741536,
"evebio_pharmone_v1_target_table": 10741537,
"cellxgene_sample_small": 10806522,
}

oracle2type = {
Expand Down
16 changes: 7 additions & 9 deletions tdc/model_server/tdc_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
'CYP3A4_Veith-AttentiveFP',
]

model_hub = ["Geneformer"]
model_hub = ["Geneformer", "scGPT"]


class tdc_hf_interface:
Expand Down Expand Up @@ -56,14 +56,12 @@ def load(self):
if self.model_name not in model_hub:
raise Exception("this model is not in the TDC model hub GH repo.")
elif self.model_name == "Geneformer":
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
# tokenizer = AutoTokenizer.from_pretrained("ctheodoris/Geneformer")
model = AutoModelForMaskedLM.from_pretrained(
"ctheodoris/Geneformer")
# pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer)
# pipe = pipeline("fill-mask", model="ctheodoris/Geneformer")
# return pipe
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("tdc/Geneformer")
return model
elif self.model_name == "scGPT":
from transformers import AutoModel
model = AutoModel.from_pretrained("tdc/scGPT")
return model
raise Exception("Not implemented yet!")

Expand Down
61 changes: 61 additions & 0 deletions tdc/model_server/tokenizers/scgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
from typing import List, Tuple


def tokenize_batch(
data: np.ndarray,
gene_ids: np.ndarray,
return_pt: bool = True,
append_cls: bool = True,
include_zero_gene: bool = False,
cls_id: str = "<cls>",
) -> List[Tuple]:
"""
Tokenize a batch of data. Returns a list of tuple (gene_id, count).
Args:
data (array-like): A batch of data, with shape (batch_size, n_features).
n_features equals the number of all genes.
gene_ids (array-like): A batch of gene ids, with shape (n_features,).
return_pt (bool): Whether to return torch tensors of gene_ids and counts,
default to True.
Returns:
list: A list of tuple (gene_names, counts) of non zero gene expressions.
"""
if data.shape[1] != len(gene_ids):
raise ValueError(
f"Number of features in data ({data.shape[1]}) does not match "
f"number of gene_ids ({len(gene_ids)}).")

tokenized_data = []
for i in range(len(data)):
row = data[i]
if include_zero_gene:
values = row
genes = gene_ids
else:
idx = np.nonzero(row)[0]
values = row[idx]
genes = gene_ids[idx]
if append_cls:
genes = np.insert(genes, 0, cls_id)
values = np.insert(values, 0, 0)
if return_pt:
import torch
values = torch.from_numpy(values).float().to(torch.int64)
tokenized_data.append((genes, values))
return tokenized_data


class scGPTTokenizer:

def __init__(self):
pass

@classmethod
def tokenize_cell_vectors(cls, data, gene_names):
"""
Tokenizing single-cell gene expression vectors formatted as anndata types
"""
return tokenize_batch(data, gene_names)
9 changes: 8 additions & 1 deletion tdc/multi_pred/anndata_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@

class DataLoader(DL):

def __init__(self, name, path, print_stats, dataset_names):
def __init__(self,
name,
path,
print_stats=False,
dataset_names=None,
no_convert=False):
super(DataLoader, self).__init__(name, path, print_stats, dataset_names)
self.adata = self.df # this is in AnnData format
cmap = ConfigMap()
self.cmap = cmap
self.config = cmap.get(name)
if no_convert:
return
if self.config is None:
# default to converting adata to dataframe as is
self.df = AnnDataToDataFrame.anndata_to_df(self.adata)
Expand Down
18 changes: 18 additions & 0 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ def setUp(self):
print(os.getcwd())
self.resource = cellxgene_census.CensusResource()

def testscGPT(self):
from tdc.multi_pred.anndata_dataset import DataLoader
from tdc import tdc_hf_interface
from tdc.model_server.tokenizers.scgpt import scGPTTokenizer
adata = DataLoader("cellxgene_sample_small",
"./data",
dataset_names=["cellxgene_sample_small"],
no_convert=True).adata
scgpt = tdc_hf_interface("scGPT")
model = scgpt.load() # this line can cause segmentation fault
tokenizer = scGPTTokenizer()
gene_ids = adata.var["feature_name"].to_numpy(
) # Convert to numpy array
tokenized_data = tokenizer.tokenize_cell_vectors(
adata.X.toarray(), gene_ids)
first_embed = model(tokenized_data[0][1]).last_hidden_state
self.assertEqual(first_embed.shape[0], len(gene_ids))

def testGeneformerTokenizer(self):

adata = self.resource.get_anndata(
Expand Down

0 comments on commit f6108df

Please sign in to comment.