From bd1ae42b8856f4430252ff84b12586485e4449da Mon Sep 17 00:00:00 2001 From: Alejandro Velez-Arce Date: Wed, 23 Oct 2024 19:49:47 -0400 Subject: [PATCH] debug reshape --- tdc/test/test_model_server.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 08f8adf5..9fedb183 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -103,15 +103,15 @@ def testGeneformerTokenizer(self): geneformer = tdc_hf_interface("Geneformer") model = geneformer.load() input_tensor = torch.tensor(cells) - input_tensor = torch.squeeze(input_tensor) - x = input_tensor.shape[0] - y = input_tensor.shape[1] - input_tensor = input_tensor.reshape(x, y) + input_tensor_squeezed = torch.squeeze(input_tensor) + x = input_tensor_squeezed.shape[0] + y = input_tensor_squeezed.shape[1] out = None # try-except block try: - out = model(input_tensor) + input_tensor_squeezed = input_tensor_squeezed.reshape(x, y) + out = model(input_tensor_squeezed) except Exception as e: - raise Exception("shape is", input_tensor.shape, "exception was: {}".format(e), "values are\n", input_tensor) + raise Exception("shape is", input_tensor.shape, "exception was: {}".format(e), "input_tensor_squeezed is\n", input_tensor, "\n\ninput_tensor normal is: {}".format(input_tensor)) assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(out) assert out.shape[0] == input_tensor.shape[0], "FAILURE: Geneformer output and input tensor input don't have the same length. {} vs {}".format(out.shape[0], input_tensor.shape[0]) assert out.shape[0] == len(cells), "FAILURE: Geneformer output and tokenized cells don't have the same length. {} vs {}".format(out.shape[0], len(cells))