diff --git a/fed/api.py b/fed/api.py index d89cc71..6b020eb 100644 --- a/fed/api.py +++ b/fed/api.py @@ -92,6 +92,33 @@ def init( 'carol': '127.0.0.1:10003', } party: optional; self party. + config: optional; a dict describes general job configurations. Currently the + supported configurations are [`cross_silo_comm`, 'barrier_on_initializing']. + * `cross_silo_comm`: optional; a dict describes the cross-silo common + configs, the supported configs can be referred to + `fed.config.CrossSiloMessageConfig` and + `fed.config.GrpcCrossSiloMessageConfig`. Note that, the + `cross_silo_comm.messages_max_size_in_bytes` will be overrided + if `cross_silo_comm.grpc_channel_options` is provided and contains + `grpc.max_send_message_length` or `grpc.max_receive_message_length`. + * `barrier_on_initializing`: optional; a bool value indicates whether to + wait for all parties to be ready before starting the job. If set + to True, the job will be started after all parties are ready, + otherwise, the job will be started immediately after the current + party is ready. + + Example: + + .. code:: python + { + "cross_silo_comm": { + "messages_max_size_in_bytes": 500*1024, + "timeout_in_ms": 1000, + "exit_on_sending_failure": True, + "expose_error_trace": True, + }, + "barrier_on_initializing": True, + } tls_config: optional; a dict describes the tls config. E.g. For alice, diff --git a/fed/config.py b/fed/config.py index a4d1a23..386984c 100644 --- a/fed/config.py +++ b/fed/config.py @@ -95,8 +95,7 @@ class CrossSiloMessageConfig: cross-silo sending. If True, a SIGTERM will be signaled to self if failed to sending cross-silo data. messages_max_size_in_bytes: The maximum length in bytes of - cross-silo messages. - If None, the default value of 500 MB is specified. + cross-silo messages. If None, the default value of 500 MB is specified. timeout_in_ms: The timeout in mili-seconds of a cross-silo RPC call. It's 60000 by default. http_header: The HTTP header, e.g. metadata in grpc, sent with the RPC request. diff --git a/fed/proxy/grpc/grpc_proxy.py b/fed/proxy/grpc/grpc_proxy.py index f1dfc71..85b7256 100644 --- a/fed/proxy/grpc/grpc_proxy.py +++ b/fed/proxy/grpc/grpc_proxy.py @@ -60,8 +60,21 @@ def parse_grpc_options(proxy_config: CrossSiloMessageConfig): dict: A dictionary containing the gRPC channel options. """ grpc_channel_options = {} - if proxy_config is not None and isinstance( - proxy_config, GrpcCrossSiloMessageConfig): + if proxy_config is not None: + # NOTE(NKcqx): `messages_max_size_in_bytes` is a common cross-silo + # config that should be extracted and filled into proper grpc's + # channel options. + # However, `GrpcCrossSiloMessageConfig` provides a more flexible way + # to configure grpc channel options, i.e. the `grpc_channel_options` + # field, which may override the `messages_max_size_in_bytes` field. + if (isinstance(proxy_config, CrossSiloMessageConfig)): + if (proxy_config.messages_max_size_in_bytes is not None): + grpc_channel_options.update({ + 'grpc.max_send_message_length': + proxy_config.messages_max_size_in_bytes, + 'grpc.max_receive_message_length': + proxy_config.messages_max_size_in_bytes, + }) if isinstance(proxy_config, GrpcCrossSiloMessageConfig): if proxy_config.grpc_channel_options is not None: grpc_channel_options.update(proxy_config.grpc_channel_options) diff --git a/fed/tests/test_grpc_options_on_proxies.py b/fed/tests/test_grpc_options_on_proxies.py index 6a09c75..cb14e92 100644 --- a/fed/tests/test_grpc_options_on_proxies.py +++ b/fed/tests/test_grpc_options_on_proxies.py @@ -61,7 +61,7 @@ def _assert_on_proxy(proxy_actor): ray.shutdown() -def test_grpc_max_size(): +def test_grpc_max_size_by_channel_options(): p_alice = multiprocessing.Process(target=run, args=('alice',)) p_bob = multiprocessing.Process(target=run, args=('bob',)) p_alice.start() @@ -71,6 +71,101 @@ def test_grpc_max_size(): assert p_alice.exitcode == 0 and p_bob.exitcode == 0 +def run2(party): + compatible_utils.init_ray(address='local') + addresses = { + 'alice': '127.0.0.1:11019', + 'bob': '127.0.0.1:11018', + } + fed.init( + addresses=addresses, + party=party, + config={ + "cross_silo_comm": { + "messages_max_size_in_bytes": 100, + }, + }, + ) + + def _assert_on_proxy(proxy_actor): + config = ray.get(proxy_actor._get_proxy_config.remote()) + options = config['grpc_options'] + assert ("grpc.max_send_message_length", 100) in options + assert ("grpc.max_receive_message_length", 100) in options + assert ('grpc.so_reuseport', 0) in options + + sender_proxy = ray.get_actor(sender_proxy_actor_name()) + receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) + _assert_on_proxy(sender_proxy) + _assert_on_proxy(receiver_proxy) + + a = dummpy.party('alice').remote() + b = dummpy.party('bob').remote() + fed.get([a, b]) + + fed.shutdown() + ray.shutdown() + + +def test_grpc_max_size_by_common_config(): + p_alice = multiprocessing.Process(target=run2, args=('alice',)) + p_bob = multiprocessing.Process(target=run2, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + +def run3(party): + compatible_utils.init_ray(address='local') + addresses = { + 'alice': '127.0.0.1:11019', + 'bob': '127.0.0.1:11018', + } + fed.init( + addresses=addresses, + party=party, + config={ + "cross_silo_comm": { + "messages_max_size_in_bytes": 100, + "grpc_channel_options": [ + ('grpc.max_send_message_length', 200), + ], + }, + }, + ) + + def _assert_on_proxy(proxy_actor): + config = ray.get(proxy_actor._get_proxy_config.remote()) + options = config['grpc_options'] + assert ("grpc.max_send_message_length", 200) in options + assert ("grpc.max_receive_message_length", 100) in options + assert ('grpc.so_reuseport', 0) in options + + sender_proxy = ray.get_actor(sender_proxy_actor_name()) + receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) + _assert_on_proxy(sender_proxy) + _assert_on_proxy(receiver_proxy) + + a = dummpy.party('alice').remote() + b = dummpy.party('bob').remote() + fed.get([a, b]) + + fed.shutdown() + ray.shutdown() + + +def test_grpc_max_size_by_both_config(): + p_alice = multiprocessing.Process(target=run3, args=('alice',)) + p_bob = multiprocessing.Process(target=run3, args=('bob',)) + p_alice.start() + p_bob.start() + p_alice.join() + p_bob.join() + assert p_alice.exitcode == 0 and p_bob.exitcode == 0 + + if __name__ == "__main__": import sys