Skip to content

Commit

Permalink
fix: Fix deadlock when child gets kill -9 sig (#340)
Browse files Browse the repository at this point in the history
* fix: Fix deadlock when child gets kill -9 sig

* chore: Better cleanup for resources

* chore: changed place of processes.clear

* fix: Added cancle_join_thread for emergency shutdown
  • Loading branch information
hh-space-invader authored Sep 24, 2024
1 parent 65c2efd commit 40a0374
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions fastembed/parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def input_queue_iterable() -> Iterable[Any]:
# See:
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
input_queue.close()
output_queue.close()
input_queue.join_thread()
output_queue.join_thread()

with num_active_workers.get_lock():
Expand All @@ -83,17 +85,15 @@ def input_queue_iterable() -> Iterable[Any]:


class ParallelWorkerPool:
def __init__(
self, num_workers: int, worker: Type[Worker], start_method: Optional[str] = None
):
def __init__(self, num_workers: int, worker: Type[Worker], start_method: Optional[str] = None):
self.worker_class = worker
self.num_workers = num_workers
self.input_queue: Optional[Queue] = None
self.output_queue: Optional[Queue] = None
self.ctx: BaseContext = get_context(start_method)
self.processes: List[BaseProcess] = []
self.queue_size = self.num_workers * max_internal_batch_size

self.emergency_shutdown = False
self.num_active_workers: Optional[BaseValue] = None

def start(self, **kwargs: Any) -> None:
Expand All @@ -120,9 +120,7 @@ def start(self, **kwargs: Any) -> None:
process.start()
self.processes.append(process)

def ordered_map(
self, stream: Iterable[Any], *args: Any, **kwargs: Any
) -> Iterable[Any]:
def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
buffer = defaultdict(Any)
next_expected = 0

Expand All @@ -144,6 +142,7 @@ def semi_ordered_map(
pushed = 0
read = 0
for idx, item in enumerate(stream):
self.check_worker_health()
if pushed - read < self.queue_size:
try:
out_item = self.output_queue.get_nowait()
Expand All @@ -170,6 +169,7 @@ def semi_ordered_map(
self.input_queue.put(QueueSignals.stop)

while read < pushed:
self.check_worker_health()
out_item = self.output_queue.get(timeout=processing_timeout)
if out_item == QueueSignals.error:
self.join_or_terminate()
Expand All @@ -179,8 +179,27 @@ def semi_ordered_map(
finally:
assert self.input_queue is not None, "Input queue is None"
assert self.output_queue is not None, "Output queue is None"
self.join()
self.input_queue.close()
self.output_queue.close()
if self.emergency_shutdown:
self.input_queue.cancel_join_thread()
self.output_queue.cancel_join_thread()
else:
self.input_queue.join_thread()
self.output_queue.join_thread()

def check_worker_health(self) -> None:
"""
Checks if any worker process has terminated unexpectedly
"""
for process in self.processes:
if not process.is_alive() and process.exitcode != 0:
self.emergency_shutdown = True
self.join_or_terminate()
raise RuntimeError(
f"Worker PID: {process.pid} terminated unexpectedly with code {process.exitcode}"
)

def join_or_terminate(self, timeout: Optional[int] = 1) -> None:
"""
Expand Down Expand Up @@ -210,4 +229,5 @@ def __del__(self) -> None:
https://eli.thegreenplace.net/2009/06/12/safely-using-destructors-in-python/.
"""
for process in self.processes:
process.terminate()
if process.is_alive():
process.terminate()

0 comments on commit 40a0374

Please sign in to comment.