Skip to content

Commit

Permalink
Fix shm segments tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
pziecina-nv committed Sep 26, 2023
1 parent bde8337 commit db1aae5
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 54 deletions.
82 changes: 55 additions & 27 deletions pytriton/proxy/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,8 @@ def _get_debug_status(self):
def close(self):
multiprocessing.util.debug(f"Closing server {self._id}")
with self._segments_lock:
for segment in self._segments:
while self._segments:
segment = self._segments.pop()
multiprocessing.util.debug(f"Closing and delete segment {segment.shared_memory.name}")
segment.shared_memory.close()
segment.shared_memory.unlink()
Expand Down Expand Up @@ -598,6 +599,23 @@ class TensorStore:
"""Tensor store for storing and retrieving numpy arrays in/from shared memory."""

_SOCKET_EXISTANCE_CHECK_INTERVAL_S = 0.1
_instances = {}

def __new__(cls, *args, **kwargs):
"""Create TensorStore object. If object with given address already exists, return it."""
if args:
address = args[0]
elif "address" in kwargs:
address = kwargs["address"]
else:
raise TypeError("TensorStore() missing 1 required positional argument: 'address'")

address = address.as_posix() if isinstance(address, pathlib.Path) else address

if address not in cls._instances:
cls._instances[address] = super().__new__(cls)

return cls._instances[address]

def __init__(self, address: Union[str, pathlib.Path], auth_key: Optional[bytes] = None):
"""Initialize TensorStore object.
Expand All @@ -606,22 +624,23 @@ def __init__(self, address: Union[str, pathlib.Path], auth_key: Optional[bytes]
address: address of data store
auth_key: authentication key required to setup connection. If not provided, current process authkey will be used
"""
_update_logger()
address = address.as_posix() if isinstance(address, pathlib.Path) else address
self._remote_blocks_store_manager = BlocksStoreManager(address, authkey=auth_key, ctx=_SpawnContext())
self._remote_blocks_store = None
self._manager_start_stop_filelock = _FileLock(f"{address}.lock")
if not hasattr(self, "_remote_blocks_store_manager"):
_update_logger()
address = address.as_posix() if isinstance(address, pathlib.Path) else address
self._remote_blocks_store_manager = BlocksStoreManager(address, authkey=auth_key, ctx=_SpawnContext())
self._remote_blocks_store = None
self._manager_start_stop_filelock = _FileLock(f"{address}.lock")

# container for keeping map between tensor_id and numpy array weak ref
self._handled_blocks: Dict[str, weakref.ReferenceType] = {}
self._handled_blocks_lock = threading.RLock()
# container for keeping map between tensor_id and numpy array weak ref
self._handled_blocks: Dict[str, weakref.ReferenceType] = {}
self._handled_blocks_lock = threading.RLock()

self._shm_segments: Dict[str, multiprocessing.shared_memory.SharedMemory] = {}
self._shm_segments_lock = threading.RLock()
self._shm_segments: Dict[str, multiprocessing.shared_memory.SharedMemory] = {}
self._shm_segments_lock = threading.RLock()

self.serialize = serialize_numpy_with_struct_header
self.deserialize = deserialize_numpy_with_struct_header
self._calc_serialized_tensor_size = calc_serialized_size_of_numpy_with_struct_header
self.serialize = serialize_numpy_with_struct_header
self.deserialize = deserialize_numpy_with_struct_header
self._calc_serialized_tensor_size = calc_serialized_size_of_numpy_with_struct_header

@property
def address(self) -> str:
Expand All @@ -631,6 +650,9 @@ def address(self) -> str:
def start(self):
"""Start remote block store."""
with self._manager_start_stop_filelock:
if self._remote_blocks_store is not None:
raise RuntimeError("Remote block store is already started/connected")

self._remote_blocks_store_manager.start()
self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error

Expand All @@ -642,11 +664,15 @@ def start(self):

def connect(self, timeout_s: Optional[float] = None):
"""Connect to remote block store."""
address = pathlib.Path(self._remote_blocks_store_manager.address)
if self._remote_blocks_store is None:
address = pathlib.Path(self._remote_blocks_store_manager.address)

self._wait_for_address(address, timeout_s)
self._remote_blocks_store_manager.connect()
self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error
self._wait_for_address(address, timeout_s)
self._remote_blocks_store_manager.connect()
self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error
_LOGGER.debug(f"Connected to remote block store at {address})")
else:
_LOGGER.debug(f"Already connectd to remote block store at {self.address}")

