Skip to content

Commit

Permalink
Merge pull request #345 from mims-harvard/init_version
Browse files Browse the repository at this point in the history
new version and dependency removal also geneformer no scperturb fix
  • Loading branch information
amva13 authored Jan 21, 2025
2 parents eedaf0c + 6fa1e1d commit f7a8e9c
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 20 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
accelerate==0.33.0
biopython>=1.78,<2.0
dataclasses>=0.6,<1.0
datasets<2.20.0
evaluate==0.4.2
Expand Down
8 changes: 3 additions & 5 deletions tdc/feature_generators/protein_feature_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
Goal is to make it easier to integrate custom datasets not yet in TDC format.
"""

from Bio import Entrez, SeqIO
import mygene
from pandas import DataFrame
import requests

Expand Down Expand Up @@ -37,6 +35,7 @@ def get_ncrna_sequence(cls, ncrna_id):
Returns:
str: The nucleotide sequence of the non-coding RNA, or a message if not found.
"""
from Bio import Entrez, SeqIO
# Provide your email to NCBI to let them know who you are
Entrez.email = "[email protected]"

Expand Down Expand Up @@ -88,6 +87,7 @@ def get_protein_sequence(cls, gene_name: str) -> str:
Returns:
str: Protein amino acid sequence.
"""
import mygene
assert isinstance(gene_name, str), (type(gene_name), gene_name)
mg = mygene.MyGeneInfo()
# Query MyGene.info for the given gene name
Expand All @@ -111,6 +111,7 @@ def get_protein_sequence(cls, gene_name: str) -> str:

@classmethod
def get_type_of_gene(cls, gene_name: str) -> str:
import mygene
assert isinstance(gene_name, str), (type(gene_name), gene_name)
mg = mygene.MyGeneInfo()
# Query MyGene.info for the given gene name
Expand Down Expand Up @@ -157,9 +158,6 @@ def helper_type(gene_name):
if gene_column not in dataset.columns:
raise ValueError(f"{gene_column} does not exist in the DataFrame.")

# Ensure the DataFrame index is aligned
# gene_df = gene_df.reset_index(drop=True)

# Retrieve protein sequences for each gene and store them in a new column
new_col = dataset[gene_column].apply(helper).tolist()
assert len(new_col) == len(dataset[gene_column]), (new_col,
Expand Down
6 changes: 3 additions & 3 deletions tdc/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,9 +1145,9 @@ def get_task2category():
"pinnacle_output8": 10431074,
"pinnacle_output9": 10431075,
"pinnacle_output10": 10431081,
"geneformer_gene_median_dictionary": 10836445,
"geneformer_gene_name_id_dict": 10836444,
"geneformer_token_dictionary": 10836446,
"geneformer_gene_median_dictionary": 10806947,
"geneformer_gene_name_id_dict": 10806948,
"geneformer_token_dictionary": 10806949,
"evebio_pharmone_v1_assay_doc": 10741530,
"evebio_pharmone_v1_assay_table": 10741541,
"evebio_pharmone_v1_bundle_doc": 10741540,
Expand Down
9 changes: 6 additions & 3 deletions tdc/model_server/tokenizers/geneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
path=None,
custom_attr_name_dict=None,
nproc=1,
max_input_size=4096,
):
path = path or "./data"
download_wrapper("geneformer_gene_median_dictionary", path,
Expand All @@ -39,6 +40,8 @@ def __init__(
self.genelist_dict = dict(
zip(self.gene_keys, [True] * len(self.gene_keys)))

self.max_input_size = max_input_size

@classmethod
def rank_genes(cls, gene_vector, gene_tokens):
"""
Expand Down Expand Up @@ -73,7 +76,7 @@ def tokenize_cell_vectors(self,
])
coding_miRNA_ids = adata.var[ensembl_id][coding_miRNA_loc]
coding_miRNA_tokens = np.array(
[self.gene_token_dict[i] for i in coding_miRNA_ids])
[self.gene_token_dict.get(i, 0) for i in coding_miRNA_ids])

