diff --git a/requirements.txt b/requirements.txt index 8004b31a..73b1262a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ accelerate==0.33.0 -biopython>=1.78,<2.0 dataclasses>=0.6,<1.0 datasets<2.20.0 evaluate==0.4.2 diff --git a/tdc/feature_generators/protein_feature_generator.py b/tdc/feature_generators/protein_feature_generator.py index bdc4fada..214cd7bf 100644 --- a/tdc/feature_generators/protein_feature_generator.py +++ b/tdc/feature_generators/protein_feature_generator.py @@ -3,8 +3,6 @@ Goal is to make it easier to integrate custom datasets not yet in TDC format. """ -from Bio import Entrez, SeqIO -import mygene from pandas import DataFrame import requests @@ -37,6 +35,7 @@ def get_ncrna_sequence(cls, ncrna_id): Returns: str: The nucleotide sequence of the non-coding RNA, or a message if not found. """ + from Bio import Entrez, SeqIO # Provide your email to NCBI to let them know who you are Entrez.email = "alejandro_velez-arce@hms.harvard.edu" @@ -88,6 +87,7 @@ def get_protein_sequence(cls, gene_name: str) -> str: Returns: str: Protein amino acid sequence. """ + import mygene assert isinstance(gene_name, str), (type(gene_name), gene_name) mg = mygene.MyGeneInfo() # Query MyGene.info for the given gene name @@ -111,6 +111,7 @@ def get_protein_sequence(cls, gene_name: str) -> str: @classmethod def get_type_of_gene(cls, gene_name: str) -> str: + import mygene assert isinstance(gene_name, str), (type(gene_name), gene_name) mg = mygene.MyGeneInfo() # Query MyGene.info for the given gene name @@ -157,9 +158,6 @@ def helper_type(gene_name): if gene_column not in dataset.columns: raise ValueError(f"{gene_column} does not exist in the DataFrame.") - # Ensure the DataFrame index is aligned - # gene_df = gene_df.reset_index(drop=True) - # Retrieve protein sequences for each gene and store them in a new column new_col = dataset[gene_column].apply(helper).tolist() assert len(new_col) == len(dataset[gene_column]), (new_col, diff --git a/tdc/metadata.py b/tdc/metadata.py index 07fac5ad..f4779c40 100644 --- a/tdc/metadata.py +++ b/tdc/metadata.py @@ -1145,9 +1145,9 @@ def get_task2category(): "pinnacle_output8": 10431074, "pinnacle_output9": 10431075, "pinnacle_output10": 10431081, - "geneformer_gene_median_dictionary": 10836445, - "geneformer_gene_name_id_dict": 10836444, - "geneformer_token_dictionary": 10836446, + "geneformer_gene_median_dictionary": 10806947, + "geneformer_gene_name_id_dict": 10806948, + "geneformer_token_dictionary": 10806949, "evebio_pharmone_v1_assay_doc": 10741530, "evebio_pharmone_v1_assay_table": 10741541, "evebio_pharmone_v1_bundle_doc": 10741540, diff --git a/tdc/model_server/tokenizers/geneformer.py b/tdc/model_server/tokenizers/geneformer.py index 48225682..97cd18ac 100644 --- a/tdc/model_server/tokenizers/geneformer.py +++ b/tdc/model_server/tokenizers/geneformer.py @@ -16,6 +16,7 @@ def __init__( path=None, custom_attr_name_dict=None, nproc=1, + max_input_size=4096, ): path = path or "./data" download_wrapper("geneformer_gene_median_dictionary", path, @@ -39,6 +40,8 @@ def __init__( self.genelist_dict = dict( zip(self.gene_keys, [True] * len(self.gene_keys))) + self.max_input_size = max_input_size + @classmethod def rank_genes(cls, gene_vector, gene_tokens): """ @@ -73,7 +76,7 @@ def tokenize_cell_vectors(self, ]) coding_miRNA_ids = adata.var[ensembl_id][coding_miRNA_loc] coding_miRNA_tokens = np.array( - [self.gene_token_dict[i] for i in coding_miRNA_ids]) + [self.gene_token_dict.get(i, 0) for i in coding_miRNA_ids]) try: _ = adata.obs["filter_pass"] @@ -102,8 +105,8 @@ def tokenize_cell_vectors(self, X_norm = sp.csr_matrix(X_norm) tokenized_cells.append([ - self.rank_genes(X_norm[i].data, - coding_miRNA_tokens[X_norm[i].indices]) + self.rank_genes(X_norm[i].data, coding_miRNA_tokens[ + X_norm[i].indices])[:self.max_input_size] for i in range(X_norm.shape[0]) ]) diff --git a/tdc/multi_pred/anndata_dataset.py b/tdc/multi_pred/anndata_dataset.py index 65788ad2..6c574b50 100644 --- a/tdc/multi_pred/anndata_dataset.py +++ b/tdc/multi_pred/anndata_dataset.py @@ -10,14 +10,14 @@ def __init__(self, path, print_stats=False, dataset_names=None, - no_convert=False): + no_convert=True): super(DataLoader, self).__init__(name, path, print_stats, dataset_names) self.adata = self.df # this is in AnnData format + if no_convert: + return cmap = ConfigMap() self.cmap = cmap self.config = cmap.get(name) - if no_convert: - return if self.config is None: # default to converting adata to dataframe as is self.df = AnnDataToDataFrame.anndata_to_df(self.adata) diff --git a/tdc/multi_pred/perturboutcome.py b/tdc/multi_pred/perturboutcome.py index 6c272471..66630602 100644 --- a/tdc/multi_pred/perturboutcome.py +++ b/tdc/multi_pred/perturboutcome.py @@ -104,6 +104,8 @@ def __init__(self, name, path="./data", print_stats=False): self.is_gene = True else: self.is_gene = False + self.is_combo = False + return if name == 'scperturb_gene_NormanWeissman2019': self.is_combo = True diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 6f5a5084..77dc46cd 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -55,6 +55,52 @@ def testscGPT(self): attention_mask=mask) print(f"scgpt ran successfully. here is an output {first_embed}") + def testGeneformerPerturb(self): + from tdc.multi_pred.perturboutcome import PerturbOutcome + dataset = "scperturb_drug_AissaBenevolenskaya2021" + data = PerturbOutcome(dataset) + adata = data.adata + tokenizer = GeneformerTokenizer(max_input_size=3) + adata.var["feature_id"] = adata.var.index.map( + lambda x: tokenizer.gene_name_id_dict.get(x, 0)) + x = tokenizer.tokenize_cell_vectors(adata, + ensembl_id="feature_id", + ncounts="ncounts") + cells, _ = x + assert cells, "FAILURE: cells false-like. Value is = {}".format(cells) + assert len(cells) > 0, "FAILURE: length of cells <= 0 {}".format(cells) + from tdc import tdc_hf_interface + import torch + geneformer = tdc_hf_interface("Geneformer") + model = geneformer.load() + mdim = max(len(cell) for b in cells for cell in b) + batch = cells[0] + for idx, cell in enumerate(batch): + if len(cell) < mdim: + for _ in range(mdim - len(cell)): + cell = np.append(cell, 0) + batch[idx] = cell + input_tensor = torch.tensor(batch) + assert input_tensor.shape[0] == 512, "unexpected batch size" + assert input_tensor.shape[1] == mdim, f"unexpected gene length {mdim}" + attention_mask = torch.tensor([[t != 0 for t in cell] for cell in batch + ]) + assert input_tensor.shape[0] == attention_mask.shape[0] + assert input_tensor.shape[1] == attention_mask.shape[1] + try: + outputs = model(input_tensor, + attention_mask=attention_mask, + output_hidden_states=True) + except Exception as e: + raise Exception( + f"sizes: {input_tensor.shape[0]}, {input_tensor.shape[1]}\n {e}" + ) + num_out_in_batch = len(outputs.hidden_states[-1]) + input_batch_size = input_tensor.shape[0] + num_gene_out_in_batch = len(outputs.hidden_states[-1][0]) + assert num_out_in_batch == input_batch_size, f"FAILURE: length doesn't match batch size {num_out_in_batch} vs {input_batch_size}" + assert num_gene_out_in_batch == mdim, f"FAILURE: out length {num_gene_out_in_batch} doesn't match gene length {mdim}" + def testGeneformerTokenizer(self): adata = self.resource.get_anndata( diff --git a/tdc/utils/load.py b/tdc/utils/load.py index 7acd58f4..b12811e9 100644 --- a/tdc/utils/load.py +++ b/tdc/utils/load.py @@ -304,12 +304,10 @@ def pd_load(name, path): df = pd.read_pickle(os.path.join(path, name + "/" + name + ".pkl")) elif name2type[name] == "h5ad": import anndata + print_sys("loading anndata object...") adata = anndata.read_h5ad( os.path.join(path, name + "." + name2type[name])) - # df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names) - # TODO: multi-index would help include var information in columns - # multi_index = pd.MultiIndex.from_frame(adata.var.reset_index()) - # df.columns = multi_index + print_sys("loader anndata object!") return adata elif name2type[name] == "json": # df = pd.read_json(os.path.join(path, name + "." + name2type[name])) diff --git a/tdc/version.py b/tdc/version.py index d20a266d..b11adef6 100644 --- a/tdc/version.py +++ b/tdc/version.py @@ -19,4 +19,4 @@ # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' # -__version__ = "1.1.6" # pragma: no cover +__version__ = "1.1.12" # pragma: no cover