Skip to content

Commit

Permalink
use tokenizer dataset function
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Oct 24, 2024
1 parent bd1ae42 commit 7f21f3b
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,25 @@ def testGeneformerTokenizer(self):
assert x[0]

# test Geneformer can serve the request
cells = x[0],
cells, metadata = x
assert cells, "FAILURE: cells false-like. Value is = {}".format(cells)
assert len(cells) > 0, "FAILURE: length of cells <= 0 {}".format(cells)
from tdc import tdc_hf_interface
import torch
# import torch
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()
input_tensor = torch.tensor(cells)
input_tensor_squeezed = torch.squeeze(input_tensor)
x = input_tensor_squeezed.shape[0]
y = input_tensor_squeezed.shape[1]
out = None # try-except block
try:
input_tensor_squeezed = input_tensor_squeezed.reshape(x, y)
out = model(input_tensor_squeezed)
except Exception as e:
raise Exception("shape is", input_tensor.shape, "exception was: {}".format(e), "input_tensor_squeezed is\n", input_tensor, "\n\ninput_tensor normal is: {}".format(input_tensor))
tokenized_data = tokenizer.create_dataset(cells, metadata)
out = model(tokenized_data)
# input_tensor = torch.tensor(cells)
# input_tensor_squeezed = torch.squeeze(input_tensor)
# x = input_tensor_squeezed.shape[0]
# y = input_tensor_squeezed.shape[1]
# out = None # try-except block
# try:
# input_tensor_squeezed = input_tensor_squeezed.reshape(x, y)
# out = model(input_tensor_squeezed)
# except Exception as e:
# raise Exception("tensor shape is", input_tensor.shape, "exception was: {}".format(e), "input_tensor_squeezed is\n", input_tensor, "\n\ninput_tensor normal is: {}".format(input_tensor))
assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(out)
assert out.shape[0] == input_tensor.shape[0], "FAILURE: Geneformer output and input tensor input don't have the same length. {} vs {}".format(out.shape[0], input_tensor.shape[0])
assert out.shape[0] == len(cells), "FAILURE: Geneformer output and tokenized cells don't have the same length. {} vs {}".format(out.shape[0], len(cells))
Expand Down

0 comments on commit 7f21f3b

Please sign in to comment.