def _wait_for_address(self, address, timeout_s: Optional[float] = None):
should_stop_at = time.time() + timeout_s if timeout_s is not None else None
Expand Down Expand Up @@ -786,7 +812,8 @@ def close(self):
"""Free resources used by TensorStore object."""
from multiprocessing.resource_tracker import register, unregister

_LOGGER.debug("TensorStore is being closed")
started_server = hasattr(self._remote_blocks_store_manager, "shutdown")
_LOGGER.debug(f"TensorStore is being closed (started_server={started_server})")

gc.collect()
with self._handled_blocks_lock:
Expand All @@ -795,18 +822,19 @@ def close(self):
self.release_block(tensor_id)

with self._shm_segments_lock:
for shm in self._shm_segments.values():
_LOGGER.debug(f"Unregistering shared memory {shm.name}")
while self._shm_segments:
_, shm = self._shm_segments.popitem()
_LOGGER.debug(f"Closing shared memory {shm.name}")
try:
shm.close()
register(shm._name, "shared_memory") # pytype: disable=attribute-error
unregister(shm._name, "shared_memory") # pytype: disable=attribute-error
except Exception as e:
_LOGGER.warning(f"Failed to unregister shared memory {shm.name}: {e}")

self._shm_segments = {}
_LOGGER.warning(f"Failed to close shared memory {shm.name}: {e}")
finally:
if not started_server:
register(shm._name, "shared_memory") # pytype: disable=attribute-error
unregister(shm._name, "shared_memory") # pytype: disable=attribute-error

if hasattr(self._remote_blocks_store_manager, "shutdown"):
if started_server:
if self._remote_blocks_store is not None:
_LOGGER.debug(f"Releasing all resources on remote process at {self.address}")
try:
Expand Down
51 changes: 24 additions & 27 deletions tests/unit/test_proxy_inference_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,31 @@ def _infer_gen_fn(*_, **__):

def _get_meta_requests_payload(_data_store_socket):
tensor_store = TensorStore(_data_store_socket)
try:
LOGGER.debug(f"Connecting to tensor store {_data_store_socket} ...")
tensor_store.connect() # to already started tensor store
requests = [
Request({"input1": np.ones((128, 4), dtype="float32"), "input2": np.ones((128, 4), dtype="float32")}),
Request({"input1": np.ones((128, 4), dtype="float32"), "input2": np.ones((128, 4), dtype="float32")}),
]
input_arrays_with_coords = [
(request_idx, input_name, tensor)
for request_idx, request in enumerate(requests)
for input_name, tensor in request.items()
LOGGER.debug(f"Connecting to tensor store {_data_store_socket} ...")
tensor_store.connect() # to already started tensor store
requests = [
Request({"input1": np.ones((128, 4), dtype="float32"), "input2": np.ones((128, 4), dtype="float32")}),
Request({"input1": np.ones((128, 4), dtype="float32"), "input2": np.ones((128, 4), dtype="float32")}),
]
input_arrays_with_coords = [
(request_idx, input_name, tensor)
for request_idx, request in enumerate(requests)
for input_name, tensor in request.items()
]
LOGGER.debug("Putting tensors to tensor store ...")
tensor_ids = tensor_store.put([tensor for _, _, tensor in input_arrays_with_coords])
requests_with_ids = [{}] * len(requests)
for (request_idx, input_name, _), tensor_id in zip(input_arrays_with_coords, tensor_ids):
requests_with_ids[request_idx][input_name] = tensor_id

meta_requests = InferenceHandlerRequests(
requests=[
MetaRequestResponse(idx, data=request_with_ids, parameters=request.parameters)
for idx, (request, request_with_ids) in enumerate(zip(requests, requests_with_ids))
]
LOGGER.debug("Putting tensors to tensor store ...")
tensor_ids = tensor_store.put([tensor for _, _, tensor in input_arrays_with_coords])
requests_with_ids = [{}] * len(requests)
for (request_idx, input_name, _), tensor_id in zip(input_arrays_with_coords, tensor_ids):
requests_with_ids[request_idx][input_name] = tensor_id

meta_requests = InferenceHandlerRequests(
requests=[
MetaRequestResponse(idx, data=request_with_ids, parameters=request.parameters)
for idx, (request, request_with_ids) in enumerate(zip(requests, requests_with_ids))
]
)
LOGGER.debug(f"Return meta requests: {meta_requests}")
return meta_requests.as_bytes()
finally:
tensor_store.close()
)
LOGGER.debug(f"Return meta requests: {meta_requests}")
return meta_requests.as_bytes()


@pytest.mark.parametrize(
Expand Down

0 comments on commit db1aae5

Please sign in to comment.