diff --git a/pytriton/proxy/communication.py b/pytriton/proxy/communication.py index 674a077..50e6cdb 100644 --- a/pytriton/proxy/communication.py +++ b/pytriton/proxy/communication.py @@ -464,7 +464,8 @@ def _get_debug_status(self): def close(self): multiprocessing.util.debug(f"Closing server {self._id}") with self._segments_lock: - for segment in self._segments: + while self._segments: + segment = self._segments.pop() multiprocessing.util.debug(f"Closing and delete segment {segment.shared_memory.name}") segment.shared_memory.close() segment.shared_memory.unlink() @@ -598,6 +599,23 @@ class TensorStore: """Tensor store for storing and retrieving numpy arrays in/from shared memory.""" _SOCKET_EXISTANCE_CHECK_INTERVAL_S = 0.1 + _instances = {} + + def __new__(cls, *args, **kwargs): + """Create TensorStore object. If object with given address already exists, return it.""" + if args: + address = args[0] + elif "address" in kwargs: + address = kwargs["address"] + else: + raise TypeError("TensorStore() missing 1 required positional argument: 'address'") + + address = address.as_posix() if isinstance(address, pathlib.Path) else address + + if address not in cls._instances: + cls._instances[address] = super().__new__(cls) + + return cls._instances[address] def __init__(self, address: Union[str, pathlib.Path], auth_key: Optional[bytes] = None): """Initialize TensorStore object. @@ -606,22 +624,23 @@ def __init__(self, address: Union[str, pathlib.Path], auth_key: Optional[bytes] address: address of data store auth_key: authentication key required to setup connection. If not provided, current process authkey will be used """ - _update_logger() - address = address.as_posix() if isinstance(address, pathlib.Path) else address - self._remote_blocks_store_manager = BlocksStoreManager(address, authkey=auth_key, ctx=_SpawnContext()) - self._remote_blocks_store = None - self._manager_start_stop_filelock = _FileLock(f"{address}.lock") + if not hasattr(self, "_remote_blocks_store_manager"): + _update_logger() + address = address.as_posix() if isinstance(address, pathlib.Path) else address + self._remote_blocks_store_manager = BlocksStoreManager(address, authkey=auth_key, ctx=_SpawnContext()) + self._remote_blocks_store = None + self._manager_start_stop_filelock = _FileLock(f"{address}.lock") - # container for keeping map between tensor_id and numpy array weak ref - self._handled_blocks: Dict[str, weakref.ReferenceType] = {} - self._handled_blocks_lock = threading.RLock() + # container for keeping map between tensor_id and numpy array weak ref + self._handled_blocks: Dict[str, weakref.ReferenceType] = {} + self._handled_blocks_lock = threading.RLock() - self._shm_segments: Dict[str, multiprocessing.shared_memory.SharedMemory] = {} - self._shm_segments_lock = threading.RLock() + self._shm_segments: Dict[str, multiprocessing.shared_memory.SharedMemory] = {} + self._shm_segments_lock = threading.RLock() - self.serialize = serialize_numpy_with_struct_header - self.deserialize = deserialize_numpy_with_struct_header - self._calc_serialized_tensor_size = calc_serialized_size_of_numpy_with_struct_header + self.serialize = serialize_numpy_with_struct_header + self.deserialize = deserialize_numpy_with_struct_header + self._calc_serialized_tensor_size = calc_serialized_size_of_numpy_with_struct_header @property def address(self) -> str: @@ -631,6 +650,9 @@ def address(self) -> str: def start(self): """Start remote block store.""" with self._manager_start_stop_filelock: + if self._remote_blocks_store is not None: + raise RuntimeError("Remote block store is already started/connected") + self._remote_blocks_store_manager.start() self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error @@ -642,11 +664,15 @@ def start(self): def connect(self, timeout_s: Optional[float] = None): """Connect to remote block store.""" - address = pathlib.Path(self._remote_blocks_store_manager.address) + if self._remote_blocks_store is None: + address = pathlib.Path(self._remote_blocks_store_manager.address) - self._wait_for_address(address, timeout_s) - self._remote_blocks_store_manager.connect() - self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error + self._wait_for_address(address, timeout_s) + self._remote_blocks_store_manager.connect() + self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error + _LOGGER.debug(f"Connected to remote block store at {address})") + else: + _LOGGER.debug(f"Already connectd to remote block store at {self.address}") def _wait_for_address(self, address, timeout_s: Optional[float] = None): should_stop_at = time.time() + timeout_s if timeout_s is not None else None @@ -786,7 +812,8 @@ def close(self): """Free resources used by TensorStore object.""" from multiprocessing.resource_tracker import register, unregister - _LOGGER.debug("TensorStore is being closed") + started_server = hasattr(self._remote_blocks_store_manager, "shutdown") + _LOGGER.debug(f"TensorStore is being closed (started_server={started_server})") gc.collect() with self._handled_blocks_lock: @@ -795,18 +822,19 @@ def close(self): self.release_block(tensor_id) with self._shm_segments_lock: - for shm in self._shm_segments.values(): - _LOGGER.debug(f"Unregistering shared memory {shm.name}") + while self._shm_segments: + _, shm = self._shm_segments.popitem() + _LOGGER.debug(f"Closing shared memory {shm.name}") try: shm.close() - register(shm._name, "shared_memory") # pytype: disable=attribute-error - unregister(shm._name, "shared_memory") # pytype: disable=attribute-error except Exception as e: - _LOGGER.warning(f"Failed to unregister shared memory {shm.name}: {e}") - - self._shm_segments = {} + _LOGGER.warning(f"Failed to close shared memory {shm.name}: {e}") + finally: + if not started_server: + register(shm._name, "shared_memory") # pytype: disable=attribute-error + unregister(shm._name, "shared_memory") # pytype: disable=attribute-error - if hasattr(self._remote_blocks_store_manager, "shutdown"): + if started_server: if self._remote_blocks_store is not None: _LOGGER.debug(f"Releasing all resources on remote process at {self.address}") try: diff --git a/tests/unit/test_proxy_inference_handler.py b/tests/unit/test_proxy_inference_handler.py index e9a0568..97966c8 100644 --- a/tests/unit/test_proxy_inference_handler.py +++ b/tests/unit/test_proxy_inference_handler.py @@ -71,34 +71,31 @@ def _infer_gen_fn(*_, **__): def _get_meta_requests_payload(_data_store_socket): tensor_store = TensorStore(_data_store_socket) - try: - LOGGER.debug(f"Connecting to tensor store {_data_store_socket} ...") - tensor_store.connect() # to already started tensor store - requests = [ - Request({"input1": np.ones((128, 4), dtype="float32"), "input2": np.ones((128, 4), dtype="float32")}), - Request({"input1": np.ones((128, 4), dtype="float32"), "input2": np.ones((128, 4), dtype="float32")}), - ] - input_arrays_with_coords = [ - (request_idx, input_name, tensor) - for request_idx, request in enumerate(requests) - for input_name, tensor in request.items() + LOGGER.debug(f"Connecting to tensor store {_data_store_socket} ...") + tensor_store.connect() # to already started tensor store + requests = [ + Request({"input1": np.ones((128, 4), dtype="float32"), "input2": np.ones((128, 4), dtype="float32")}), + Request({"input1": np.ones((128, 4), dtype="float32"), "input2": np.ones((128, 4), dtype="float32")}), + ] + input_arrays_with_coords = [ + (request_idx, input_name, tensor) + for request_idx, request in enumerate(requests) + for input_name, tensor in request.items() + ] + LOGGER.debug("Putting tensors to tensor store ...") + tensor_ids = tensor_store.put([tensor for _, _, tensor in input_arrays_with_coords]) + requests_with_ids = [{}] * len(requests) + for (request_idx, input_name, _), tensor_id in zip(input_arrays_with_coords, tensor_ids): + requests_with_ids[request_idx][input_name] = tensor_id + + meta_requests = InferenceHandlerRequests( + requests=[ + MetaRequestResponse(idx, data=request_with_ids, parameters=request.parameters) + for idx, (request, request_with_ids) in enumerate(zip(requests, requests_with_ids)) ] - LOGGER.debug("Putting tensors to tensor store ...") - tensor_ids = tensor_store.put([tensor for _, _, tensor in input_arrays_with_coords]) - requests_with_ids = [{}] * len(requests) - for (request_idx, input_name, _), tensor_id in zip(input_arrays_with_coords, tensor_ids): - requests_with_ids[request_idx][input_name] = tensor_id - - meta_requests = InferenceHandlerRequests( - requests=[ - MetaRequestResponse(idx, data=request_with_ids, parameters=request.parameters) - for idx, (request, request_with_ids) in enumerate(zip(requests, requests_with_ids)) - ] - ) - LOGGER.debug(f"Return meta requests: {meta_requests}") - return meta_requests.as_bytes() - finally: - tensor_store.close() + ) + LOGGER.debug(f"Return meta requests: {meta_requests}") + return meta_requests.as_bytes() @pytest.mark.parametrize(