Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh committed Aug 6, 2024
1 parent 6dc2a0b commit 48c9b25
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions qa/L0_input_validation/input_validation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ def test_client_input_shape_validation(self):
inputs[0].set_data_from_numpy(input0_data)
inputs[1].set_data_from_numpy(input1_data)

# 1. Test wrong shapes with correct element counts
# Compromised input shapes
inputs[0].set_shape([2, 8])
inputs[1].set_shape([2, 8])

# If element count is correct but shape is wrong, core will return an error.
with self.assertRaises(InferenceServerException) as e:
triton_client.infer(model_name=model_name, inputs=inputs)
err_str = str(e.exception)
Expand All @@ -158,10 +160,12 @@ def test_client_input_shape_validation(self):
err_str,
)

# 2. Test wrong shapes with wrong element counts
# Compromised input shapes
inputs[0].set_shape([1, 8])
inputs[1].set_shape([1, 8])

# If element count is wrong, client returns an error.
with self.assertRaises(InferenceServerException) as e:
triton_client.infer(model_name=model_name, inputs=inputs)
err_str = str(e.exception)
Expand Down Expand Up @@ -208,15 +212,6 @@ def identity_inference(triton_client, np_array, binary_data):
else:
triton_client = tritongrpcclient.InferenceServerClient("localhost:8001")

# Example using BYTES input tensor with utf-8 encoded string that
# has an embedded null character.
null_chars_array = np.array(
["he\x00llo".encode("utf-8") for i in range(16)], dtype=np.object_
)
null_char_data = null_chars_array.reshape([1, 16])
identity_inference(triton_client, null_char_data, True) # Using binary data
identity_inference(triton_client, null_char_data, False) # Using JSON data

# Example using BYTES input tensor with 16 elements, where each
# element is a 4-byte binary blob with value 0x00010203. Can use
# dtype=np.bytes_ in this case.
Expand Down

0 comments on commit 48c9b25

Please sign in to comment.