diff --git a/src/python/library/CMakeLists.txt b/src/python/library/CMakeLists.txt index e552ec2f7..f9a1c5748 100644 --- a/src/python/library/CMakeLists.txt +++ b/src/python/library/CMakeLists.txt @@ -96,7 +96,6 @@ add_custom_target( if (NOT WIN32) # Can generate linux specific wheel file on linux systems only. set(LINUX_WHEEL_DEPENDS - cshm ${WHEEL_DEPENDS} ) diff --git a/src/python/library/build_wheel.py b/src/python/library/build_wheel.py index d32e7732a..73f727d0d 100755 --- a/src/python/library/build_wheel.py +++ b/src/python/library/build_wheel.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -174,10 +174,6 @@ def sed(pattern, replace, source, dest=None): "tritonclient/utils/shared_memory", os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory"), ) - shutil.copyfile( - "tritonclient/utils/libcshm.so", - os.path.join(FLAGS.whl_dir, "tritonclient/utils/shared_memory/libcshm.so"), - ) cpdir( "tritonclient/utils/cuda_shared_memory", os.path.join(FLAGS.whl_dir, "tritonclient/utils/cuda_shared_memory"), diff --git a/src/python/library/setup.py b/src/python/library/setup.py index e31f5ddc0..634b8b57e 100755 --- a/src/python/library/setup.py +++ b/src/python/library/setup.py @@ -76,8 +76,6 @@ def req_file(filename, folder="requirements"): extras_require["all"] = list(chain(extras_require.values())) platform_package_data = [] -if PLATFORM_FLAG != "any": - platform_package_data += ["libcshm.so"] data_files = [ ("", ["LICENSE.txt"]), diff --git a/src/python/library/tests/test_shared_memory.py b/src/python/library/tests/test_shared_memory.py new file mode 100644 index 000000000..36c64f090 --- /dev/null +++ b/src/python/library/tests/test_shared_memory.py @@ -0,0 +1,183 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import unittest + +import numpy +import tritonclient.utils as utils +import tritonclient.utils.shared_memory as shm + + +class SharedMemoryTest(unittest.TestCase): + """ + Testing shared memory utilities + """ + + def setUp(self): + self.shm_handles = [] + + def tearDown(self): + for shm_handle in self.shm_handles: + shm.destroy_shared_memory_region(shm_handle) + + def test_lifecycle(self): + cpu_tensor = numpy.ones([4, 4], dtype=numpy.float32) + byte_size = 64 + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", byte_size) + ) + + self.assertEqual(len(shm.mapped_shared_memory_regions()), 1) + + # Set data from Numpy array + shm.set_shared_memory_region(self.shm_handles[0], [cpu_tensor]) + shm_tensor = shm.get_contents_as_numpy( + self.shm_handles[0], numpy.float32, [4, 4] + ) + + self.assertTrue(numpy.allclose(cpu_tensor, shm_tensor)) + + shm.destroy_shared_memory_region(self.shm_handles.pop(0)) + + def test_invalid_create_shm(self): + # Raises error since tried to create invalid system shared memory region + with self.assertRaisesRegex( + shm.SharedMemoryException, "unable to create the shared memory region" + ): + self.shm_handles.append( + shm.create_shared_memory_region("dummy_data", "/dummy_data", -1) + ) + + def test_set_region_offset(self): + large_tensor = numpy.ones([4, 4], dtype=numpy.float32) + large_size = 64 + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", large_size) + ) + shm.set_shared_memory_region(self.shm_handles[0], [large_tensor]) + small_tensor = numpy.zeros([2, 4], dtype=numpy.float32) + small_size = 32 + shm.set_shared_memory_region( + self.shm_handles[0], [small_tensor], offset=large_size - small_size + ) + shm_tensor = shm.get_contents_as_numpy( + self.shm_handles[0], numpy.float32, [2, 4], offset=large_size - small_size + ) + + self.assertTrue(numpy.allclose(small_tensor, shm_tensor)) + + def test_set_region_oversize(self): + large_tensor = numpy.ones([4, 4], dtype=numpy.float32) + small_size = 32 + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", small_size) + ) + with self.assertRaisesRegex( + shm.SharedMemoryException, "unable to set the shared memory region" + ): + shm.set_shared_memory_region(self.shm_handles[0], [large_tensor]) + + def test_duplicate_key(self): + # by default, return the same handle if existed, warning will be print + # if size is different + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", 32) + ) + with self.assertRaisesRegex( + shm.SharedMemoryException, + "unable to create the shared memory region", + ): + self.shm_handles.append( + shm.create_shared_memory_region( + "shm_name", "shm_key", 32, create_only=True + ) + ) + + # Get handle to the same shared memory region but with larger size requested, + # check if actual size is checked + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", 64) + ) + + self.assertEqual(len(shm.mapped_shared_memory_regions()), 1) + + large_tensor = numpy.ones([4, 4], dtype=numpy.float32) + with self.assertRaisesRegex( + shm.SharedMemoryException, "unable to set the shared memory region" + ): + shm.set_shared_memory_region(self.shm_handles[-1], [large_tensor]) + + def test_destroy_duplicate(self): + # destruction of duplicate shared memory region will occur when the last + # managed handle is destroyed + self.assertEqual(len(shm.mapped_shared_memory_regions()), 0) + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", 64) + ) + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", 32) + ) + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", 32) + ) + self.assertEqual(len(shm.mapped_shared_memory_regions()), 1) + + shm.destroy_shared_memory_region(self.shm_handles.pop(0)) + shm.destroy_shared_memory_region(self.shm_handles.pop(0)) + self.assertEqual(len(shm.mapped_shared_memory_regions()), 1) + + shm.destroy_shared_memory_region(self.shm_handles.pop(0)) + self.assertEqual(len(shm.mapped_shared_memory_regions()), 0) + + def test_numpy_bytes(self): + int_tensor = numpy.arange(start=0, stop=16, dtype=numpy.int32) + bytes_tensor = numpy.array( + [str(x).encode("utf-8") for x in int_tensor.flatten()], dtype=object + ) + bytes_tensor = bytes_tensor.reshape(int_tensor.shape) + bytes_tensor_serialized = utils.serialize_byte_tensor(bytes_tensor) + byte_size = utils.serialized_byte_size(bytes_tensor_serialized) + + self.shm_handles.append( + shm.create_shared_memory_region("shm_name", "shm_key", byte_size) + ) + + # Set data from Numpy array + shm.set_shared_memory_region(self.shm_handles[0], [bytes_tensor_serialized]) + + shm_tensor = shm.get_contents_as_numpy( + self.shm_handles[0], + numpy.object_, + [ + 16, + ], + ) + + self.assertTrue(numpy.array_equal(bytes_tensor, shm_tensor)) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/python/library/tritonclient/utils/CMakeLists.txt b/src/python/library/tritonclient/utils/CMakeLists.txt index 7de1acf96..94952efc7 100644 --- a/src/python/library/tritonclient/utils/CMakeLists.txt +++ b/src/python/library/tritonclient/utils/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -28,20 +28,6 @@ configure_file(__init__.py __init__.py COPYONLY) configure_file(_dlpack.py _dlpack.py COPYONLY) configure_file(_shared_memory_tensor.py _shared_memory_tensor.py COPYONLY) -if(NOT WIN32) - file(COPY shared_memory DESTINATION .) - - # - # libcshm.so - # - add_library(cshm SHARED shared_memory/shared_memory.cc) - if(${TRITON_ENABLE_GPU}) - target_compile_definitions(cshm PUBLIC TRITON_ENABLE_GPU=1) - target_link_libraries(cshm PUBLIC CUDA::cudart) - endif() # TRITON_ENABLE_GPU - target_link_libraries(cshm PRIVATE rt) -endif() # WIN32 - if(NOT WIN32) configure_file(shared_memory/__init__.py shared_memory/__init__.py COPYONLY) configure_file(cuda_shared_memory/__init__.py cuda_shared_memory/__init__.py COPYONLY) diff --git a/src/python/library/tritonclient/utils/shared_memory/__init__.py b/src/python/library/tritonclient/utils/shared_memory/__init__.py index fd65191b9..12904445e 100755 --- a/src/python/library/tritonclient/utils/shared_memory/__init__.py +++ b/src/python/library/tritonclient/utils/shared_memory/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -28,70 +28,27 @@ import os import struct -from ctypes import * +import warnings +from multiprocessing import shared_memory as mpshm import numpy as np -import pkg_resources - - -class _utf8(object): - @classmethod - def from_param(cls, value): - if value is None: - return None - elif isinstance(value, bytes): - return value - else: - return value.encode("utf8") - - -_cshm_lib = "cshm" if os.name == "nt" else "libcshm.so" -_cshm_path = pkg_resources.resource_filename( - "tritonclient.utils.shared_memory", _cshm_lib -) -_cshm = cdll.LoadLibrary(_cshm_path) - -_cshm_shared_memory_region_create = _cshm.SharedMemoryRegionCreate -_cshm_shared_memory_region_create.restype = c_int -_cshm_shared_memory_region_create.argtypes = [_utf8, _utf8, c_uint64, POINTER(c_void_p)] -_cshm_shared_memory_region_set = _cshm.SharedMemoryRegionSet -_cshm_shared_memory_region_set.restype = c_int -_cshm_shared_memory_region_set.argtypes = [c_void_p, c_uint64, c_uint64, c_void_p] -_cshm_get_shared_memory_handle_info = _cshm.GetSharedMemoryHandleInfo -_cshm_get_shared_memory_handle_info.restype = c_int -_cshm_get_shared_memory_handle_info.argtypes = [ - c_void_p, - POINTER(c_char_p), - POINTER(c_char_p), - POINTER(c_int), - POINTER(c_uint64), - POINTER(c_uint64), -] -_cshm_shared_memory_region_destroy = _cshm.SharedMemoryRegionDestroy -_cshm_shared_memory_region_destroy.restype = c_int -_cshm_shared_memory_region_destroy.argtypes = [c_void_p] - -mapped_shm_regions = [] - - -def _raise_if_error(errno): - """ - Raise SharedMemoryException if 'err' is non-success. - Otherwise return nothing. - """ - if errno.value != 0: - ex = SharedMemoryException(errno) - raise ex - return + +_key_mapping = {} -def _raise_error(msg): - ex = SharedMemoryException(msg) - raise ex +class SharedMemoryRegion: + def __init__( + self, + triton_shm_name: str, + shm_key: str, + ) -> None: + self._triton_shm_name = triton_shm_name + self._shm_key = shm_key + self._mpsm_handle = None -def create_shared_memory_region(triton_shm_name, shm_key, byte_size): - """Creates a system shared memory region with the specified name and size. +def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only=False): + """Return a handle of the system shared memory region with the specified name and size. Parameters ---------- @@ -101,10 +58,16 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size): The unique key of the shared memory object. byte_size : int The size in bytes of the shared memory region to be created. + create_only : bool + Whether a shared memory region must be created. If False and + a shared memory region of the same name exists, a handle to that + shared memory region will be returned and user must be aware that + the previously allocated shared memory size can be different from + the size requested. Returns ------- - shm_handle : c_void_p + shm_handle : SharedMemoryRegion The handle for the system shared memory region. Raises @@ -112,16 +75,39 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size): SharedMemoryException If unable to create the shared memory region. """ - - shm_handle = c_void_p() - _raise_if_error( - c_int( - _cshm_shared_memory_region_create( - triton_shm_name, shm_key, byte_size, byref(shm_handle) + shm_handle = SharedMemoryRegion(triton_shm_name, shm_key) + # Check whether the region exists before creating it + if not create_only: + try: + shm_handle._mpsm_handle = mpshm.SharedMemory(shm_key) + if shm_key not in _key_mapping: + _key_mapping[shm_key] = { + "needs_unlink": False, + "active_handle_count": 0, + } + _key_mapping[shm_key]["active_handle_count"] += 1 + except FileNotFoundError: + # File not found means the shared memory region has not been created, + # suppress the exception and attempt to create the region. + pass + if shm_handle._mpsm_handle is None: + try: + shm_handle._mpsm_handle = mpshm.SharedMemory( + shm_key, create=True, size=byte_size ) + except Exception as ex: + raise SharedMemoryException( + "unable to create the shared memory region" + ) from ex + if shm_key not in _key_mapping: + _key_mapping[shm_key] = {"needs_unlink": False, "active_handle_count": 0} + _key_mapping[shm_key]["needs_unlink"] = True + _key_mapping[shm_key]["active_handle_count"] += 1 + + if byte_size > shm_handle._mpsm_handle.size: + warnings.warn( + f"reusing shared memory region with key '{shm_key}', region size is {shm_handle._mpsm_handle.size} instead of requested {byte_size}" ) - ) - mapped_shm_regions.append(shm_key) return shm_handle @@ -131,7 +117,7 @@ def set_shared_memory_region(shm_handle, input_values, offset=0): Parameters ---------- - shm_handle : c_void_p + shm_handle : SharedMemoryRegion The handle for the system shared memory region. input_values : list The list of numpy arrays to be copied into the shared memory region. @@ -146,41 +132,35 @@ def set_shared_memory_region(shm_handle, input_values, offset=0): """ if not isinstance(input_values, (list, tuple)): - _raise_error("input_values must be specified as a list/tuple of numpy arrays") + raise SharedMemoryException( + "input_values must be specified as a list/tuple of numpy arrays" + ) for input_value in input_values: if not isinstance(input_value, np.ndarray): - _raise_error("each element of input_values must be a numpy array") + raise SharedMemoryException( + "each element of input_values must be a numpy array" + ) - offset_current = offset - for input_value in input_values: - input_value = np.ascontiguousarray(input_value).flatten() - if input_value.dtype == np.object_: - input_value = input_value.item() - byte_size = np.dtype(np.byte).itemsize * len(input_value) - _raise_if_error( - c_int( - _cshm_shared_memory_region_set( - shm_handle, - c_uint64(offset_current), - c_uint64(byte_size), - cast(input_value, c_void_p), - ) + try: + for input_value in input_values: + # numpy array of object type is "syntactic sugar" for the API, should + # be handled by accessing its item and treat as Python object + if input_value.dtype == np.object_: + byte_size = len(input_value.item()) + shm_handle._mpsm_handle.buf[offset : offset + byte_size] = ( + input_value.item() ) - ) - else: - byte_size = input_value.size * input_value.itemsize - _raise_if_error( - c_int( - _cshm_shared_memory_region_set( - shm_handle, - c_uint64(offset_current), - c_uint64(byte_size), - input_value.ctypes.data_as(c_void_p), - ) + offset += byte_size + else: + shm_tensor_view = np.ndarray( + input_value.shape, + input_value.dtype, + buffer=shm_handle._mpsm_handle.buf[offset:], ) - ) - offset_current += byte_size - return + shm_tensor_view[:] = input_value[:] + offset += input_value.nbytes + except Exception as ex: + raise SharedMemoryException("unable to set the shared memory region") from ex def get_contents_as_numpy(shm_handle, datatype, shape, offset=0): @@ -189,7 +169,7 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0): Parameters ---------- - shm_handle : c_void_p + shm_handle : SharedMemoryRegion The handle for the system shared memory region. datatype : np.dtype The datatype of the array to be returned. @@ -205,42 +185,13 @@ def get_contents_as_numpy(shm_handle, datatype, shape, offset=0): The numpy array generated using the contents of the specified shared memory region. """ - shm_fd = c_int() - region_offset = c_uint64() - byte_size = c_uint64() - shm_addr = c_char_p() - shm_key = c_char_p() - _raise_if_error( - c_int( - _cshm_get_shared_memory_handle_info( - shm_handle, - byref(shm_addr), - byref(shm_key), - byref(shm_fd), - byref(region_offset), - byref(byte_size), - ) - ) - ) - start_pos = region_offset.value + offset if (datatype != np.object_) and (datatype != np.bytes_): - requested_byte_size = np.prod(shape) * np.dtype(datatype).itemsize - cval_len = start_pos + requested_byte_size - if byte_size.value < cval_len: - _raise_error( - "The size of the shared memory region is insufficient to provide numpy array with requested size" - ) - if cval_len == 0: - result = np.empty(shape, dtype=datatype) - else: - val_buf = cast(shm_addr, POINTER(c_byte * cval_len))[0] - val = np.frombuffer(val_buf, dtype=datatype, offset=start_pos) - - # Reshape the result to the appropriate shape. - result = np.reshape(val, shape) + result = np.ndarray( + shape, datatype, buffer=shm_handle._mpsm_handle.buf[offset:] + ) else: - str_offset = start_pos - val_buf = cast(shm_addr, POINTER(c_byte * byte_size.value))[0] + str_offset = offset + val_buf = shm_handle._mpsm_handle.buf ii = 0 strs = list() while (ii % np.prod(shape) != 0) or (ii == 0): @@ -268,15 +219,16 @@ def mapped_shared_memory_regions(): The list of mapped system shared memory regions. """ - return mapped_shm_regions + return list(_key_mapping.keys()) def destroy_shared_memory_region(shm_handle): - """Unlink a system shared memory region with the specified handle. + """Release the handle, unlink a system shared memory region with the specified handle + if it is the last managed handle. Parameters ---------- - shm_handle : c_void_p + shm_handle : SharedMemoryRegion The handle for the system shared memory region. Raises @@ -284,57 +236,22 @@ def destroy_shared_memory_region(shm_handle): SharedMemoryException If unable to unlink the shared memory region. """ - shm_fd = c_int() - offset = c_uint64() - byte_size = c_uint64() - shm_addr = c_char_p() - shm_key = c_char_p() - _raise_if_error( - c_int( - _cshm_get_shared_memory_handle_info( - shm_handle, - byref(shm_addr), - byref(shm_key), - byref(shm_fd), - byref(offset), - byref(byte_size), - ) - ) - ) # It is safer to remove the shared memory key from the list before # deleting the shared memory region because if the deletion should # fail, a re-attempt could result in a segfault. Secondarily, if we # fail to delete a region, we should not report it back to the user # as a valid memory region. - mapped_shm_regions.remove(shm_key.value.decode("utf-8")) - _raise_if_error(c_int(_cshm_shared_memory_region_destroy(shm_handle))) - return + shm_handle._mpsm_handle.close() + _key_mapping[shm_handle._shm_key]["active_handle_count"] -= 1 + if _key_mapping[shm_handle._shm_key]["active_handle_count"] == 0: + try: + if _key_mapping[shm_handle._shm_key]["needs_unlink"]: + shm_handle._mpsm_handle.unlink() + finally: + _key_mapping.pop(shm_handle._shm_key) class SharedMemoryException(Exception): - """Exception indicating non-Success status. - - Parameters - ---------- - err : c_void_p - Pointer to an Error that should be used to initialize the exception. - - """ + """Exception type for shared memory related error.""" - def __init__(self, err): - self.err_code_map = { - -2: "unable to get shared memory descriptor", - -3: "unable to initialize the size", - -4: "unable to read/mmap the shared memory region", - -5: "unable to unlink the shared memory region", - -6: "unable to munmap the shared memory region", - } - self._msg = None - if type(err) == str: - self._msg = err - elif err.value != 0 and err.value in self.err_code_map: - self._msg = self.err_code_map[err.value] - - def __str__(self): - msg = super().__str__() if self._msg is None else self._msg - return msg + pass diff --git a/src/python/library/tritonclient/utils/shared_memory/shared_memory.cc b/src/python/library/tritonclient/utils/shared_memory/shared_memory.cc deleted file mode 100644 index 80720da26..000000000 --- a/src/python/library/tritonclient/utils/shared_memory/shared_memory.cc +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions -// are met: -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// * Neither the name of NVIDIA CORPORATION nor the names of its -// contributors may be used to endorse or promote products derived -// from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#include "shared_memory.h" - -#include -#include -#include -#include - -#include -#include -#include - -#include "shared_memory_handle.h" - -//============================================================================== -// SharedMemoryControlContext - -namespace { - -void* -SharedMemoryHandleCreate( - std::string triton_shm_name, void* shm_addr, std::string shm_key, - int shm_fd, size_t offset, size_t byte_size) -{ - SharedMemoryHandle* handle = new SharedMemoryHandle(); - handle->triton_shm_name_ = triton_shm_name; - handle->base_addr_ = shm_addr; - handle->shm_key_ = shm_key; - handle->shm_fd_ = shm_fd; - handle->offset_ = offset; - handle->byte_size_ = byte_size; - return reinterpret_cast(handle); -} - -int -SharedMemoryRegionMap( - int shm_fd, size_t offset, size_t byte_size, void** shm_addr) -{ - // map shared memory to process address space - *shm_addr = mmap(NULL, byte_size, PROT_WRITE, MAP_SHARED, shm_fd, offset); - if (*shm_addr == MAP_FAILED) { - return -1; - } - - // close shared memory descriptor, return 0 if success else return -1 - return close(shm_fd); -} - -} // namespace - -int -SharedMemoryRegionCreate( - const char* triton_shm_name, const char* shm_key, size_t byte_size, - void** shm_handle) -{ - // get shared memory region descriptor - int shm_fd = shm_open(shm_key, O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); - if (shm_fd == -1) { - return -2; - } - - // extend shared memory object as by default it's initialized with size 0 - int res = ftruncate(shm_fd, byte_size); - if (res == -1) { - return -3; - } - - // get base address of shared memory region - void* shm_addr = nullptr; - int err = SharedMemoryRegionMap(shm_fd, 0, byte_size, &shm_addr); - if (err == -1) { - return -4; - } - - // create a handle for the shared memory region - *shm_handle = SharedMemoryHandleCreate( - std::string(triton_shm_name), shm_addr, std::string(shm_key), shm_fd, 0, - byte_size); - return 0; -} - -int -SharedMemoryRegionSet( - void* shm_handle, size_t offset, size_t byte_size, const void* data) -{ - void* shm_addr = - reinterpret_cast(shm_handle)->base_addr_; - char* shm_addr_offset = reinterpret_cast(shm_addr); - std::memcpy(shm_addr_offset + offset, data, byte_size); - return 0; -} - -int -GetSharedMemoryHandleInfo( - void* shm_handle, char** shm_addr, const char** shm_key, int* shm_fd, - size_t* offset, size_t* byte_size) -{ - SharedMemoryHandle* handle = - reinterpret_cast(shm_handle); - *shm_addr = reinterpret_cast(handle->base_addr_); - *shm_key = handle->shm_key_.c_str(); - *shm_fd = handle->shm_fd_; - *offset = handle->offset_; - *byte_size = handle->byte_size_; - return 0; -} - -int -SharedMemoryRegionDestroy(void* shm_handle) -{ - std::unique_ptr handle( - reinterpret_cast(shm_handle)); - void* shm_addr = reinterpret_cast(handle->base_addr_); - int status = munmap(shm_addr, handle->byte_size_); - if (status == -1) { - return -6; - } - - int shm_fd = shm_unlink(handle->shm_key_.c_str()); - if (shm_fd == -1) { - return -5; - } - return 0; -} - -//============================================================================== diff --git a/src/python/library/tritonclient/utils/shared_memory/shared_memory.h b/src/python/library/tritonclient/utils/shared_memory/shared_memory.h deleted file mode 100644 index 9d3e9519e..000000000 --- a/src/python/library/tritonclient/utils/shared_memory/shared_memory.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions -// are met: -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// * Neither the name of NVIDIA CORPORATION nor the names of its -// contributors may be used to endorse or promote products derived -// from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#pragma once - -#include -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -//============================================================================== -// SharedMemoryControlContext -int SharedMemoryRegionCreate( - const char* triton_shm_name, const char* shm_key, size_t byte_size, - void** shm_handle); -int SharedMemoryRegionSet( - void* shm_handle, size_t offset, size_t byte_size, const void* data); -int GetSharedMemoryHandleInfo( - void* shm_handle, char** shm_addr, const char** shm_key, int* shm_fd, - size_t* offset, size_t* byte_size); -int SharedMemoryRegionDestroy(void* shm_handle); - -//============================================================================== - -#ifdef __cplusplus -} -#endif diff --git a/src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h b/src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h deleted file mode 100644 index b929ed305..000000000 --- a/src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions -// are met: -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// * Neither the name of NVIDIA CORPORATION nor the names of its -// contributors may be used to endorse or promote products derived -// from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#pragma once - -#ifdef TRITON_ENABLE_GPU -#include -#endif // TRITON_ENABLE_GPU - -struct SharedMemoryHandle { - std::string triton_shm_name_; - std::string shm_key_; -#ifdef TRITON_ENABLE_GPU - cudaIpcMemHandle_t cuda_shm_handle_; - int device_id_; -#endif // TRITON_ENABLE_GPU - void* base_addr_; - int shm_fd_; - size_t offset_; - size_t byte_size_; -};