Skip to content

Commit

Permalink
fix: correct get exeception code and some other minor fixes. (ray-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouaihui authored Dec 13, 2023
1 parent 1baee4d commit 7b73381
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 39 deletions.
57 changes: 31 additions & 26 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,37 @@
import inspect
import logging
import signal
from typing import Any, Dict, List, Union, Callable
from typing import Any, Callable, Dict, List, Union

import cloudpickle
import ray
from ray.exceptions import RayError

import fed._private.compatible_utils as compatible_utils
import fed.config as fed_config
import fed.utils as fed_utils
from fed._private import constants
from fed._private.fed_actor import FedActorHandle
from fed._private.fed_call_holder import FedCallHolder
from fed.exceptions import FedRemoteError
from fed._private.global_context import (
init_global_context,
clear_global_context,
get_global_context,
clear_global_context
init_global_context,
)
from fed.config import CrossSiloMessageConfig
from fed.exceptions import FedRemoteError
from fed.fed_object import FedObject
from fed.proxy.barriers import (
ping_others,
recv,
send,
_start_receiver_proxy,
_start_sender_proxy,
_start_sender_receiver_proxy,
ping_others,
recv,
send,
set_proxy_actor_name,
)
from fed.proxy.base_proxy import SenderProxy, ReceiverProxy, SenderReceiverProxy
from fed.config import CrossSiloMessageConfig
from fed.fed_object import FedObject
from fed.proxy.base_proxy import ReceiverProxy, SenderProxy, SenderReceiverProxy
from fed.utils import is_ray_object_refs, setup_logger
from ray.exceptions import RayError

logger = logging.getLogger(__name__)

Expand All @@ -59,7 +59,8 @@ def _signal_handler(signum, frame):
logger.warning(
"Stop signal received (e.g. via SIGINT/Ctrl+C), "
"try to shutdown fed. Press CTRL+C "
"(or send SIGINT/SIGKILL/SIGTERM) to skip.")
"(or send SIGINT/SIGKILL/SIGTERM) to skip."
)
_shutdown(intended=False)


Expand Down Expand Up @@ -162,8 +163,9 @@ def init(
assert party in addresses, f"Party {party} is not in the addresses {addresses}."

fed_utils.validate_addresses(addresses)
init_global_context(current_party=party, job_name=job_name,
failure_handler=failure_handler)
init_global_context(
current_party=party, job_name=job_name, failure_handler=failure_handler
)
tls_config = {} if tls_config is None else tls_config
if tls_config:
assert (
Expand Down Expand Up @@ -196,20 +198,21 @@ def init(
logging_format=constants.RAYFED_LOG_FMT,
date_format=constants.RAYFED_DATE_FMT,
party_val=_get_party(job_name),
job_name=job_name
job_name=job_name,
)

logger.info(f'Started rayfed with {cluster_config}')
cross_silo_comm_config = CrossSiloMessageConfig.from_dict(cross_silo_comm_dict)
signal.signal(signal.SIGINT, _signal_handler)
get_global_context().get_cleanup_manager().start(
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure,
expose_error_trace=cross_silo_comm_config.expose_error_trace
expose_error_trace=cross_silo_comm_config.expose_error_trace,
)

if receiver_sender_proxy_cls is not None:
set_proxy_actor_name(
job_name, cross_silo_comm_dict.get("use_global_proxy", True), True)
job_name, cross_silo_comm_dict.get("use_global_proxy", True), True
)
_start_sender_receiver_proxy(
addresses=addresses,
party=party,
Expand All @@ -231,7 +234,8 @@ def init(

receiver_proxy_cls = GrpcReceiverProxy
set_proxy_actor_name(
job_name, cross_silo_comm_dict.get("use_global_proxy", True))
job_name, cross_silo_comm_dict.get("use_global_proxy", True)
)
_start_receiver_proxy(
addresses=addresses,
party=party,
Expand All @@ -244,8 +248,7 @@ def init(

if sender_proxy_cls is None:
logger.debug(
"No sender proxy class specified, use `GrpcSenderProxy` by "
"default."
"No sender proxy class specified, use `GrpcSenderProxy` by default."
)
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy

Expand Down Expand Up @@ -281,12 +284,12 @@ def _shutdown(intended=True):
intended: (Optional) Whether this is a intended exit. If not
a "failure handler" will be triggered.
"""
if (get_global_context() is not None):
if get_global_context() is not None:
# Job has inited, can be shutdown
failure_handler = get_global_context().get_failure_handler()
compatible_utils._clear_internal_kv()
clear_global_context()
if (not intended and failure_handler is not None):
if not intended and failure_handler is not None:
failure_handler()
logger.info('Shutdowned rayfed.')

Expand Down Expand Up @@ -472,10 +475,12 @@ def get(
values = values[0]
return values
except RayError as e:
if isinstance(e.cause, FedRemoteError):
logger.warning("Encounter RemoteError happend in other parties"
f", prepare to exit, error message: {e.cause}")
if (get_global_context().acquire_shutdown_flag()):
if isinstance(e, FedRemoteError):
logger.warning(
"Encounter RemoteError happend in other parties"
f", prepare to exit, error message: {e.cause}"
)
if get_global_context().acquire_shutdown_flag():
_shutdown(intended=False)
raise e

Expand Down
25 changes: 14 additions & 11 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
are mutable.
"""

import fed._private.compatible_utils as compatible_utils
import fed._private.constants as fed_constants
import cloudpickle
import json

from typing import Dict, List, Optional
from dataclasses import dataclass, fields
from typing import Dict, List, Optional

import cloudpickle

import fed._private.compatible_utils as compatible_utils
import fed._private.constants as fed_constants


class ClusterConfig:
Expand Down Expand Up @@ -48,24 +49,26 @@ def cross_silo_comm_config_dict(self) -> Dict:
_job_config = None


def get_cluster_config(job_name: str = None):
def get_cluster_config(job_name: str = None) -> ClusterConfig:
"""This function is not thread safe to use."""
global _cluster_config
if _cluster_config is None:
assert job_name is not None, \
"Initializing internal kv need to provide job_name."
assert (
job_name is not None
), "Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG)
_cluster_config = ClusterConfig(raw_dict)
return _cluster_config


def get_job_config(job_name: str = None):
def get_job_config(job_name: str = None) -> JobConfig:
"""This config still acts like cluster config for now"""
global _job_config
if _job_config is None:
assert job_name is not None, \
"Initializing internal kv need to provide job_name."
assert (
job_name is not None
), "Initializing internal kv need to provide job_name."
compatible_utils._init_internal_kv(job_name)
raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG)
_job_config = JobConfig(raw_dict)
Expand Down
4 changes: 2 additions & 2 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def __init__(
self._addresses = addresses
self._party = party
self._tls_config = tls_config
job_config = fed_config.get_job_config()
job_config = fed_config.get_job_config(job_name=job_name)
cross_silo_comm_config = job_config.cross_silo_comm_config_dict
self._proxy_instance = proxy_cls(
addresses, party, tls_config, cross_silo_comm_config
Expand Down Expand Up @@ -397,7 +397,7 @@ def send(
except Exception as e:
logger.error(f'Failed to {send_log_msg}, error: {e}')
return False
logger.debug(f"Succeeded to send {send_log_msg}. Response is {response}")
logger.debug(f"Succeeded to {send_log_msg}. Response is {response}")
return True # True indicates it's sent successfully.

def _get_stats(self):
Expand Down

0 comments on commit 7b73381

Please sign in to comment.