diff --git a/fed/_private/global_context.py b/fed/_private/global_context.py index 0a7c522..f3d1830 100644 --- a/fed/_private/global_context.py +++ b/fed/_private/global_context.py @@ -16,6 +16,7 @@ from typing import Callable from fed.cleanup import CleanupManager +from fed.exceptions import FedRemoteError class GlobalContext: @@ -25,6 +26,7 @@ def __init__( current_party: str, sending_failure_handler: Callable[[Exception], None], exit_on_sending_failure=False, + continue_waiting_for_data_sending_on_error=False, ) -> None: self._job_name = job_name self._seq_count = 0 @@ -35,6 +37,10 @@ def __init__( self._cleanup_manager = CleanupManager( current_party, self.acquire_shutdown_flag ) + self._last_received_error: FedRemoteError = None + self._continue_waiting_for_data_sending_on_error = ( + continue_waiting_for_data_sending_on_error + ) def next_seq_id(self) -> int: self._seq_count += 1 @@ -52,6 +58,15 @@ def get_sending_failure_handler(self) -> Callable[[], None]: def get_exit_on_sending_failure(self) -> bool: return self._exit_on_sending_failure + def get_last_recevied_error(self) -> FedRemoteError: + return self._last_received_error + + def set_last_recevied_error(self, err): + self._last_received_error = err + + def get_continue_waiting_for_data_sending_on_error(self) -> bool: + return self._continue_waiting_for_data_sending_on_error + def acquire_shutdown_flag(self) -> bool: """ Acquiring a lock and set the flag to make sure @@ -78,12 +93,18 @@ def acquire_shutdown_flag(self) -> bool: def init_global_context( current_party: str, job_name: str, + exit_on_sending_failure: bool, + continue_waiting_for_data_sending_on_error: bool, sending_failure_handler: Callable[[Exception], None] = None, ) -> None: global _global_context if _global_context is None: _global_context = GlobalContext( - job_name, current_party, sending_failure_handler + job_name, + current_party, + exit_on_sending_failure=exit_on_sending_failure, + continue_waiting_for_data_sending_on_error=continue_waiting_for_data_sending_on_error, + sending_failure_handler=sending_failure_handler, ) @@ -92,8 +113,8 @@ def get_global_context(): return _global_context -def clear_global_context(graceful=True): +def clear_global_context(wait_for_sending=False): global _global_context if _global_context is not None: - _global_context.get_cleanup_manager().stop(graceful=graceful) + _global_context.get_cleanup_manager().stop(wait_for_sending=wait_for_sending) _global_context = None diff --git a/fed/_private/message_queue.py b/fed/_private/message_queue.py index e4dabe9..3207365 100644 --- a/fed/_private/message_queue.py +++ b/fed/_private/message_queue.py @@ -26,7 +26,7 @@ class MessageQueueManager: - def __init__(self, msg_handler, failure_handler=None, thread_name=''): + def __init__(self, msg_handler, failure_handler=None, thread_name=""): assert callable(msg_handler), "msg_handler must be a callable function" # `deque()` is thread safe on `popleft` and `append` operations. # See https://docs.python.org/3/library/collections.html#deque-objects @@ -73,16 +73,13 @@ def _notify_to_exit(self, immediately=False): else: self.append(STOP_SYMBOL) - def stop(self, immediately=False): + def stop(self, wait_for_sending=True): """ Stop the message queue. Args: - immediately (bool): A flag indicating whether to stop the queue - immediately or not. Default is True. - If True: insert the STOP_SYMBOL at the begin of the queue. - If False: insert the STOP_SYMBOL at the end of the queue, which means - stop the for loop until all messages in queue are all sent. + wait_for_sending (bool): A flag indicating whether joining the thread to wait for + the loop stop. If True, do not join. Defaults to True. """ if threading.current_thread() == self._thread: logger.error( @@ -97,9 +94,12 @@ def stop(self, immediately=False): # Therefore, currently, not support forcelly kill thread if self.is_started(): logger.debug(f"Killing thread[{self._thread_name}].") - self._notify_to_exit(immediately=immediately) - self._thread.join() - logger.info(f"The message polling thread[{self._thread_name}] was exited.") + self._notify_to_exit(immediately=not wait_for_sending) + if wait_for_sending: + self._thread.join() + logger.info( + f"The message polling thread[{self._thread_name}] was exited." + ) def is_started(self): return self._thread is not None and self._thread.is_alive() diff --git a/fed/api.py b/fed/api.py index fa9ff1a..b3c7909 100644 --- a/fed/api.py +++ b/fed/api.py @@ -16,12 +16,11 @@ import inspect import logging import signal +import sys from typing import Any, Callable, Dict, List, Union import cloudpickle import ray -from ray.exceptions import RayError -import sys import fed._private.compatible_utils as compatible_utils import fed.config as fed_config @@ -70,7 +69,7 @@ def init( party: str = None, config: Dict = {}, tls_config: Dict = None, - logging_level: str = 'info', + logging_level: str = "info", sender_proxy_cls: SenderProxy = None, receiver_proxy_cls: ReceiverProxy = None, receiver_sender_proxy_cls: SenderReceiverProxy = None, @@ -125,6 +124,7 @@ def init( "exit_on_sending_failure": True, "expose_error_trace": True, "use_global_proxy": True, + "continue_waiting_for_data_sending_on_error": False, }, "barrier_on_initializing": True, } @@ -182,16 +182,23 @@ def init( job_name = constants.RAYFED_DEFAULT_JOB_NAME fed_utils.validate_addresses(addresses) + + cross_silo_comm_dict = config.get("cross_silo_comm", {}) + cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict) + init_global_context( current_party=party, job_name=job_name, + exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure, + continue_waiting_for_data_sending_on_error=cross_silo_comm_config.continue_waiting_for_data_sending_on_error, sending_failure_handler=sending_failure_handler, ) + tls_config = {} if tls_config is None else tls_config if tls_config: assert ( - 'cert' in tls_config and 'key' in tls_config - ), 'Cert or key are not in tls_config.' + "cert" in tls_config and "key" in tls_config + ), "Cert or key are not in tls_config." # A Ray private accessing, should be replaced in public API. compatible_utils._init_internal_kv(job_name) @@ -201,15 +208,15 @@ def init( constants.KEY_OF_CURRENT_PARTY_NAME: party, constants.KEY_OF_TLS_CONFIG: tls_config, } + compatible_utils.kv.put( + constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) + ) - cross_silo_comm_dict = config.get("cross_silo_comm", {}) job_config = { constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: cross_silo_comm_dict, } - compatible_utils.kv.put( - constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) - ) compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, cloudpickle.dumps(job_config)) + # Set logger. # Note(NKcqx): This should be called after internal_kv has party value, i.e. # after `ray.init` and @@ -222,8 +229,7 @@ def init( job_name=job_name, ) - logger.info(f'Started rayfed with {cluster_config}') - cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict) + logger.info(f"Started rayfed with {cluster_config}") signal.signal(signal.SIGINT, _signal_handler) get_global_context().get_cleanup_manager().start( exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure, @@ -305,7 +311,7 @@ def _shutdown(intended=True): Args: intended: (Optional) Whether this is a intended shutdown. If not - a "failure handler" will be triggered and sys.exit will be called then. + a "failure handler" will be triggered and do not wait data sending. """ if get_global_context() is None: @@ -313,34 +319,45 @@ def _shutdown(intended=True): return if intended: - logger.info('Shutdowning rayfed intendedly...') + logger.info("Shutdowning rayfed intendedly...") else: - logger.warn('Shutdowning rayfed unintendedly...') + logger.warn("Shutdowning rayfed unintendedly...") global_context = get_global_context() last_sending_error = global_context.get_cleanup_manager().get_last_sending_error() + last_received_error = global_context.get_last_recevied_error() if last_sending_error is not None: - logging.error(f'Cross-silo sending error occured. {last_sending_error}') + logging.error(f"Cross-silo sending error occured. {last_sending_error}") + + wait_for_sending = True + if ( + last_sending_error is not None or last_received_error is not None + ) and not global_context.get_continue_waiting_for_data_sending_on_error(): + wait_for_sending = False + logging.info(f'{"Wait" if wait_for_sending else "No wait"} for data sending.') if not intended: # Execute failure_handler fisrtly. failure_handler = global_context.get_sending_failure_handler() if failure_handler is not None: - logger.info('Executing failure handler...') + logger.info(f"Executing failure handler {failure_handler} ...") failure_handler(last_sending_error) + exit_on_sending_failure = global_context.get_exit_on_sending_failure() + # Clean context. compatible_utils._clear_internal_kv() - clear_global_context(graceful=intended) - logger.info('Shutdowned rayfed.') + clear_global_context(wait_for_sending=wait_for_sending) + logger.info("Shutdowned rayfed.") - # Exit with error. - logger.critical('Exit now due to the previous error.') - sys.exit(1) + if exit_on_sending_failure: + # Exit with error. + logger.critical("Exit now due to the previous error.") + sys.exit(1) else: # Clean context. compatible_utils._clear_internal_kv() - clear_global_context(graceful=intended) - logger.info('Shutdowned rayfed.') + clear_global_context(wait_for_sending=wait_for_sending) + logger.info("Shutdowned rayfed.") def _get_addresses(job_name: str = None): @@ -586,6 +603,8 @@ def get( "Encounter RemoteError happend in other parties" f", error message: {e.cause}" ) + if get_global_context() is not None: + get_global_context().set_last_recevied_error(e) raise e diff --git a/fed/cleanup.py b/fed/cleanup.py index 03aab89..d7d6f17 100644 --- a/fed/cleanup.py +++ b/fed/cleanup.py @@ -45,7 +45,7 @@ class CleanupManager: def __init__(self, current_party, acquire_shutdown_flag) -> None: self._sending_data_q = MessageQueueManager( lambda msg: self._process_data_sending_task_return(msg), - thread_name='DataSendingQueueThread', + thread_name="DataSendingQueueThread", ) self._sending_error_q = MessageQueueManager( @@ -64,32 +64,16 @@ def start(self, exit_on_sending_failure=False, expose_error_trace=False): self._expose_error_trace = expose_error_trace self._sending_data_q.start() - logger.debug('Start check sending thread.') + logger.debug("Start check sending thread.") self._sending_error_q.start() - logger.debug('Start check error sending thread.') + logger.debug("Start check error sending thread.") - def _main_thread_monitor(): - main_thread = threading.main_thread() - main_thread.join() - logging.debug('Stoping sending queue ...') - self.stop(graceful=True) - - self._monitor_thread = threading.Thread(target=_main_thread_monitor) - self._monitor_thread.start() - logger.info('Start check sending monitor thread.') - - def stop(self, graceful=True): + def stop(self, wait_for_sending=False): # NOTE(NKcqx): MUST firstly stop the data queue, because it # may still throw errors during the termination which need to # be sent to the error queue. - if graceful: - self._sending_data_q.stop(immediately=False) - self._sending_error_q.stop(immediately=False) - else: - # Stop data queue immediately, but stop error queue not immediately always - # to sure that error can be sent to peers. - self._sending_data_q.stop(immediately=True) - self._sending_error_q.stop(immediately=False) + self._sending_data_q.stop(wait_for_sending=wait_for_sending) + self._sending_error_q.stop(wait_for_sending=wait_for_sending) def push_to_sending( self, @@ -168,9 +152,9 @@ def _process_data_sending_task_return(self, message): res = ray.get(obj_ref) except Exception as e: logger.warn( - f'Failed to send {obj_ref} to {dest_party}, error: {e},' - f'upstream_seq_id: {upstream_seq_id}, ' - f'downstream_seq_id: {downstream_seq_id}.' + f"Failed to send {obj_ref} to {dest_party}, error: {e}," + f"upstream_seq_id: {upstream_seq_id}, " + f"downstream_seq_id: {downstream_seq_id}." ) self._last_sending_error = e if isinstance(e, RayError): diff --git a/fed/config.py b/fed/config.py index a8799ee..404b5ad 100644 --- a/fed/config.py +++ b/fed/config.py @@ -101,6 +101,10 @@ class CrossSiloMessageConfig: exit_on_sending_failure: whether exit when failure on cross-silo sending. If True, a SIGINT will be signaled to self if failed to sending cross-silo data and exit then. + continue_waiting_for_data_sending_on_error: + Whether to continue waiting for data sending if an error occurs, including + data-sending errors and receiving errors from the peer. If True, wait until + all data has been sent. messages_max_size_in_bytes: The maximum length in bytes of cross-silo messages. If None, the default value of 500 MB is specified. @@ -122,6 +126,7 @@ class CrossSiloMessageConfig: timeout_in_ms: int = 60000 messages_max_size_in_bytes: int = None exit_on_sending_failure: Optional[bool] = False + continue_waiting_for_data_sending_on_error: Optional[bool] = False serializing_allowed_list: Optional[Dict[str, str]] = None send_resource_label: Optional[Dict[str, str]] = None recv_resource_label: Optional[Dict[str, str]] = None diff --git a/fed/tests/test_cross_silo_error.py b/fed/tests/test_cross_silo_error.py index 88d7f87..c544e44 100644 --- a/fed/tests/test_cross_silo_error.py +++ b/fed/tests/test_cross_silo_error.py @@ -14,13 +14,13 @@ import multiprocessing import sys -from unittest.mock import Mock import pytest import ray import fed import fed._private.compatible_utils as compatible_utils +from fed._private.global_context import get_global_context from fed.exceptions import FedRemoteError @@ -34,6 +34,11 @@ def error_func(): raise MyError("Test normal task Error") +@fed.remote +def normal_func(a): + return a + + @fed.remote class My: def __init__(self) -> None: @@ -227,5 +232,81 @@ def test_cross_silo_alice_send_error_and_shutdown_once(): assert p_bob.exitcode == 0 +def run5(party: str): + compatible_utils.init_ray(address='local') + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', + } + + fed.init( + addresses=addresses, + party=party, + logging_level='debug', + config={ + 'cross_silo_comm': { + 'timeout_ms': 20 * 1000, + 'expose_error_trace': False, + 'continue_waiting_for_data_sending_on_error': True, + }, + }, + ) + + assert get_global_context().get_continue_waiting_for_data_sending_on_error() + + fed.shutdown() + ray.shutdown() + + +def test_continue_waiting_for_data_sending_on_error(): + p_alice = multiprocessing.Process(target=run5, args=('alice',)) + p_alice.start() + p_alice.join() + assert p_alice.exitcode == 0 + + +def run6(party: str): + compatible_utils.init_ray(address='local') + addresses = { + 'alice': '127.0.0.1:11012', + 'bob': '127.0.0.1:11011', + } + + fed.init( + addresses=addresses, + party=party, + logging_level='debug', + config={ + 'cross_silo_comm': { + 'timeout_ms': 20 * 1000, + 'expose_error_trace': False, + 'exit_on_sending_failure': True, + }, + }, + ) + + try: + # Alice ran into an error and broadcast error to bob. And exit then. + a = error_func.party('alice').remote() + b = normal_func.party('bob').remote(a) + + # Bob got the error. + fed.get(b) + finally: + fed.shutdown() + ray.shutdown() + + +def test_no_wait_for_data_sending_on_error(): + p_alice = multiprocessing.Process(target=run6, args=('alice',)) + p_bob = multiprocessing.Process(target=run6, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 1 + assert p_bob.exitcode == 1 + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) diff --git a/fed/tests/test_listening_address.py b/fed/tests/test_listening_address.py index 5f9891c..d2746b7 100644 --- a/fed/tests/test_listening_address.py +++ b/fed/tests/test_listening_address.py @@ -29,13 +29,13 @@ def _run(party): # NOTE(NKcqx): Firstly try to bind IPv6 because the grpc server will do so. # Otherwise this UT will fail because socket bind $occupied_port # on IPv4 address while grpc server listened on the Ipv6 address. - try: - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - # Pre-occuping the port using local address - s.bind(("::1", occupied_port)) - except OSError: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("127.0.0.1", occupied_port)) + s_ipv6 = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + s_ipv6.bind(("::1", occupied_port)) + s_ipv4 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s_ipv4.bind(("127.0.0.1", occupied_port)) + import time + + time.sleep(5) addresses = {'alice': f'127.0.0.1:{occupied_port}'} @@ -46,10 +46,8 @@ def _run(party): party=party, ) - import time - - time.sleep(5) - s.close() + s_ipv6.close() + s_ipv4.close() fed.shutdown() ray.shutdown() diff --git a/fed/tests/test_transport_proxy.py b/fed/tests/test_transport_proxy.py index de37e86..534522d 100644 --- a/fed/tests/test_transport_proxy.py +++ b/fed/tests/test_transport_proxy.py @@ -31,7 +31,7 @@ from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy, GrpcSenderProxy if compatible_utils._compare_version_strings( - fed_utils.get_package_version('protobuf'), '4.0.0' + fed_utils.get_package_version("protobuf"), "4.0.0" ): from fed.grpc.pb4 import fed_pb2 as fed_pb2 from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc @@ -45,10 +45,10 @@ def test_n_to_1_transport(): sending data to the target receiver proxy, and there also have N receivers to `get_data` from receiver proxy at that time. """ - compatible_utils.init_ray(address='local') - test_job_name = 'test_n_to_1_transport' - party = 'test_party' - global_context.init_global_context(party, test_job_name) + compatible_utils.init_ray(address="local") + test_job_name = "test_n_to_1_transport" + party = "test_party" + global_context.init_global_context(party, test_job_name, False, False) global_context.get_global_context().get_cleanup_manager().start() cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: "", @@ -63,18 +63,18 @@ def test_n_to_1_transport(): NUM_DATA = 10 SERVER_ADDRESS = "127.0.0.1:12344" - addresses = {'test_party': SERVER_ADDRESS} + addresses = {"test_party": SERVER_ADDRESS} _start_receiver_proxy( addresses, party, - logging_level='info', + logging_level="info", proxy_cls=GrpcReceiverProxy, proxy_config={}, ) _start_sender_proxy( addresses, party, - logging_level='info', + logging_level="info", proxy_cls=GrpcSenderProxy, proxy_config={}, ) @@ -137,7 +137,7 @@ async def _test_run_grpc_server( ), server, ) - server.add_insecure_port(f'[::]:{port}') + server.add_insecure_port(f"[::]:{port}") await server.start() await server.wait_for_termination() @@ -158,7 +158,7 @@ def __init__( async def run_grpc_server(self): return await _test_run_grpc_server( - self._listen_addr[self._listen_addr.index(':') + 1 :], + self._listen_addr[self._listen_addr.index(":") + 1 :], None, None, self._party, @@ -193,20 +193,20 @@ def _test_start_receiver_proxy( def test_send_grpc_with_meta(): - compatible_utils.init_ray(address='local') + compatible_utils.init_ray(address="local") cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: "", } metadata = {"key": "value"} - config = {'http_header': metadata} + config = {"http_header": metadata} job_config = { constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: config, } - test_job_name = 'test_send_grpc_with_meta' - party_name = 'test_party' - global_context.init_global_context(party_name, test_job_name) + test_job_name = "test_send_grpc_with_meta" + party_name = "test_party" + global_context.init_global_context(party_name, test_job_name, False, False) compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) @@ -226,7 +226,7 @@ def test_send_grpc_with_meta(): _start_sender_proxy( addresses, party_name, - logging_level='info', + logging_level="info", proxy_cls=GrpcSenderProxy, proxy_config=config, ) diff --git a/fed/tests/test_transport_proxy_tls.py b/fed/tests/test_transport_proxy_tls.py index e4af7ce..dea2948 100644 --- a/fed/tests/test_transport_proxy_tls.py +++ b/fed/tests/test_transport_proxy_tls.py @@ -34,8 +34,8 @@ def test_n_to_1_transport(): sending data to the target receiver proxy, and there also have N receivers to `get_data` from receiver proxy at that time. """ - compatible_utils.init_ray(address='local') - test_job_name = 'test_n_to_1_transport' + compatible_utils.init_ray(address="local") + test_job_name = "test_n_to_1_transport" cert_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "/tmp/rayfed/test-certs/" ) @@ -44,13 +44,13 @@ def test_n_to_1_transport(): "cert": os.path.join(cert_dir, "server.crt"), "key": os.path.join(cert_dir, "server.key"), } - party = 'test_party' + party = "test_party" cluster_config = { constants.KEY_OF_CLUSTER_ADDRESSES: "", constants.KEY_OF_CURRENT_PARTY_NAME: "", constants.KEY_OF_TLS_CONFIG: tls_config, } - global_context.init_global_context(party, test_job_name) + global_context.init_global_context(party, test_job_name, False, False) global_context.get_global_context().get_cleanup_manager().start() compatible_utils._init_internal_kv(test_job_name) compatible_utils.kv.put( @@ -59,11 +59,11 @@ def test_n_to_1_transport(): NUM_DATA = 10 SERVER_ADDRESS = "127.0.0.1:65422" - addresses = {'test_party': SERVER_ADDRESS} + addresses = {"test_party": SERVER_ADDRESS} _start_receiver_proxy( addresses, party, - logging_level='info', + logging_level="info", tls_config=tls_config, proxy_cls=GrpcReceiverProxy, proxy_config={}, @@ -71,7 +71,7 @@ def test_n_to_1_transport(): _start_sender_proxy( addresses, party, - logging_level='info', + logging_level="info", tls_config=tls_config, proxy_cls=GrpcSenderProxy, proxy_config={},