From f49bdb0f9aa9374ba5e9d2be72c68b1f094fd1ac Mon Sep 17 00:00:00 2001 From: "Buchnik, Yehonatan" Date: Thu, 23 Jan 2025 01:42:07 -0800 Subject: [PATCH 1/3] introduce a mechainsm for relearning the agg certificate Signed-off-by: Buchnik, Yehonatan --- openfl/transport/grpc/aggregator_client.py | 46 ++++++++-------------- 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index 4b61cb888b..bee3b3c1cb 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -148,15 +148,15 @@ def wrapper(self, *args, **kwargs): while True: try: response = func(self, *args, **kwargs) + break except grpc.RpcError as e: - if e.code() == grpc.StatusCode.UNKNOWN: - self.logger.info( - f"Attempting to resend data request to aggregator at {self.uri}" - ) - elif e.code() == grpc.StatusCode.UNAUTHENTICATED: - raise - continue - break + self.logger.info( + f"Failed to send data request to aggregator at {self.uri}, error code {e.code()}" + ) + if self.refetch_server_cert_callback is not None: + self.logger.info("Refetching server certificate") + self.root_certificate = self.refetch_server_cert_callback() + self.sleeping_policy.sleep() return response return wrapper @@ -197,6 +197,7 @@ def __init__( aggregator_uuid=None, federation_uuid=None, single_col_cert_common_name=None, + refetch_server_cert_callback=None, **kwargs, ): """ @@ -226,7 +227,7 @@ def __init__( self.root_certificate = root_certificate self.certificate = certificate self.private_key = private_key - + self.sleeping_policy = ConstantBackoff(int(kwargs.get("client_reconnect_interval", 1)), getLogger(__name__), self.uri) self.logger = getLogger(__name__) if not self.use_tls: @@ -245,21 +246,8 @@ def __init__( self.aggregator_uuid = aggregator_uuid self.federation_uuid = federation_uuid self.single_col_cert_common_name = single_col_cert_common_name - - # Adding an interceptor for RPC Errors - self.interceptors = ( - RetryOnRpcErrorClientInterceptor( - sleeping_policy=ConstantBackoff( - logger=self.logger, - reconnect_interval=int(kwargs.get("client_reconnect_interval", 1)), - uri=self.uri, - ), - status_for_retry=(grpc.StatusCode.UNAVAILABLE,), - ), - ) - self.stub = aggregator_pb2_grpc.AggregatorStub( - grpc.intercept_channel(self.channel, *self.interceptors) - ) + self.refetch_server_cert_callback = refetch_server_cert_callback + self.stub = aggregator_pb2_grpc.AggregatorStub(self.channel) def create_insecure_channel(self, uri): """Set an insecure gRPC channel (i.e. no TLS) if desired. @@ -377,12 +365,10 @@ def reconnect(self): self.logger.debug("Connecting to gRPC at %s", self.uri) - self.stub = aggregator_pb2_grpc.AggregatorStub( - grpc.intercept_channel(self.channel, *self.interceptors) - ) + self.stub = aggregator_pb2_grpc.AggregatorStub(self.channel) - @_atomic_connection @_resend_data_on_reconnection + @_atomic_connection def get_tasks(self, collaborator_name): """Get tasks from the aggregator. @@ -406,8 +392,8 @@ def get_tasks(self, collaborator_name): response.quit, ) - @_atomic_connection @_resend_data_on_reconnection + @_atomic_connection def get_aggregated_tensor( self, collaborator_name, @@ -447,8 +433,8 @@ def get_aggregated_tensor( return response.tensor - @_atomic_connection @_resend_data_on_reconnection + @_atomic_connection def send_local_task_results( self, collaborator_name, From af3ed2b38b85c79a74ee5435c39ef02e683e02ce Mon Sep 17 00:00:00 2001 From: "Buchnik, Yehonatan" Date: Thu, 23 Jan 2025 02:27:42 -0800 Subject: [PATCH 2/3] formatting Signed-off-by: Buchnik, Yehonatan --- openfl/transport/grpc/aggregator_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index bee3b3c1cb..b4ece200d9 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -227,7 +227,9 @@ def __init__( self.root_certificate = root_certificate self.certificate = certificate self.private_key = private_key - self.sleeping_policy = ConstantBackoff(int(kwargs.get("client_reconnect_interval", 1)), getLogger(__name__), self.uri) + self.sleeping_policy = ConstantBackoff( + int(kwargs.get("client_reconnect_interval", 1)), getLogger(__name__), self.uri + ) self.logger = getLogger(__name__) if not self.use_tls: From 29c997a3d99e9f30b65467fc695f57ed11435503 Mon Sep 17 00:00:00 2001 From: "Buchnik, Yehonatan" Date: Thu, 23 Jan 2025 06:12:36 -0800 Subject: [PATCH 3/3] formatting Signed-off-by: Buchnik, Yehonatan --- openfl/transport/grpc/aggregator_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index b4ece200d9..225c2e4228 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -151,7 +151,7 @@ def wrapper(self, *args, **kwargs): break except grpc.RpcError as e: self.logger.info( - f"Failed to send data request to aggregator at {self.uri}, error code {e.code()}" + f"Failed to send data request to aggregator {self.uri}, error code {e.code()}" ) if self.refetch_server_cert_callback is not None: self.logger.info("Refetching server certificate")