diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index c2240ee9..146865c7 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -103,13 +103,20 @@ def testGeneformerTokenizer(self): geneformer = tdc_hf_interface("Geneformer") model = geneformer.load() # tokenized_data = tokenizer.create_dataset(cells, metadata) + print("using very few genes for these test cases so expecting empties... let's pad/remove just for the test case...") + for idx in range(len(cells)): + x = cells[idx] + if len(x)<2: + for _ in range(2-len(x)): + x.append(1) + cells[idx] = x input_tensor = torch.tensor(cells) - input_tensor = torch.squeeze(input_tensor) + # input_tensor = torch.squeeze(input_tensor) try: - input_tensor.squeeze(2) # last dim is zero + # input_tensor.squeeze(2) # last dim is zero out = model(input_tensor) except Exception as e: - raise Exception("tensor shape is", input_tensor.shape, "exception was:", e) + raise Exception("tensor shape is", input_tensor.shape, "exception was:", e, "\n cells was\n", cells) # raise Exception(e) # input_tensor = torch.tensor(cells) # input_tensor_squeezed = torch.squeeze(input_tensor)