Skip to content

Commit

Permalink
Proxy actor per job (ray-project#184)
Browse files Browse the repository at this point in the history
* proxy actor per job

Signed-off-by: NKcqx <[email protected]>

---------

Signed-off-by: NKcqx <[email protected]>
  • Loading branch information
NKcqx authored Dec 12, 2023
1 parent 6035939 commit 1baee4d
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 67 deletions.
6 changes: 6 additions & 0 deletions fed/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@
RAYFED_DEFAULT_JOB_NAME = "Anonymous_job"

RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}"

RAYFED_DEFAULT_SENDER_PROXY_ACTOR_NAME = "SenderProxyActor"

RAYFED_DEFAULT_RECEIVER_PROXY_ACTOR_NAME = "ReceiverProxyActor"

RAYFED_DEFAULT_SENDER_RECEIVER_PROXY_ACTOR_NAME = "SenderReceiverProxyActor"
17 changes: 10 additions & 7 deletions fed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
_start_receiver_proxy,
_start_sender_proxy,
_start_sender_receiver_proxy,
set_receiver_proxy_actor_name,
set_sender_proxy_actor_name,
set_proxy_actor_name,
)
from fed.proxy.base_proxy import SenderProxy, ReceiverProxy, SenderReceiverProxy
from fed.config import CrossSiloMessageConfig
Expand Down Expand Up @@ -116,6 +115,7 @@ def init(
"timeout_in_ms": 1000,
"exit_on_sending_failure": True,
"expose_error_trace": True,
"use_global_proxy": True,
},
"barrier_on_initializing": True,
}
Expand Down Expand Up @@ -170,7 +170,6 @@ def init(
'cert' in tls_config and 'key' in tls_config
), 'Cert or key are not in tls_config.'

cross_silo_comm_dict = config.get("cross_silo_comm", {})
# A Ray private accessing, should be replaced in public API.
compatible_utils._init_internal_kv(job_name)

Expand All @@ -180,6 +179,7 @@ def init(
constants.KEY_OF_TLS_CONFIG: tls_config,
}

cross_silo_comm_dict = config.get("cross_silo_comm", {})
job_config = {
constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: cross_silo_comm_dict,
}
Expand All @@ -206,10 +206,10 @@ def init(
exit_on_sending_failure=cross_silo_comm_config.exit_on_sending_failure,
expose_error_trace=cross_silo_comm_config.expose_error_trace
)

