Skip to content

Commit

Permalink
Fix for limiting bytes tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
pziecina-nv committed Nov 10, 2023
1 parent 36c4bec commit 25a8845
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ limitations under the License.
## 0.4.1 (2023-11-09)

- New: Place where workspaces with temporary Triton model repositories and communication file sockets can be configured by `$PYTRITON_HOME` environment variable
- Fix: recover handling `KeyboardInterrupt` in `triton.serve()`
- Fix: Recover handling `KeyboardInterrupt` in `triton.serve()`
- Fix: Remove limit for handling bytes dtype tensors
- Build scripts update
- Added support for arm64 platform builds

Expand Down
2 changes: 1 addition & 1 deletion pytriton/proxy/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _deserialize_bytes_tensor(encoded_tensor, dtype, order: Literal["C", "F"] =
return np.array(strs, dtype=dtype, order=order)


_MAX_DTYPE_DESCR = 8
_MAX_DTYPE_DESCR = 16 # up to 16 chars in dtype descr; |S2147483647 (2^31-1) with margin
_PARTIAL_HEADER_FORMAT = f"<{_MAX_DTYPE_DESCR}scH"


Expand Down
21 changes: 16 additions & 5 deletions tests/unit/test_communication_tensor_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,23 @@ def test_tensor_store_connection_timeout(tmp_path):
],
2,
),
# 2GB bytes array
(
[
np.array(b"a" * (2**31 - 1), dtype=bytes),
],
1,
),
),
)
def test_tensor_store_get_put_equal(tensor_store, tensors, n_times):
for _ in range(n_times):
tensors_ids = tensor_store.put(tensors)
assert len(tensors) == len(tensors_ids)
for tensor, tensor_id in zip(tensors, tensors_ids):
tensor_retrieved = tensor_store.get(tensor_id)
np.testing.assert_equal(tensor, tensor_retrieved)
try:
tensors_ids = tensor_store.put(tensors)
assert len(tensors) == len(tensors_ids)
for tensor, tensor_id in zip(tensors, tensors_ids):
tensor_retrieved = tensor_store.get(tensor_id)
np.testing.assert_equal(tensor, tensor_retrieved)
finally:
for tensor_id in tensors_ids:
tensor_store.release_block(tensor_id)

0 comments on commit 25a8845

Please sign in to comment.