diff --git a/tests/test_model_call.py b/tests/test_model_call.py index 7f6f450..314b9d1 100644 --- a/tests/test_model_call.py +++ b/tests/test_model_call.py @@ -37,11 +37,7 @@ def test_with_input_name(protocol_and_port): client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol) - sample = np.random.rand(1, 100).astype(np.float32) - result = client({client.input_name_list[0]: sample}) - print(f"Result: {np.isclose(result, sample).all()}") - sample = np.random.rand(100, 100).astype(np.float32) - result = client({client.default_model_spec.input_name[0]: sample}) + result = client({client.default_model_spec.model_input[0].name: sample}) print(f"Result: {np.isclose(result, sample).all()}")