if receiver_sender_proxy_cls is not None:
proxy_actor_name = 'sender_recevier_actor'
set_sender_proxy_actor_name(proxy_actor_name)
set_receiver_proxy_actor_name(proxy_actor_name)
set_proxy_actor_name(
job_name, cross_silo_comm_dict.get("use_global_proxy", True), True)
_start_sender_receiver_proxy(
addresses=addresses,
party=party,
Expand All @@ -230,6 +230,8 @@ def init(
from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy

receiver_proxy_cls = GrpcReceiverProxy
set_proxy_actor_name(
job_name, cross_silo_comm_dict.get("use_global_proxy", True))
_start_receiver_proxy(
addresses=addresses,
party=party,
Expand All @@ -242,12 +244,13 @@ def init(

if sender_proxy_cls is None:
logger.debug(
"No sender proxy class specified, use `GrpcRecvProxy` by "
"No sender proxy class specified, use `GrpcSenderProxy` by "
"default."
)
from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy

sender_proxy_cls = GrpcSenderProxy

_start_sender_proxy(
addresses=addresses,
party=party,
Expand Down
8 changes: 4 additions & 4 deletions fed/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class CleanupManager:
The main logic path is:
A. If `fed.shutdown()` is invoked in the main thread and every thing works well,
the `stop()` will be invoked as well and the checking thread will be
notifiled to exit gracefully.
notified to exit gracefully.
B. If the main thread are broken before sending the notification flag to the
sending thread, the monitor thread will detect that and it joins until the main
thread exited, then notifys the checking thread.
B. If the main thread are broken before sending the stop flag to the sending
thread, the monitor thread will detect that and then notifys the checking
thread.
"""

def __init__(self, current_party, acquire_shutdown_flag) -> None:
Expand Down
3 changes: 3 additions & 0 deletions fed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class CrossSiloMessageConfig:
This won't override basic tcp headers, such as `user-agent`, but concat
them together.
max_concurrency: the max_concurrency of the sender/receiver proxy actor.
use_global_proxy: Whether using the global proxy actor or create new proxy
actor for current job.
"""

proxy_max_restarts: int = None
Expand All @@ -114,6 +116,7 @@ class CrossSiloMessageConfig:
http_header: Optional[Dict[str, str]] = None
max_concurrency: Optional[int] = None
expose_error_trace: Optional[bool] = False
use_global_proxy: Optional[bool] = True

def __json__(self):
return json.dumps(self.__dict__)
Expand Down
135 changes: 81 additions & 54 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,39 @@ def set_receiver_proxy_actor_name(name: str):
_RECEIVER_PROXY_ACTOR_NAME = name


def set_proxy_actor_name(job_name: str,
use_global_proxy: bool,
sender_recvr_proxy: bool = False):
"""
Generate the name of the proxy actor.
Args:
job_name: The name of the job, used for actor name's postfix
use_global_proxy: Whether
to use a single proxy actor or not. If True, the name of the proxy
actor will be the default global name, otherwise the name will be
added with a postfix.
sender_recvr_proxy: Whether to use the sender-receiver proxy actor or
not. If True, since there's only one proxy actor, make two actor name
the same.
"""
sender_actor_name = (
constants.RAYFED_DEFAULT_SENDER_PROXY_ACTOR_NAME
if not sender_recvr_proxy
else constants.RAYFED_DEFAULT_SENDER_RECEIVER_PROXY_ACTOR_NAME
)
receiver_actor_name = (
constants.RAYFED_DEFAULT_RECEIVER_PROXY_ACTOR_NAME
if not sender_recvr_proxy
else constants.RAYFED_DEFAULT_SENDER_RECEIVER_PROXY_ACTOR_NAME
)
if not use_global_proxy:
sender_actor_name = f"{sender_actor_name}_{job_name}"
receiver_actor_name = f"{receiver_actor_name}_{job_name}"
set_sender_proxy_actor_name(sender_actor_name)
set_receiver_proxy_actor_name(receiver_actor_name)


def key_exists_in_two_dim_dict(the_dict, key_a, key_b) -> bool:
key_a, key_b = str(key_a), str(key_b)
if key_a not in the_dict:
Expand Down Expand Up @@ -217,22 +250,20 @@ def _start_receiver_proxy(
ready_timeout_second: int = 60,
):
actor_options = copy.deepcopy(_DEFAULT_RECEIVER_PROXY_OPTIONS)
if proxy_config:
proxy_config = fed_config.CrossSiloMessageConfig.from_dict(proxy_config)
if proxy_config.recv_resource_label is not None:
actor_options.update({"resources": proxy_config.recv_resource_label})
if proxy_config.max_concurrency:
actor_options.update({"max_concurrency": proxy_config.max_concurrency})
proxy_config = fed_config.CrossSiloMessageConfig.from_dict(proxy_config)
if proxy_config.recv_resource_label is not None:
actor_options.update({"resources": proxy_config.recv_resource_label})
if proxy_config.max_concurrency:
actor_options.update({"max_concurrency": proxy_config.max_concurrency})
actor_options.update({"name": receiver_proxy_actor_name()})

logger.debug(f"Starting ReceiverProxyActor with options: {actor_options}")
job_name = get_global_context().get_job_name()

global _RECEIVER_PROXY_ACTOR_NAME
receiver_proxy_actor = ReceiverProxyActor.options(
name=_RECEIVER_PROXY_ACTOR_NAME, **actor_options
).remote(
receiver_proxy_actor = ReceiverProxyActor.options(**actor_options).remote(
listening_address=addresses[party],
party=party,
job_name=get_global_context().get_job_name(),
job_name=job_name,
tls_config=tls_config,
logging_level=logging_level,
proxy_cls=proxy_cls,
Expand Down Expand Up @@ -260,30 +291,28 @@ def _start_sender_proxy(
proxy_config: Dict = None,
ready_timeout_second: int = 60,
):
if proxy_config:
proxy_config = fed_config.GrpcCrossSiloMessageConfig.from_dict(proxy_config)
actor_options = copy.deepcopy(_DEFAULT_SENDER_PROXY_OPTIONS)
if proxy_config:
if proxy_config.proxy_max_restarts:
actor_options.update(
{
"max_task_retries": proxy_config.proxy_max_restarts,
"max_restarts": 1,
}
)
if proxy_config.send_resource_label:
actor_options.update({"resources": proxy_config.send_resource_label})
if proxy_config.max_concurrency:
actor_options.update({"max_concurrency": proxy_config.max_concurrency})
proxy_config = fed_config.GrpcCrossSiloMessageConfig.from_dict(proxy_config)
if proxy_config.proxy_max_restarts:
actor_options.update(
{
"max_task_retries": proxy_config.proxy_max_restarts,
"max_restarts": 1,
}
)
if proxy_config.send_resource_label:
actor_options.update({"resources": proxy_config.send_resource_label})
if proxy_config.max_concurrency:
actor_options.update({"max_concurrency": proxy_config.max_concurrency})

job_name = get_global_context().get_job_name()
actor_options.update({"name": sender_proxy_actor_name()})

logger.debug(f"Starting SenderProxyActor with options: {actor_options}")
global _SENDER_PROXY_ACTOR
global _SENDER_PROXY_ACTOR_NAME
_SENDER_PROXY_ACTOR = SenderProxyActor.options(
name=_SENDER_PROXY_ACTOR_NAME, **actor_options
)

job_name = get_global_context().get_job_name()
_SENDER_PROXY_ACTOR = SenderProxyActor.options(**actor_options)

_SENDER_PROXY_ACTOR = _SENDER_PROXY_ACTOR.remote(
addresses=addresses,
party=party,
Expand Down Expand Up @@ -389,32 +418,32 @@ def _start_sender_receiver_proxy(
):
global _DEFAULT_SENDER_RECEIVER_PROXY_OPTIONS
actor_options = copy.deepcopy(_DEFAULT_SENDER_RECEIVER_PROXY_OPTIONS)
if proxy_config:
proxy_config = fed_config.CrossSiloMessageConfig.from_dict(proxy_config)
if proxy_config.proxy_max_restarts:
actor_options.update(
{
"max_task_retries": proxy_config.proxy_max_restarts,
"max_restarts": 1,
}
)
if proxy_config.max_concurrency:
actor_options.update({"max_concurrency": proxy_config.max_concurrency})
proxy_config = fed_config.CrossSiloMessageConfig.from_dict(proxy_config)
if proxy_config.proxy_max_restarts:
actor_options.update(
{
"max_task_retries": proxy_config.proxy_max_restarts,
"max_restarts": 1,
}
)
if proxy_config.max_concurrency:
actor_options.update({"max_concurrency": proxy_config.max_concurrency})

# NOTE(NKcqx): sender & receiver have the same name
actor_options.update({"name": receiver_proxy_actor_name()})
logger.debug(f"Starting ReceiverProxyActor with options: {actor_options}")

job_name = get_global_context().get_job_name()
global _SENDER_RECEIVER_PROXY_ACTOR
global _RECEIVER_PROXY_ACTOR_NAME

_SENDER_RECEIVER_PROXY_ACTOR = SenderReceiverProxyActor.options(
name=_RECEIVER_PROXY_ACTOR_NAME, **actor_options
).remote(
addresses=addresses,
party=party,
job_name=job_name,
tls_config=tls_config,
logging_level=logging_level,
proxy_cls=proxy_cls,
**actor_options).remote(
addresses=addresses,
party=party,
job_name=job_name,
tls_config=tls_config,
logging_level=logging_level,
proxy_cls=proxy_cls,
)
_SENDER_RECEIVER_PROXY_ACTOR.start.remote()
server_state = ray.get(
Expand All @@ -436,8 +465,7 @@ def send(
is_error: Whether the `data` is an error object or not. Default is False.
If True, the data will be sent to the error message queue.
"""
global _SENDER_PROXY_ACTOR_NAME
sender_proxy = ray.get_actor(_SENDER_PROXY_ACTOR_NAME)
sender_proxy = ray.get_actor(sender_proxy_actor_name())
res = sender_proxy.send.remote(
dest_party=dest_party,
data=data,
Expand All @@ -451,8 +479,7 @@ def send(

def recv(party: str, src_party: str, upstream_seq_id, curr_seq_id):
assert party, 'Party can not be None.'
global _RECEIVER_PROXY_ACTOR_NAME
receiver_proxy = ray.get_actor(_RECEIVER_PROXY_ACTOR_NAME)
receiver_proxy = ray.get_actor(receiver_proxy_actor_name())
return receiver_proxy.get_data.remote(src_party, upstream_seq_id, curr_seq_id)


Expand Down
2 changes: 1 addition & 1 deletion fed/proxy/grpc/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ async def _run_grpc_server(
port, event, all_data, party, lock, job_name,
server_ready_future, tls_config=None, grpc_options=None
):
logger.info(f"ReceiveProxy binding port {port}, options: {grpc_options}...")
logger.info(f"ReceiverProxy binding port {port}, options: {grpc_options}...")
server = grpc.aio.server(options=grpc_options)
fed_pb2_grpc.add_GrpcServiceServicer_to_server(
SendDataService(event, all_data, party, lock, job_name), server
Expand Down
Loading

0 comments on commit 1baee4d

Please sign in to comment.