try:
_ = adata.obs["filter_pass"]
Expand Down Expand Up @@ -102,8 +105,8 @@ def tokenize_cell_vectors(self,
X_norm = sp.csr_matrix(X_norm)

tokenized_cells.append([
self.rank_genes(X_norm[i].data,
coding_miRNA_tokens[X_norm[i].indices])
self.rank_genes(X_norm[i].data, coding_miRNA_tokens[
X_norm[i].indices])[:self.max_input_size]
for i in range(X_norm.shape[0])
])

Expand Down
6 changes: 3 additions & 3 deletions tdc/multi_pred/anndata_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ def __init__(self,
path,
print_stats=False,
dataset_names=None,
no_convert=False):
no_convert=True):
super(DataLoader, self).__init__(name, path, print_stats, dataset_names)
self.adata = self.df # this is in AnnData format
if no_convert:
return
cmap = ConfigMap()
self.cmap = cmap
self.config = cmap.get(name)
if no_convert:
return
if self.config is None:
# default to converting adata to dataframe as is
self.df = AnnDataToDataFrame.anndata_to_df(self.adata)
Expand Down
2 changes: 2 additions & 0 deletions tdc/multi_pred/perturboutcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, name, path="./data", print_stats=False):
self.is_gene = True
else:
self.is_gene = False
self.is_combo = False
return

if name == 'scperturb_gene_NormanWeissman2019':
self.is_combo = True
Expand Down
46 changes: 46 additions & 0 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,52 @@ def testscGPT(self):
attention_mask=mask)
print(f"scgpt ran successfully. here is an output {first_embed}")

def testGeneformerPerturb(self):
from tdc.multi_pred.perturboutcome import PerturbOutcome
dataset = "scperturb_drug_AissaBenevolenskaya2021"
data = PerturbOutcome(dataset)
adata = data.adata
tokenizer = GeneformerTokenizer(max_input_size=3)
adata.var["feature_id"] = adata.var.index.map(
lambda x: tokenizer.gene_name_id_dict.get(x, 0))
x = tokenizer.tokenize_cell_vectors(adata,
ensembl_id="feature_id",
ncounts="ncounts")
cells, _ = x
assert cells, "FAILURE: cells false-like. Value is = {}".format(cells)
assert len(cells) > 0, "FAILURE: length of cells <= 0 {}".format(cells)
from tdc import tdc_hf_interface
import torch
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()
mdim = max(len(cell) for b in cells for cell in b)
batch = cells[0]
for idx, cell in enumerate(batch):
if len(cell) < mdim:
for _ in range(mdim - len(cell)):
cell = np.append(cell, 0)
batch[idx] = cell
input_tensor = torch.tensor(batch)
assert input_tensor.shape[0] == 512, "unexpected batch size"
assert input_tensor.shape[1] == mdim, f"unexpected gene length {mdim}"
attention_mask = torch.tensor([[t != 0 for t in cell] for cell in batch
])
assert input_tensor.shape[0] == attention_mask.shape[0]
assert input_tensor.shape[1] == attention_mask.shape[1]
try:
outputs = model(input_tensor,
attention_mask=attention_mask,
output_hidden_states=True)
except Exception as e:
raise Exception(
f"sizes: {input_tensor.shape[0]}, {input_tensor.shape[1]}\n {e}"
)
num_out_in_batch = len(outputs.hidden_states[-1])
input_batch_size = input_tensor.shape[0]
num_gene_out_in_batch = len(outputs.hidden_states[-1][0])
assert num_out_in_batch == input_batch_size, f"FAILURE: length doesn't match batch size {num_out_in_batch} vs {input_batch_size}"
assert num_gene_out_in_batch == mdim, f"FAILURE: out length {num_gene_out_in_batch} doesn't match gene length {mdim}"

def testGeneformerTokenizer(self):

adata = self.resource.get_anndata(
Expand Down
6 changes: 2 additions & 4 deletions tdc/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,10 @@ def pd_load(name, path):
df = pd.read_pickle(os.path.join(path, name + "/" + name + ".pkl"))
elif name2type[name] == "h5ad":
import anndata
print_sys("loading anndata object...")
adata = anndata.read_h5ad(
os.path.join(path, name + "." + name2type[name]))
# df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names, index=adata.obs_names)
# TODO: multi-index would help include var information in columns
# multi_index = pd.MultiIndex.from_frame(adata.var.reset_index())
# df.columns = multi_index
print_sys("loader anndata object!")
return adata
elif name2type[name] == "json":
# df = pd.read_json(os.path.join(path, name + "." + name2type[name]))
Expand Down
2 changes: 1 addition & 1 deletion tdc/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
#
__version__ = "1.1.6" # pragma: no cover
__version__ = "1.1.12" # pragma: no cover

0 comments on commit f7a8e9c

Please sign in to comment.