diff --git a/fed/api.py b/fed/api.py index 692bec6..66d4aba 100644 --- a/fed/api.py +++ b/fed/api.py @@ -16,10 +16,11 @@ import inspect import logging import signal -from typing import Any, Dict, List, Union, Callable +from typing import Any, Callable, Dict, List, Union import cloudpickle import ray +from ray.exceptions import RayError import fed._private.compatible_utils as compatible_utils import fed.config as fed_config @@ -27,26 +28,25 @@ from fed._private import constants from fed._private.fed_actor import FedActorHandle from fed._private.fed_call_holder import FedCallHolder -from fed.exceptions import FedRemoteError from fed._private.global_context import ( - init_global_context, + clear_global_context, get_global_context, - clear_global_context + init_global_context, ) +from fed.config import CrossSiloMessageConfig +from fed.exceptions import FedRemoteError +from fed.fed_object import FedObject from fed.proxy.barriers import ( - ping_others, - recv, - send, _start_receiver_proxy, _start_sender_proxy, _start_sender_receiver_proxy, + ping_others, + recv, + send, set_proxy_actor_name, ) -from fed.proxy.base_proxy import SenderProxy, ReceiverProxy, SenderReceiverProxy -from fed.config import CrossSiloMessageConfig -from fed.fed_object import FedObject +from fed.proxy.base_proxy import ReceiverProxy, SenderProxy, SenderReceiverProxy from fed.utils import is_ray_object_refs, setup_logger -from ray.exceptions import RayError logger = logging.getLogger(__name__) @@ -59,7 +59,8 @@ def _signal_handler(signum, frame): logger.warning( "Stop signal received (e.g. via SIGINT/Ctrl+C), " "try to shutdown fed. Press CTRL+C " - "(or send SIGINT/SIGKILL/SIGTERM) to skip.") + "(or send SIGINT/SIGKILL/SIGTERM) to skip." + ) _shutdown(intended=False) @@ -162,8 +163,9 @@ def init( assert party in addresses, f"Party {party} is not in the addresses {addresses}." fed_utils.validate_addresses(addresses) - init_global_context(current_party=party, job_name=job_name, - failure_handler=failure_handler) + init_global_context( + current_party=party, job_name=job_name, failure_handler=failure_handler + ) tls_config = {} if tls_config is None else tls_config if tls_config: assert ( @@ -196,7 +198,7 @@ def init( logging_format=constants.RAYFED_LOG_FMT, date_format=constants.RAYFED_DATE_FMT, party_val=_get_party(job_name), - job_name=job_name + job_name=job_name, ) logger.info(f'Started rayfed with {cluster_config}') @@ -204,12 +206,13 @@ def init( signal.signal(signal.SIGINT, _signal_handler) get_global_context().get_cleanup_manager().start( exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure, - expose_error_trace=cross_silo_comm_config.expose_error_trace + expose_error_trace=cross_silo_comm_config.expose_error_trace, ) if receiver_sender_proxy_cls is not None: set_proxy_actor_name( - job_name, cross_silo_comm_dict.get("use_global_proxy", True), True) + job_name, cross_silo_comm_dict.get("use_global_proxy", True), True + ) _start_sender_receiver_proxy( addresses=addresses, party=party, @@ -231,7 +234,8 @@ def init( receiver_proxy_cls = GrpcReceiverProxy set_proxy_actor_name( - job_name, cross_silo_comm_dict.get("use_global_proxy", True)) + job_name, cross_silo_comm_dict.get("use_global_proxy", True) + ) _start_receiver_proxy( addresses=addresses, party=party, @@ -244,8 +248,7 @@ def init( if sender_proxy_cls is None: logger.debug( - "No sender proxy class specified, use `GrpcSenderProxy` by " - "default." + "No sender proxy class specified, use `GrpcSenderProxy` by default." ) from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy @@ -281,12 +284,12 @@ def _shutdown(intended=True): intended: (Optional) Whether this is a intended exit. If not a "failure handler" will be triggered. """ - if (get_global_context() is not None): + if get_global_context() is not None: # Job has inited, can be shutdown failure_handler = get_global_context().get_failure_handler() compatible_utils._clear_internal_kv() clear_global_context() - if (not intended and failure_handler is not None): + if not intended and failure_handler is not None: failure_handler() logger.info('Shutdowned rayfed.') @@ -472,10 +475,12 @@ def get( values = values[0] return values except RayError as e: - if isinstance(e.cause, FedRemoteError): - logger.warning("Encounter RemoteError happend in other parties" - f", prepare to exit, error message: {e.cause}") - if (get_global_context().acquire_shutdown_flag()): + if isinstance(e, FedRemoteError): + logger.warning( + "Encounter RemoteError happend in other parties" + f", prepare to exit, error message: {e.cause}" + ) + if get_global_context().acquire_shutdown_flag(): _shutdown(intended=False) raise e diff --git a/fed/config.py b/fed/config.py index c447efb..230d979 100644 --- a/fed/config.py +++ b/fed/config.py @@ -2,13 +2,14 @@ are mutable. """ -import fed._private.compatible_utils as compatible_utils -import fed._private.constants as fed_constants -import cloudpickle import json - -from typing import Dict, List, Optional from dataclasses import dataclass, fields +from typing import Dict, List, Optional + +import cloudpickle + +import fed._private.compatible_utils as compatible_utils +import fed._private.constants as fed_constants class ClusterConfig: @@ -48,24 +49,26 @@ def cross_silo_comm_config_dict(self) -> Dict: _job_config = None -def get_cluster_config(job_name: str = None): +def get_cluster_config(job_name: str = None) -> ClusterConfig: """This function is not thread safe to use.""" global _cluster_config if _cluster_config is None: - assert job_name is not None, \ - "Initializing internal kv need to provide job_name." + assert ( + job_name is not None + ), "Initializing internal kv need to provide job_name." compatible_utils._init_internal_kv(job_name) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG) _cluster_config = ClusterConfig(raw_dict) return _cluster_config -def get_job_config(job_name: str = None): +def get_job_config(job_name: str = None) -> JobConfig: """This config still acts like cluster config for now""" global _job_config if _job_config is None: - assert job_name is not None, \ - "Initializing internal kv need to provide job_name." + assert ( + job_name is not None + ), "Initializing internal kv need to provide job_name." compatible_utils._init_internal_kv(job_name) raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG) _job_config = JobConfig(raw_dict) diff --git a/fed/proxy/barriers.py b/fed/proxy/barriers.py index aef0812..aa207ae 100644 --- a/fed/proxy/barriers.py +++ b/fed/proxy/barriers.py @@ -354,7 +354,7 @@ def __init__( self._addresses = addresses self._party = party self._tls_config = tls_config - job_config = fed_config.get_job_config() + job_config = fed_config.get_job_config(job_name=job_name) cross_silo_comm_config = job_config.cross_silo_comm_config_dict self._proxy_instance = proxy_cls( addresses, party, tls_config, cross_silo_comm_config @@ -397,7 +397,7 @@ def send( except Exception as e: logger.error(f'Failed to {send_log_msg}, error: {e}') return False - logger.debug(f"Succeeded to send {send_log_msg}. Response is {response}") + logger.debug(f"Succeeded to {send_log_msg}. Response is {response}") return True # True indicates it's sent successfully. def _get_stats(self):