Skip to content

Commit

Permalink
mend
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Jan 21, 2025
1 parent b237dd6 commit e67fcd0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 34 deletions.
1 change: 1 addition & 0 deletions .github/workflows/conda-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
echo "Creating Conda Environment from environment.yml"
conda env create -f environment.yml
conda activate tdc-conda-env
python run_tests.py tdc.test.test_model_server.TestModelServer.testGeneformerPerturb
python run_tests.py
yapf --style=google -r -d tdc
conda deactivate
52 changes: 18 additions & 34 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,40 +73,24 @@ def testGeneformerPerturb(self):
import torch
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()
input_tensor = torch.tensor(cells)
out = []
try:
ctr = 0 # stop after some passes to avoid failure
for batch in input_tensor:
# build an attention mask
attention_mask = torch.tensor(
[[x[0] != 0, x[1] != 0] for x in batch])
outputs = model(batch,
attention_mask=attention_mask,
output_hidden_states=True)
layer_to_quant = quant_layers(model) + (
-1
) # TODO note this can be parametrized to either 0 (extract last embedding layer) or -1 (second-to-last which is more generalized)
embs_i = outputs.hidden_states[layer_to_quant]
# there are "cls", "cell", and "gene" embeddings. we will only capture "gene", which is cell type specific. for "cell", you'd average out across unmasked gene embeddings per cell
embs = embs_i
out.append(embs)
if ctr == 2:
break
ctr += 1
except Exception as e:
raise Exception(e)

assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(
out)
assert len(
out
) == 3, "length not matching ctr+1: {} vs {}. output was \n {}".format(
len(out), ctr + 1, out)
print(
"Geneformer ran sucessfully. Find batch embedding example here:\n {}"
.format(
out[0])) # TODO: test for cell-type-specific emebdding count
mdim = max(len(cell) for b in cells for cell in b)
batch = cells[0]
for idx, cell in enumerate(batch):
if len(cell) < mdim:
for _ in range(mdim - len(cell)):
cell = np.append(cell, 0)
batch[idx] = cell
input_tensor = torch.tensor([batch])
attention_mask = torch.tensor([[t != 0 for t in cell] for cell in batch
])
outputs = model(input_tensor,
attention_mask=attention_mask,
output_hidden_states=True)
num_out_in_batch = len(outputs.hidden_states[-1])
input_batch_size = input_tensor.shape[1]
num_gene_out_in_batch = len(outputs.hidden_states[-1][0])
assert num_out_in_batch == input_batch_size, f"FAILURE: length doesn't match batch size {num_out_in_batch} vs {input_batch_size}"
assert num_gene_out_in_batch == mdim, f"FAILURE: out length {num_gene_out_in_batch} doesn't match gene length {mdim}"

def testGeneformerTokenizer(self):

Expand Down

0 comments on commit e67fcd0

Please sign in to comment.