From de829cee59bd29787e7f1502daa4d0e87541ae44 Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Tue, 4 Feb 2025 11:25:10 -0500 Subject: [PATCH] Support relay - Part 1 (#3198) Fixes # . ### Description This PR implements support of using relays in building NVFLARE cellnet. Relays are nodes in the cellnet that are only used for routing messages. They do not have any learning functions. This is the Part 1 of the relay support. It does not include provisioning functions for creating startup kits for relays. These functions will be done in future PRs. Just like regular clients, relays also register to the Server when they are started. Once successfully registered, the Server sends auth token and signature to the relay. Relay nodes also perform message authentication: relays validate auth headers for all messages going thru them. Since all cross-site messages must go thru either the server or relays (or both), by enforcing message authentication at both the Server and relays ensure that no cross-site messages can go through without valid auth headers. Currently peers of a CellPipe can only communicate via the Server. This PR includes an enhancement that allows peers to communicate via CP or Relay nodes. This can make the peer-to-peer communication more efficient. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --- nvflare/apis/fl_constant.py | 28 +- nvflare/apis/job_launcher_spec.py | 1 + .../executors/client_api_launcher_executor.py | 16 +- nvflare/app_common/utils/export_utils.py | 39 +++ .../widgets/external_configurator.py | 15 +- nvflare/app_opt/job_launcher/k8s_launcher.py | 6 +- nvflare/client/config.py | 13 +- nvflare/client/ex_process/api.py | 17 +- nvflare/fuel/f3/cellnet/connector_manager.py | 22 +- nvflare/fuel/f3/cellnet/core_cell.py | 19 +- nvflare/fuel/f3/cellnet/net_manager.py | 7 + nvflare/fuel/f3/comm_config_utils.py | 36 ++ nvflare/fuel/f3/communicator.py | 16 +- nvflare/fuel/f3/drivers/aio_grpc_driver.py | 3 +- nvflare/fuel/f3/drivers/aio_http_driver.py | 3 +- nvflare/fuel/f3/drivers/driver_params.py | 6 - nvflare/fuel/f3/drivers/grpc/utils.py | 3 +- nvflare/fuel/f3/drivers/grpc_driver.py | 3 +- nvflare/fuel/f3/drivers/net_utils.py | 7 +- nvflare/fuel/f3/drivers/tcp_driver.py | 3 +- nvflare/fuel/sec/authn.py | 52 ++- nvflare/fuel/utils/config_service.py | 7 +- nvflare/fuel/utils/pipe/cell_pipe.py | 127 ++++--- nvflare/fuel/utils/url_utils.py | 95 ++++++ nvflare/job_config/script_runner.py | 77 +++-- nvflare/lighter/constants.py | 8 +- nvflare/lighter/impl/static_file.py | 4 +- nvflare/private/defs.py | 8 +- .../private/fed/app/client/worker_process.py | 8 + .../fed/app/deployer/base_client_deployer.py | 5 +- nvflare/private/fed/app/fl_conf.py | 84 ++++- nvflare/private/fed/app/relay/relay.py | 224 ++++++++++++ .../private/fed/app/server/runner_process.py | 24 +- nvflare/private/fed/authenticator.py | 322 ++++++++++++++++++ nvflare/private/fed/client/client_executor.py | 9 +- .../private/fed/client/client_json_config.py | 19 +- nvflare/private/fed/client/communicator.py | 201 +++-------- nvflare/private/fed/client/fed_client_base.py | 42 ++- nvflare/private/fed/server/client_manager.py | 15 +- nvflare/private/fed/server/fed_server.py | 88 ++--- nvflare/private/fed/server/server_state.py | 4 +- nvflare/private/fed/server/training_cmds.py | 7 + .../private/fed/simulator/simulator_server.py | 1 - nvflare/private/fed/utils/identity_utils.py | 16 + nvflare/private/json_configer.py | 2 + nvflare/utils/job_launcher_utils.py | 6 +- tests/unit_test/client/in_process/api_test.py | 6 +- .../fuel/f3/comm_config_utils_test.py | 49 +++ tests/unit_test/fuel/utils/url_utils_test.py | 72 ++++ 49 files changed, 1468 insertions(+), 377 deletions(-) create mode 100644 nvflare/app_common/utils/export_utils.py create mode 100644 nvflare/fuel/f3/comm_config_utils.py create mode 100644 nvflare/fuel/utils/url_utils.py create mode 100644 nvflare/private/fed/app/relay/relay.py create mode 100644 nvflare/private/fed/authenticator.py create mode 100644 tests/unit_test/fuel/f3/comm_config_utils_test.py create mode 100644 tests/unit_test/fuel/utils/url_utils_test.py diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index b45cb58582..cafc8cb2e8 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -443,7 +443,6 @@ class SecureTrainConst: SSL_ROOT_CERT = "ssl_root_cert" SSL_CERT = "ssl_cert" PRIVATE_KEY = "ssl_private_key" - CONNECTION_SECURITY = "connection_security" class FLMetaKey: @@ -467,6 +466,7 @@ class FLMetaKey: class CellMessageAuthHeaderKey: CLIENT_NAME = "client_name" + SSID = "ssid" TOKEN = "__token__" TOKEN_SIGNATURE = "__token_signature__" @@ -542,6 +542,8 @@ class SystemVarName: WORKSPACE = "WORKSPACE" # directory of the workspace JOB_ID = "JOB_ID" # Job ID ROOT_URL = "ROOT_URL" # the URL of the Service Provider (server) + CP_URL = "CP_URL" # URL to CP + RELAY_URL = "RELAY_URL" # URL to relay that the CP is connected to SECURE_MODE = "SECURE_MODE" # whether the system is running in secure mode JOB_CUSTOM_DIR = "JOB_CUSTOM_DIR" # custom dir of the job PYTHONPATH = "PYTHONPATH" @@ -552,3 +554,27 @@ class RunnerTask: INIT = "init" TASK_EXEC = "task_exec" END_RUN = "end_run" + + +class ConnPropKey: + + PROJECT_NAME = "project_name" + SERVER_IDENTITY = "server_identity" + IDENTITY = "identity" + PARENT = "parent" + FQCN = "fqcn" + URL = "url" + SCHEME = "scheme" + ADDRESS = "address" + CONNECTION_SECURITY = "connection_security" + + RELAY_CONFIG = "relay_config" + CP_CONN_PROPS = "cp_conn_props" + RELAY_CONN_PROPS = "relay_conn_props" + ROOT_CONN_PROPS = "root_conn_props" + + +class ConnectionSecurity: + CLEAR = "clear" + TLS = "tls" + MTLS = "mtls" diff --git a/nvflare/apis/job_launcher_spec.py b/nvflare/apis/job_launcher_spec.py index 36edc62380..cb75130e6a 100644 --- a/nvflare/apis/job_launcher_spec.py +++ b/nvflare/apis/job_launcher_spec.py @@ -32,6 +32,7 @@ class JobProcessArgs: CLIENT_NAME = "client_name" ROOT_URL = "root_url" PARENT_URL = "parent_url" + PARENT_CONN_SEC = "parent_conn_sec" SERVICE_HOST = "service_host" SERVICE_PORT = "service_port" HA_MODE = "ha_mode" diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index 3b470edf2c..35a99e0fe0 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -15,13 +15,12 @@ import os from typing import Optional -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst from nvflare.apis.fl_context import FLContext from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.executors.launcher_executor import LauncherExecutor +from nvflare.app_common.utils.export_utils import update_export_props from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG -from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.attributes_exportable import ExportMode @@ -126,22 +125,11 @@ def prepare_config_for_launch(self, fl_ctx: FLContext): ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout, } - site_name = fl_ctx.get_identity_name() - auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") - signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") - config_data = { ConfigKey.TASK_EXCHANGE: task_exchange_attributes, - FLMetaKey.SITE_NAME: site_name, - FLMetaKey.JOB_ID: fl_ctx.get_job_id(), - FLMetaKey.AUTH_TOKEN: auth_token, - FLMetaKey.AUTH_TOKEN_SIGNATURE: signature, } - conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY) - if conn_sec: - config_data[SecureTrainConst.CONNECTION_SECURITY] = conn_sec - + update_export_props(config_data, fl_ctx) config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=config_data, config_file_path=config_file_path) diff --git a/nvflare/app_common/utils/export_utils.py b/nvflare/app_common/utils/export_utils.py new file mode 100644 index 0000000000..ba52632c9c --- /dev/null +++ b/nvflare/app_common/utils/export_utils.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey +from nvflare.apis.fl_context import FLContext +from nvflare.fuel.data_event.utils import get_scope_property + + +def update_export_props(props: dict, fl_ctx: FLContext): + site_name = fl_ctx.get_identity_name() + auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") + + props[FLMetaKey.SITE_NAME] = site_name + props[FLMetaKey.JOB_ID] = fl_ctx.get_job_id() + props[FLMetaKey.AUTH_TOKEN] = auth_token + props[FLMetaKey.AUTH_TOKEN_SIGNATURE] = signature + + root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS) + if root_conn_props: + props[ConnPropKey.ROOT_CONN_PROPS] = root_conn_props + + cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS) + if cp_conn_props: + props[ConnPropKey.CP_CONN_PROPS] = cp_conn_props + + relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS) + if relay_conn_props: + props[ConnPropKey.RELAY_CONN_PROPS] = relay_conn_props diff --git a/nvflare/app_common/widgets/external_configurator.py b/nvflare/app_common/widgets/external_configurator.py index b9b8c5b14d..648dba64eb 100644 --- a/nvflare/app_common/widgets/external_configurator.py +++ b/nvflare/app_common/widgets/external_configurator.py @@ -16,8 +16,9 @@ from typing import List from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLMetaKey +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.apis.fl_context import FLContext +from nvflare.app_common.utils.export_utils import update_export_props from nvflare.client.config import write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG from nvflare.fuel.utils.attributes_exportable import ExportMode, export_components @@ -47,9 +48,7 @@ def __init__( def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.ABOUT_TO_START_RUN: components_data = self._export_all_components(fl_ctx) - components_data[FLMetaKey.SITE_NAME] = fl_ctx.get_identity_name() - components_data[FLMetaKey.JOB_ID] = fl_ctx.get_job_id() - + update_export_props(components_data, fl_ctx) config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=components_data, config_file_path=config_file_path) @@ -65,5 +64,11 @@ def _export_all_components(self, fl_ctx: FLContext) -> dict: engine = fl_ctx.get_engine() all_components = engine.get_all_components() components = {i: all_components.get(i) for i in self._component_ids} - reserved_keys = [FLMetaKey.SITE_NAME, FLMetaKey.JOB_ID] + reserved_keys = [ + FLMetaKey.SITE_NAME, + FLMetaKey.JOB_ID, + ConnPropKey.CP_CONN_PROPS, + ConnPropKey.ROOT_CONN_PROPS, + ConnPropKey.RELAY_CONN_PROPS, + ] return export_components(components=components, reserved_keys=reserved_keys, export_mode=ExportMode.PEER) diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py index 39af3675ed..12db4891f2 100644 --- a/nvflare/app_opt/job_launcher/k8s_launcher.py +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -277,7 +277,11 @@ def get_module_args(self, job_id, fl_ctx: FLContext): def _job_args_dict(job_args: dict, arg_names: list) -> dict: result = {} for name in arg_names: - n, v = job_args[name] + e = job_args.get(name) + if not e: + continue + + n, v = e result[n] = v return result diff --git a/nvflare/client/config.py b/nvflare/client/config.py index 477b132dc3..06369a9954 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -16,7 +16,7 @@ import os from typing import Dict, Optional -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.fuel.utils.config_factory import ConfigFactory @@ -157,7 +157,16 @@ def get_heartbeat_timeout(self): ) def get_connection_security(self): - return self.config.get(SecureTrainConst.CONNECTION_SECURITY) + return self.config.get(ConnPropKey.CONNECTION_SECURITY) + + def get_root_conn_props(self): + return self.config.get(ConnPropKey.ROOT_CONN_PROPS) + + def get_cp_conn_props(self): + return self.config.get(ConnPropKey.CP_CONN_PROPS) + + def get_relay_conn_props(self): + return self.config.get(ConnPropKey.RELAY_CONN_PROPS) def get_site_name(self): return self.config.get(FLMetaKey.SITE_NAME) diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index bb94b70939..73e9328a9e 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple from nvflare.apis.analytix import AnalyticsDataType -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.apis.utils.analytix_utils import create_analytic_dxo from nvflare.app_common.abstract.fl_model import FLModel from nvflare.client.api_spec import APISpec @@ -39,9 +39,18 @@ def _create_client_config(config: str) -> ClientConfig: raise ValueError(f"config should be a string but got: {type(config)}") site_name = client_config.get_site_name() - conn_sec = client_config.get_connection_security() - if conn_sec: - set_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY, conn_sec) + + root_conn_props = client_config.get_root_conn_props() + if root_conn_props: + set_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props) + + cp_conn_props = client_config.get_cp_conn_props() + if cp_conn_props: + set_scope_property(site_name, ConnPropKey.CP_CONN_PROPS, cp_conn_props) + + relay_conn_props = client_config.get_relay_conn_props() + if relay_conn_props: + set_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props) # get message auth info and put them into Databus for CellPipe to use auth_token = client_config.get_auth_token() diff --git a/nvflare/fuel/f3/cellnet/connector_manager.py b/nvflare/fuel/f3/cellnet/connector_manager.py index 5c0e65a48f..567b6cb0ed 100644 --- a/nvflare/fuel/f3/cellnet/connector_manager.py +++ b/nvflare/fuel/f3/cellnet/connector_manager.py @@ -15,11 +15,13 @@ import time from typing import Union +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.common.excepts import ConfigError from nvflare.fuel.f3.cellnet.defs import ConnectorRequirementKey from nvflare.fuel.f3.cellnet.fqcn import FqcnInfo from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import CommError, Communicator, Mode +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.security.logging import secure_format_exception, secure_format_traceback @@ -48,6 +50,9 @@ def __init__(self, handle, connect_url: str, active: bool, params: dict): def get_connection_url(self): return self.connect_url + def get_connection_params(self): + return self.params + class ConnectorManager: """ @@ -85,6 +90,11 @@ def __init__(self, communicator: Communicator, secure: bool, comm_configurator: self.adhoc_scheme = adhoc_conf.get(_KEY_SCHEME) self.adhoc_resources = adhoc_conf.get(_KEY_RESOURCES) + # default conn sec + conn_sec = self.int_resources.get(DriverParams.CONNECTION_SECURITY) + if not conn_sec: + self.int_resources[DriverParams.CONNECTION_SECURITY] = ConnectionSecurity.CLEAR + self.logger.debug(f"internal scheme={self.int_scheme}, resources={self.int_resources}") self.logger.debug(f"adhoc scheme={self.adhoc_scheme}, resources={self.adhoc_resources}") self.comm_config = comm_config @@ -152,7 +162,7 @@ def _validate_conn_config(config: dict, key: str) -> Union[None, dict]: return conn_config def _get_connector( - self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool + self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool, conn_resources=None ) -> Union[None, ConnectorData]: if active and not url: raise RuntimeError("url is required by not provided for active connector!") @@ -193,10 +203,10 @@ def _get_connector( try: if active: - handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required) + handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required, conn_resources) connect_url = url elif url: - handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required) + handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required, conn_resources) connect_url = url else: self.logger.info(f"{os.getpid()}: Try start_listener Listener resources: {reqs}") @@ -240,11 +250,13 @@ def get_internal_listener(self) -> Union[None, ConnectorData]: """ return self._get_connector(url="", active=False, internal=True, adhoc=False, secure=False) - def get_internal_connector(self, url: str) -> Union[None, ConnectorData]: + def get_internal_connector(self, url: str, conn_resources=None) -> Union[None, ConnectorData]: """ Try to get an internal listener. Args: url: """ - return self._get_connector(url=url, active=True, internal=True, adhoc=False, secure=False) + return self._get_connector( + url=url, active=True, internal=True, adhoc=False, secure=False, conn_resources=conn_resources + ) diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 7228e7a8c8..3b5aec82c3 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -22,6 +22,7 @@ from typing import Dict, List, Tuple, Union from urllib.parse import urlparse +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.cellnet.connector_manager import ConnectorManager from nvflare.fuel.f3.cellnet.credential_manager import CredentialManager from nvflare.fuel.f3.cellnet.defs import ( @@ -43,7 +44,7 @@ from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import Communicator, MessageReceiver from nvflare.fuel.f3.connection import Connection -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState from nvflare.fuel.f3.message import Message @@ -281,6 +282,7 @@ def __init__( credentials: dict, create_internal_listener: bool = False, parent_url: str = None, + parent_resources: dict = None, max_timeout=3600, bulk_check_interval=0.5, bulk_process_interval=0.5, @@ -296,6 +298,7 @@ def __init__( max_timeout: default timeout for send_and_receive create_internal_listener: whether to create an internal listener for child cells parent_url: url for connecting to parent cell + parent_resources: extra resources for making connection to parent FQCN is the names of all ancestor, concatenated with dots. @@ -331,7 +334,7 @@ def __init__( # If configured, use it; otherwise keep the original value of 'secure'. conn_security = credentials.get(DriverParams.CONNECTION_SECURITY.value) if conn_security: - if conn_security == ConnectionSecurity.INSECURE: + if conn_security == ConnectionSecurity.CLEAR: secure = False else: secure = True @@ -370,6 +373,7 @@ def __init__( self.root_url = root_url self.create_internal_listener = create_internal_listener self.parent_url = parent_url + self.parent_resources = parent_resources self.bulk_check_interval = bulk_check_interval self.max_bulk_size = max_bulk_size self.bulk_checker = None @@ -566,7 +570,7 @@ def _set_bb_for_client_root(self): def _set_bb_for_client_child(self, parent_url: str, create_internal_listener: bool): if parent_url: - self._create_internal_connector(parent_url) + self._create_internal_connector(parent_url, self.parent_resources) if create_internal_listener: self._create_internal_listener() @@ -693,6 +697,11 @@ def get_internal_listener_url(self) -> Union[None, str]: return None return self.int_listener.get_connection_url() + def get_internal_listener_params(self) -> Union[None, dict]: + if not self.int_listener: + return None + return self.int_listener.get_connection_params() + def _add_adhoc_connector(self, to_cell: str, url: str): if self.bb_ext_connector: # it is possible that the server root offers connect url after the bb_ext_connector is created @@ -786,8 +795,8 @@ def _create_bb_external_connector(self): else: raise RuntimeError(f"{self.my_info.fqcn}: cannot create backbone external connector to {self.root_url}") - def _create_internal_connector(self, url: str): - self.bb_int_connector = self.connector_manager.get_internal_connector(url) + def _create_internal_connector(self, url: str, resources=None): + self.bb_int_connector = self.connector_manager.get_internal_connector(url, resources) if self.bb_int_connector: self.logger.info(f"{self.my_info.fqcn}: created backbone internal connector to {url} on parent") else: diff --git a/nvflare/fuel/f3/cellnet/net_manager.py b/nvflare/fuel/f3/cellnet/net_manager.py index 82a1a0d980..e25232f2d3 100644 --- a/nvflare/fuel/f3/cellnet/net_manager.py +++ b/nvflare/fuel/f3/cellnet/net_manager.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvflare.fuel.data_event.data_bus import DataBus from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.stats_pool import VALID_HIST_MODES, parse_hist_mode @@ -31,6 +32,12 @@ class NetManager(CommandModule): def __init__(self, agent: NetAgent, diagnose=False): self.agent = agent self.diagnose = diagnose + data_bus = DataBus() + data_bus.subscribe(["stop_cellnet"], self._stop_cellnet) + + def _stop_cellnet(self, topic: str, conn: Connection, db: DataBus): + self.agent.stop() + conn.append_string("Cellnet Stopped") def get_spec(self) -> CommandModuleSpec: return CommandModuleSpec( diff --git a/nvflare/fuel/f3/comm_config_utils.py b/nvflare/fuel/f3/comm_config_utils.py new file mode 100644 index 0000000000..5ffd4e07b0 --- /dev/null +++ b/nvflare/fuel/f3/comm_config_utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_constant import ConnectionSecurity +from nvflare.fuel.f3.drivers.driver_params import DriverParams + + +def requires_secure_connection(resources: dict): + """Determine whether secure connection is required based on information in resources. + + Args: + resources: a dict that contains info for making connection + + Returns: whether secure connection is required + + """ + conn_sec = resources.get(DriverParams.CONNECTION_SECURITY.value) + if conn_sec: + # if connection security is specified, it takes precedence over the "secure" flag + if conn_sec == ConnectionSecurity.CLEAR: + return False + else: + return True + else: + # Connection security is not specified, check the "secure" flag. + return resources.get(DriverParams.SECURE.value, False) diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index dbd4e86298..7cd7681f36 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import atexit +import copy import logging import os import weakref @@ -152,13 +153,14 @@ def register_message_receiver(self, app_id: int, receiver: MessageReceiver): self.conn_manager.register_message_receiver(app_id, receiver) - def add_connector(self, url: str, mode: Mode, secure: bool = False) -> (str, dict): + def add_connector(self, url: str, mode: Mode, secure: bool = False, resources=None) -> (str, dict): """Load a connector. The driver is selected based on the URL Args: url: The url to listen on or connect to, like "https://0:443". Use 0 for empty host mode: Active for connecting, Passive for listening secure: True if SSL is required. + resources: extra resources for creating connection Returns: A tuple of (A handle that can be used to delete connector, connector params) @@ -175,6 +177,8 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> (str, dic raise CommError(CommError.NOT_SUPPORTED, f"No driver found for URL {url}") params = parse_url(url) + if resources: + params.update(resources) return self.add_connector_advanced(driver_class(), mode, params, secure, False), params def start_listener(self, scheme: str, resources: dict) -> (str, str, dict): @@ -199,7 +203,9 @@ def start_listener(self, scheme: str, resources: dict) -> (str, str, dict): raise CommError(CommError.NOT_SUPPORTED, f"No driver found for scheme {scheme}") connect_url, listening_url = driver_class.get_urls(scheme, resources) - params = parse_url(listening_url) + extra_params = parse_url(listening_url) + params = copy.copy(resources) + params.update(extra_params) handle = self.add_connector_advanced(driver_class(), Mode.PASSIVE, params, False, True) @@ -223,10 +229,14 @@ def add_connector_advanced( Raises: CommError: If any errors """ - + original_conn_sec = params.get(DriverParams.CONNECTION_SECURITY) if self.local_endpoint.conn_props: params.update(self.local_endpoint.conn_props) + if original_conn_sec: + # we do not allow the connection sec to be overwritten by the endpoint's conn_props + params[DriverParams.CONNECTION_SECURITY] = original_conn_sec + params[DriverParams.SECURE] = secure handle = self.conn_manager.add_connector(driver, params, mode) diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index aa71b0441a..65544a2f84 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -22,6 +22,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers.aio_context import AioContext @@ -409,7 +410,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure_connection(resources) if secure: if use_aio_grpc(): scheme = "grpcs" diff --git a/nvflare/fuel/f3/drivers/aio_http_driver.py b/nvflare/fuel/f3/drivers/aio_http_driver.py index 61c953a867..383f95d7cb 100644 --- a/nvflare/fuel/f3/drivers/aio_http_driver.py +++ b/nvflare/fuel/f3/drivers/aio_http_driver.py @@ -18,6 +18,7 @@ import websockets from websockets.exceptions import ConnectionClosedOK +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers import net_utils @@ -120,7 +121,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure_connection(resources) if secure: scheme = "https" diff --git a/nvflare/fuel/f3/drivers/driver_params.py b/nvflare/fuel/f3/drivers/driver_params.py index 54118855e3..c03c89618d 100644 --- a/nvflare/fuel/f3/drivers/driver_params.py +++ b/nvflare/fuel/f3/drivers/driver_params.py @@ -44,12 +44,6 @@ class DriverParams(str, Enum): IMPLEMENTED_CONN_SEC = "implemented_conn_sec" -class ConnectionSecurity: - INSECURE = "insecure" - TLS = "tls" - MTLS = "mtls" - - class DriverCap(str, Enum): SEND_HEARTBEAT = "send_heartbeat" diff --git a/nvflare/fuel/f3/drivers/grpc/utils.py b/nvflare/fuel/f3/drivers/grpc/utils.py index bbad1a522f..7c1f0de896 100644 --- a/nvflare/fuel/f3/drivers/grpc/utils.py +++ b/nvflare/fuel/f3/drivers/grpc/utils.py @@ -13,8 +13,9 @@ # limitations under the License. import grpc +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.comm_config import CommConfigurator -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams def use_aio_grpc(): diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py index 7cdd444360..bc37e54fce 100644 --- a/nvflare/fuel/f3/drivers/grpc_driver.py +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -20,6 +20,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import Connection from nvflare.fuel.f3.drivers.driver import ConnectorInfo @@ -274,7 +275,7 @@ def connect(self, connector: ConnectorInfo): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure_connection(resources) if secure: if use_aio_grpc(): scheme = "nagrpcs" diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 8aa4e1f60f..a649566647 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -20,8 +20,9 @@ from typing import Any, Optional from urllib.parse import parse_qsl, urlencode, urlparse +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.comm_error import CommError -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.utils.argument_utils import str2bool from nvflare.security.logging import secure_format_exception @@ -64,6 +65,10 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: ca_path = params.get(DriverParams.CA_CERT.value) cert_path = params.get(DriverParams.SERVER_CERT.value) key_path = params.get(DriverParams.SERVER_KEY.value) + + if not cert_path or not key_path: + raise RuntimeError(f"not cert or key for SSL server: {params=}") + if conn_security == ConnectionSecurity.TLS: # do not require client auth ctx.verify_mode = ssl.CERT_NONE diff --git a/nvflare/fuel/f3/drivers/tcp_driver.py b/nvflare/fuel/f3/drivers/tcp_driver.py index f7aff1a75d..09256bca88 100644 --- a/nvflare/fuel/f3/drivers/tcp_driver.py +++ b/nvflare/fuel/f3/drivers/tcp_driver.py @@ -17,6 +17,7 @@ from socketserver import TCPServer, ThreadingTCPServer from typing import Any, Dict, List +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.drivers.base_driver import BaseDriver from nvflare.fuel.f3.drivers.driver import ConnectorInfo, Driver from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams @@ -100,7 +101,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure_connection(resources) if secure: scheme = "stcp" diff --git a/nvflare/fuel/sec/authn.py b/nvflare/fuel/sec/authn.py index 4bbe2f66b3..2aa466a5ce 100644 --- a/nvflare/fuel/sec/authn.py +++ b/nvflare/fuel/sec/authn.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvflare.apis.fl_constant import CellMessageAuthHeaderKey +from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.message import Message +from nvflare.fuel.utils.validation_utils import check_object_type, check_str -def add_authentication_headers(msg: Message, client_name: str, auth_token, token_signature): +def add_authentication_headers(msg: Message, client_name: str, auth_token, token_signature, ssid=None): """Add authentication headers to the specified message. Args: @@ -23,6 +25,7 @@ def add_authentication_headers(msg: Message, client_name: str, auth_token, token client_name: name of the client auth_token: authentication token token_signature: token signature + ssid: optional SSID Returns: @@ -30,5 +33,52 @@ def add_authentication_headers(msg: Message, client_name: str, auth_token, token if client_name: msg.set_header(CellMessageAuthHeaderKey.CLIENT_NAME, client_name) + if ssid: + msg.set_header(CellMessageAuthHeaderKey.SSID, ssid) + msg.set_header(CellMessageAuthHeaderKey.TOKEN, auth_token if auth_token else "NA") msg.set_header(CellMessageAuthHeaderKey.TOKEN_SIGNATURE, token_signature if token_signature else "NA") + + +def set_add_auth_headers_filters(cell: Cell, client_name: str, auth_token: str, token_signature: str, ssid=None): + """Set filters for adding auth headers. + + Args: + cell: the cell to add the filters to. + client_name: name of the client + auth_token: authentication token + token_signature: token signature + ssid: SSID, optional + + Returns: None + + """ + check_object_type("cell", cell, Cell) + + if client_name: + check_str("client_name", client_name) + + check_str("auth_token", auth_token) + check_str("token_signature", token_signature) + + if ssid: + check_str("ssid", ssid) + + cell.core_cell.add_outgoing_reply_filter( + channel="*", + topic="*", + cb=add_authentication_headers, + client_name=client_name, + auth_token=auth_token, + token_signature=token_signature, + ssid=ssid, + ) + cell.core_cell.add_outgoing_request_filter( + channel="*", + topic="*", + cb=add_authentication_headers, + client_name=client_name, + auth_token=auth_token, + token_signature=token_signature, + ssid=ssid, + ) diff --git a/nvflare/fuel/utils/config_service.py b/nvflare/fuel/utils/config_service.py index 7095c59e8f..d8f59aaf6c 100644 --- a/nvflare/fuel/utils/config_service.py +++ b/nvflare/fuel/utils/config_service.py @@ -116,7 +116,9 @@ def initialize(cls, section_files: Dict[str, str], config_path: List[str], parse if not os.path.isdir(d): raise ValueError(f"'{d}' is not a valid directory") - cls._config_path = config_path + for d in config_path: + if d not in cls._config_path: + cls._config_path.append(d) for section, file_basename in section_files.items(): cls._sections[section] = cls.load_config_dict(file_basename, cls._config_path) @@ -185,7 +187,8 @@ def load_configuration(cls, file_basename: str) -> Optional[Config]: Returns: config data loaded, or None if the config file is not found. """ - return ConfigFactory.load_config(file_basename, cls._config_path) + result = ConfigFactory.load_config(file_basename, cls._config_path) + return result @classmethod def load_config_dict( diff --git a/nvflare/fuel/utils/pipe/cell_pipe.py b/nvflare/fuel/utils/pipe/cell_pipe.py index 1554a0dd44..7d266166b9 100644 --- a/nvflare/fuel/utils/pipe/cell_pipe.py +++ b/nvflare/fuel/utils/pipe/cell_pipe.py @@ -17,15 +17,16 @@ import time from typing import Tuple, Union -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst, SystemVarName +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey, SystemVarName from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.cell import Message as CellMessage from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.cellnet.utils import make_reply from nvflare.fuel.f3.drivers.driver_params import DriverParams -from nvflare.fuel.sec.authn import add_authentication_headers +from nvflare.fuel.sec.authn import set_add_auth_headers_filters from nvflare.fuel.utils.attributes_exportable import ExportMode from nvflare.fuel.utils.config_service import search_file from nvflare.fuel.utils.constants import Mode @@ -44,12 +45,16 @@ _HEADER_HB_SEQ = _PREFIX + "hb_seq" -def _cell_fqcn(mode, site_name, token): +def _cell_fqcn(mode, site_name, token, parent_fqcn): # The FQCN of the cell must be unique in the whole cellnet. # We use the combination of mode, site_name, and token to derive the value of FQCN # Since the token is usually used across all sites, the "site_name" differentiate cell on one site from another. # The two peer pipes on the same site share the same site_name and token, but are differentiated by their modes. - return f"{site_name}_{token}_{mode}" + base = f"{site_name}_{token}_{mode}" + if parent_fqcn == FQCN.ROOT_SERVER: + return base + else: + return FQCN.join([parent_fqcn, base]) def _to_cell_message(msg: Message, extra=None) -> CellMessage: @@ -77,8 +82,11 @@ class _CellInfo: A cell could be used by multiple pipes (e.g. one pipe for task interaction, another for metrics logging). """ - def __init__(self, cell, net_agent): + def __init__(self, site_name, cell, net_agent, auth_token, token_signature): + self.site_name = site_name self.cell = cell + self.auth_token = auth_token + self.token_signature = token_signature self.net_agent = net_agent self.started = False self.pipes = [] @@ -114,21 +122,15 @@ class CellPipe(Pipe): _lock = threading.Lock() _cells_info = {} # (root_url, site_name, token) => _CellInfo - _auth_token = None - _token_signature = None - _site_name = None @classmethod - def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_dir): + def _build_cell(cls, site_name, fqcn, parent_conn_props, secure_mode, workspace_dir, logger): """Build a cell if necessary. The combination of (root_url, site_name, token) uniquely determine one cell. There can be multiple pipes on the same cell. Args: - root_url: root url of the cell net - mode: mode (passive or active) of the pipe - site_name: name of the site - token: the unique token + parent_conn_props: parent for this cell secure_mode: whether cellnet is in secure mode workspace_dir: workspace that contains startup kit for connecting to server. Needed only if secure_mode @@ -136,11 +138,8 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di """ with cls._lock: - cls._site_name = site_name - cell_key = f"{root_url}.{site_name}.{token}" - ci = cls._cells_info.get(cell_key) + ci = cls._cells_info.get(fqcn) if not ci: - credentials = {} if secure_mode: root_cert_path = search_file(SSL_ROOT_CERT, workspace_dir) if not root_cert_path: @@ -149,35 +148,42 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di credentials = { DriverParams.CA_CERT.value: root_cert_path, } + else: + credentials = {} - conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY) - if conn_sec: - credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec + conn_sec = parent_conn_props.get(ConnPropKey.CONNECTION_SECURITY) + if conn_sec: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec + + parent_url = parent_conn_props.get(ConnPropKey.URL) + + if FQCN.get_parent(fqcn): + # the cell has a parent: connect to the parent + cell_root = None + cell_parent_url = parent_url + else: + # the cell has no parent: the parent_url is the root of the cellnet + cell_root = parent_url + cell_parent_url = None cell = Cell( - fqcn=_cell_fqcn(mode, site_name, token), - root_url=root_url, + fqcn=fqcn, + root_url=cell_root, secure=secure_mode, credentials=credentials, + parent_url=cell_parent_url, create_internal_listener=False, ) - # set filter to add additional auth headers - cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=cls._add_auth_headers) - cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=cls._add_auth_headers) + auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + token_signature = get_scope_property(site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") net_agent = NetAgent(cell) - ci = _CellInfo(cell, net_agent) - cls._cells_info[cell_key] = ci - return ci - - @classmethod - def _add_auth_headers(cls, message: CellMessage): - if not cls._auth_token: - cls._auth_token = get_scope_property(scope_name=cls._site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") - cls._token_signature = get_scope_property(cls._site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") + ci = _CellInfo(site_name, cell, net_agent, auth_token, token_signature) + cls._cells_info[fqcn] = ci - add_authentication_headers(message, cls._site_name, cls._auth_token, cls._token_signature) + set_add_auth_headers_filters(cell, ci.site_name, ci.auth_token, ci.token_signature) + return ci def __init__( self, @@ -203,9 +209,9 @@ def __init__( self.site_name = site_name self.token = token - self.root_url = root_url self.secure_mode = secure_mode self.workspace_dir = workspace_dir + self.root_url = root_url # this section is needed by job config to prevent building cell when using SystemVarName arguments # TODO: enhance this part @@ -219,8 +225,49 @@ def __init__( check_str("site_name", site_name) check_str("workspace_dir", workspace_dir) + # determine the endpoint for this pipe to connect to + root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS) + + if root_conn_props: + # Not in simulator + if not isinstance(root_conn_props, dict): + raise RuntimeError(f"expect root_conn_props for {site_name} to be dict but got {type(root_conn_props)}") + + cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS) + if cp_conn_props: + if not isinstance(cp_conn_props, dict): + raise RuntimeError(f"expect cp_conn_props to be dict but got {type(cp_conn_props)}") + + url_to_conns = { + root_conn_props.get(ConnPropKey.URL): root_conn_props, + cp_conn_props.get(ConnPropKey.URL): cp_conn_props, + } + + relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS) + if relay_conn_props: + if not isinstance(relay_conn_props, dict): + raise RuntimeError(f"expect relay_conn_props to be dict but got {type(relay_conn_props)}") + url_to_conns[relay_conn_props.get(ConnPropKey.URL)] = relay_conn_props + + if not root_url: + # root_url not specified - use CP! + root_url = cp_conn_props.get(ConnPropKey.URL) + self.root_url = root_url + + conn_props = url_to_conns.get(self.root_url) + if not conn_props: + raise RuntimeError(f"cannot determine conn props for '{root_url}'") + else: + # this is running in simulator + conn_props = { + ConnPropKey.URL: root_url, + ConnPropKey.FQCN: FQCN.ROOT_SERVER, + } + mode = f"{mode}".strip().lower() # convert to lower case string - self.ci = self._build_cell(mode, root_url, site_name, token, secure_mode, workspace_dir) + fqcn = _cell_fqcn(mode, site_name, token, conn_props.get(ConnPropKey.FQCN)) + + self.ci = self._build_cell(site_name, fqcn, conn_props, secure_mode, workspace_dir, self.logger) self.cell = self.ci.cell self.ci.add_pipe(self) @@ -231,7 +278,7 @@ def __init__( else: raise ValueError(f"invalid mode {mode} - must be 'active' or 'passive'") - self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token) + self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token, conn_props.get(ConnPropKey.FQCN)) self.received_msgs = queue.Queue() # contains Message(s), not CellMessage(s)! self.channel = None # the cellnet message channel self.pipe_lock = threading.Lock() # used to ensure no msg to be sent after closed @@ -366,16 +413,14 @@ def close(self): def export(self, export_mode: str) -> Tuple[str, dict]: if export_mode == ExportMode.SELF: mode = self.mode - root_url = self.root_url else: mode = Mode.ACTIVE if self.mode == Mode.PASSIVE else Mode.PASSIVE - root_url = self.cell.get_root_url_for_child() export_args = { "mode": mode, "site_name": self.site_name, "token": self.token, - "root_url": root_url, + "root_url": self.root_url, "secure_mode": self.cell.core_cell.secure, "workspace_dir": self.workspace_dir, } diff --git a/nvflare/fuel/utils/url_utils.py b/nvflare/fuel/utils/url_utils.py new file mode 100644 index 0000000000..e80b7fbfa9 --- /dev/null +++ b/nvflare/fuel/utils/url_utils.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_SECURE_SCHEME_MAPPING = {"tcp": "stcp", "grpc": "grpcs", "http": "https"} +_CLEAR_SCHEME_MAPPING = {"stcp": "tcp", "grpcs": "grpc", "https": "http"} + + +def make_url(scheme: str, address, secure: bool) -> str: + """Make a full URL based on specified info + + Args: + scheme: scheme of the url + address: host address. Multiple formats are supported: + str: this is a string that contains host name and optionally port number (e.g. localhost:1234) + dict: contains item "host" and optionally "port" + tuple or list: contains 1 or 2 items for host and port + secure: whether secure connection is required + + Returns: + + """ + if secure: + if scheme in _SECURE_SCHEME_MAPPING.values(): + # already secure scheme + secure_scheme = scheme + else: + secure_scheme = _SECURE_SCHEME_MAPPING.get(scheme) + + if not secure_scheme: + raise ValueError(f"unsupported scheme '{scheme}'") + + scheme = secure_scheme + else: + if scheme in _CLEAR_SCHEME_MAPPING.values(): + # already clear scheme + clear_scheme = scheme + else: + clear_scheme = _CLEAR_SCHEME_MAPPING.get(scheme) + + if not clear_scheme: + raise ValueError(f"unsupported scheme '{scheme}'") + + scheme = clear_scheme + + if isinstance(address, str): + if not address: + raise ValueError("address must not be empty") + return f"{scheme}://{address}" + else: + port = None + if isinstance(address, (tuple, list)): + if len(address) < 1: + raise ValueError("address must not be empty") + if len(address) > 2: + raise ValueError(f"invalid address {address}") + host = address[0] + if len(address) > 1: + port = address[1] + elif isinstance(address, dict): + if len(address) < 1: + raise ValueError("address must not be empty") + if len(address) > 2: + raise ValueError(f"invalid address {address}") + + host = address.get("host") + if not host: + raise ValueError(f"invalid address {address}: missing 'host'") + + port = address.get("port") + if not port and len(address) > 1: + raise ValueError(f"invalid address {address}: missing 'port'") + else: + raise ValueError(f"invalid address: {address}") + + if not isinstance(host, str): + raise ValueError(f"invalid host '{host}': must be str but got {type(host)}") + + if port: + if not isinstance(port, (str, int)): + raise ValueError(f"invalid port '{port}': must be str or int but got {type(port)}") + port_str = f":{port}" + else: + port_str = "" + return f"{scheme}://{host}{port_str}" diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index d39f65db61..7b3ca4f40b 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -14,6 +14,7 @@ from typing import Optional, Type, Union +from nvflare.apis.fl_constant import SystemVarName from nvflare.app_common.abstract.launcher import Launcher from nvflare.app_common.executors.client_api_launcher_executor import ClientAPILauncherExecutor from nvflare.app_common.executors.in_process_client_api_executor import InProcessClientAPIExecutor @@ -22,8 +23,9 @@ from nvflare.app_common.widgets.metric_relay import MetricRelay from nvflare.client.config import ExchangeFormat, TransferType from nvflare.fuel.utils.import_utils import optional_import -from nvflare.fuel.utils.pipe.cell_pipe import CellPipe +from nvflare.fuel.utils.pipe.cell_pipe import CellPipe, Mode from nvflare.fuel.utils.pipe.pipe import Pipe +from nvflare.fuel.utils.validation_utils import check_str from .api import FedJob, validate_object_for_job @@ -35,6 +37,19 @@ class FrameworkType: TENSORFLOW = "tensorflow" +class PipeConnectType: + VIA_ROOT = "via_root" + VIA_CP = "via_cp" + VIA_RELAY = "via_relay" + + +_PIPE_CONNECT_URL = { + PipeConnectType.VIA_CP: "{" + SystemVarName.CP_URL + "}", + PipeConnectType.VIA_RELAY: "{" + SystemVarName.RELAY_URL + "}", + PipeConnectType.VIA_ROOT: "{" + SystemVarName.ROOT_URL + "}", +} + + class BaseScriptRunner: def __init__( self, @@ -52,6 +67,7 @@ def __init__( launcher: Optional[Launcher] = None, metric_relay: Optional[MetricRelay] = None, metric_pipe: Optional[Pipe] = None, + pipe_connect_type: str = None, ): """BaseScriptRunner is used with FedJob API to run or launch a script. @@ -102,6 +118,12 @@ def __init__( metric_pipe (Optional[Pipe], optional): An optional Pipe instance for passing metric data between components. This allows for real-time metric handling during execution. Defaults to `None`. + + pipe_connect_type: how pipe peers are to be connected: + Via Root: peers are both connected to the root of the cellnet + Via Relay: peers are both connected to the relay if a relay is used; otherwise via root. + Via CP: peers are both connected to the CP + If not specified, will be via CP. """ self._script = script self._script_args = script_args @@ -112,6 +134,7 @@ def __init__( self._params_exchange_format = params_exchange_format self._from_nvflare_converter_id = from_nvflare_converter_id self._to_nvflare_converter_id = to_nvflare_converter_id + self._pipe_connect_type = pipe_connect_type if self._framework == FrameworkType.PYTORCH: _, torch_ok = optional_import(module="torch") @@ -151,12 +174,35 @@ def __init__( elif executor is not None: validate_object_for_job("executor", executor, InProcessClientAPIExecutor) + if pipe_connect_type: + check_str("pipe_connect_type", pipe_connect_type) + valid_connect_types = [PipeConnectType.VIA_CP, PipeConnectType.VIA_RELAY, PipeConnectType.VIA_RELAY] + if pipe_connect_type not in valid_connect_types: + raise ValueError(f"invalid pipe_connect_type '{pipe_connect_type}': must be {valid_connect_types}") + self._metric_pipe = metric_pipe self._metric_relay = metric_relay self._task_pipe = task_pipe self._executor = executor self._launcher = launcher + def _create_cell_pipe(self): + ct = self._pipe_connect_type + if not ct: + ct = PipeConnectType.VIA_CP + conn_url = _PIPE_CONNECT_URL.get(ct) + if not conn_url: + raise RuntimeError(f"cannot determine pipe connect url for {self._pipe_connect_type}") + + return CellPipe( + mode=Mode.PASSIVE, + site_name="{" + SystemVarName.SITE_NAME + "}", + token="{" + SystemVarName.JOB_ID + "}", + root_url=conn_url, + secure_mode="{" + SystemVarName.SECURE_MODE + "}", + workspace_dir="{" + SystemVarName.WORKSPACE + "}", + ) + def add_to_fed_job(self, job: FedJob, ctx, **kwargs): """This method is used by Job API. @@ -172,18 +218,7 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): comp_ids = {} if self._launch_external_process: - task_pipe = ( - self._task_pipe - if self._task_pipe - else CellPipe( - mode="PASSIVE", - site_name="{SITE_NAME}", - token="{JOB_ID}", - root_url="{ROOT_URL}", - secure_mode="{SECURE_MODE}", - workspace_dir="{WORKSPACE}", - ) - ) + task_pipe = self._task_pipe if self._task_pipe else self._create_cell_pipe() task_pipe_id = job.add_component("pipe", task_pipe, ctx) comp_ids["pipe_id"] = task_pipe_id @@ -211,18 +246,7 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): ) job.add_executor(executor, tasks=tasks, ctx=ctx) - metric_pipe = ( - self._metric_pipe - if self._metric_pipe - else CellPipe( - mode="PASSIVE", - site_name="{SITE_NAME}", - token="{JOB_ID}", - root_url="{ROOT_URL}", - secure_mode="{SECURE_MODE}", - workspace_dir="{WORKSPACE}", - ) - ) + metric_pipe = self._metric_pipe if self._metric_pipe else self._create_cell_pipe() metric_pipe_id = job.add_component("metrics_pipe", metric_pipe, ctx) comp_ids["metric_pipe_id"] = metric_pipe_id @@ -295,6 +319,7 @@ def __init__( framework: FrameworkType = FrameworkType.PYTORCH, params_exchange_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: str = TransferType.FULL, + pipe_connect_type: str = PipeConnectType.VIA_CP, ): """ScriptRunner is used with FedJob API to run or launch a script. @@ -310,6 +335,7 @@ def __init__( params_exchange_format (str): The format to exchange the parameters. Defaults to ExchangeFormat.NUMPY. params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. + pipe_connect_type (str): how pipe peers are to be connected """ super().__init__( script=script, @@ -319,4 +345,5 @@ def __init__( framework=framework, params_exchange_format=params_exchange_format, params_transfer_type=params_transfer_type, + pipe_connect_type=pipe_connect_type, ) diff --git a/nvflare/lighter/constants.py b/nvflare/lighter/constants.py index 99d765f770..76131ff14c 100644 --- a/nvflare/lighter/constants.py +++ b/nvflare/lighter/constants.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvflare.apis.fl_constant import ConnectionSecurity class WorkDir: @@ -66,10 +67,9 @@ class ProvisionMode: class ConnSecurity: - CLEAR = "clear" - INSECURE = "insecure" - TLS = "tls" - MTLS = "mtls" + CLEAR = ConnectionSecurity.CLEAR + TLS = ConnectionSecurity.TLS + MTLS = ConnectionSecurity.MTLS class AdminRole: diff --git a/nvflare/lighter/impl/static_file.py b/nvflare/lighter/impl/static_file.py index c80208ba6a..7c2f00dce8 100644 --- a/nvflare/lighter/impl/static_file.py +++ b/nvflare/lighter/impl/static_file.py @@ -115,7 +115,7 @@ def _build_overseer(self, overseer: Participant, ctx: ProvisionContext): @staticmethod def _build_conn_properties(site: Participant, ctx: ProvisionContext, site_config: dict): - valid_values = [ConnSecurity.CLEAR, ConnSecurity.INSECURE, ConnSecurity.TLS, ConnSecurity.MTLS] + valid_values = [ConnSecurity.CLEAR, ConnSecurity.TLS, ConnSecurity.MTLS] conn_security = site.get_prop_fb(PropKey.CONN_SECURITY) if conn_security: assert isinstance(conn_security, str) @@ -124,8 +124,6 @@ def _build_conn_properties(site: Participant, ctx: ProvisionContext, site_config if conn_security not in valid_values: raise ValueError(f"invalid connection_security '{conn_security}': must be in {valid_values}") - if conn_security in [ConnSecurity.CLEAR, ConnSecurity.INSECURE]: - conn_security = ConnSecurity.INSECURE site_config["connection_security"] = conn_security custom_ca_cert = site.get_prop_fb(PropKey.CUSTOM_CA_CERT) diff --git a/nvflare/private/defs.py b/nvflare/private/defs.py index 0a452e262a..3a2af90122 100644 --- a/nvflare/private/defs.py +++ b/nvflare/private/defs.py @@ -143,11 +143,12 @@ class AppFolderConstants: class CellMessageHeaderKeys: CLIENT_NAME = CellMessageAuthHeaderKey.CLIENT_NAME + CLIENT_TYPE = "client_type" TOKEN = CellMessageAuthHeaderKey.TOKEN TOKEN_SIGNATURE = CellMessageAuthHeaderKey.TOKEN_SIGNATURE CLIENT_IP = "client_ip" PROJECT_NAME = "project_name" - SSID = "ssid" + SSID = CellMessageAuthHeaderKey.SSID UNAUTHENTICATED = "unauthenticated" JOB_ID = "job_id" JOB_IDS = "job_ids" @@ -155,6 +156,11 @@ class CellMessageHeaderKeys: ABORT_JOBS = "abort_jobs" +class ClientType: + RELAY = "relay" + REGULAR = "regular" + + AUTH_CLIENT_NAME_FOR_SJ = "server_job" diff --git a/nvflare/private/fed/app/client/worker_process.py b/nvflare/private/fed/app/client/worker_process.py index b43133abe5..4aa8bea5a6 100644 --- a/nvflare/private/fed/app/client/worker_process.py +++ b/nvflare/private/fed/app/client/worker_process.py @@ -154,6 +154,14 @@ def parse_arguments(): parser.add_argument("--sp_target", "-g", type=str, help="Sp target", required=True) parser.add_argument("--sp_scheme", "-scheme", type=str, help="Sp connection scheme", required=True) parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True) + parser.add_argument( + "--parent_conn_sec", + "-pcs", + type=str, + help="parent conn security", + required=False, + default="", + ) parser.add_argument( "--fed_client", "-s", type=str, help="an aggregation server specification json file", required=True ) diff --git a/nvflare/private/fed/app/deployer/base_client_deployer.py b/nvflare/private/fed/app/deployer/base_client_deployer.py index b4c684c093..353d63ce20 100644 --- a/nvflare/private/fed/app/deployer/base_client_deployer.py +++ b/nvflare/private/fed/app/deployer/base_client_deployer.py @@ -45,8 +45,9 @@ def build(self, build_ctx): self.components = build_ctx["client_components"] self.handlers = build_ctx["client_handlers"] - def set_model_manager(self, model_manager): - self.model_manager = model_manager + relay_config = build_ctx.get("relay_config") + if relay_config: + self.client_config["relay_config"] = relay_config def create_fed_client(self, args, sp_target=None): if sp_target: diff --git a/nvflare/private/fed/app/fl_conf.py b/nvflare/private/fed/app/fl_conf.py index 7a60876637..d87a2d9b48 100644 --- a/nvflare/private/fed/app/fl_conf.py +++ b/nvflare/private/fed/app/fl_conf.py @@ -19,11 +19,14 @@ import sys from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FilterKey, SiteType, SystemConfigs +from nvflare.apis.fl_constant import ConnectionSecurity, ConnPropKey, FilterKey, SiteType, SystemConfigs from nvflare.apis.workspace import Workspace +from nvflare.fuel.data_event.utils import set_scope_property +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node +from nvflare.fuel.utils.url_utils import make_url from nvflare.fuel.utils.wfconf import ConfigContext, ConfigError from nvflare.private.defs import SSLConstants from nvflare.private.json_configer import JsonConfigurator @@ -225,6 +228,8 @@ def __init__(self, workspace: Workspace, args, kv_list=None): config_files = workspace.get_config_files_for_startup(is_server=False, for_job=True if args.job_id else False) + print(f"got all config files: {config_files}") + JsonConfigurator.__init__( self, config_file_name=config_files, @@ -283,6 +288,72 @@ def build_component(self, config_dict): self.handlers.append(t) return t + def _determine_conn_props(self, client_name, config_data: dict): + relay_fqcn = None + relay_url = None + relay_conn_security = None + + # relay info is set in the client's relay__resources.json. + # If relay is used, then connect via the specified relay; if not, try to connect the Server directly + print(f"Config data: {config_data=}") + print(f"Args: {self.args=}") + relay_config = config_data.get(ConnPropKey.RELAY_CONFIG) + self.logger.info(f"got relay config: {relay_config}") + if relay_config: + if relay_config: + relay_fqcn = relay_config.get(ConnPropKey.FQCN) + scheme = relay_config.get(ConnPropKey.SCHEME) + addr = relay_config.get(ConnPropKey.ADDRESS) + relay_conn_security = relay_config.get(ConnPropKey.CONNECTION_SECURITY) + secure = True + if relay_conn_security == ConnectionSecurity.CLEAR: + secure = False + relay_url = make_url(scheme, addr, secure) + print(f"connect to server via relay: {relay_url=} {relay_fqcn=}") + else: + print("no relay defined: connect to server directly") + else: + print("no relay_config: connect to server directly") + + if relay_fqcn: + cp_fqcn = FQCN.join([relay_fqcn, client_name]) + else: + cp_fqcn = client_name + + if relay_fqcn: + relay_conn_props = { + ConnPropKey.FQCN: relay_fqcn, + ConnPropKey.URL: relay_url, + ConnPropKey.CONNECTION_SECURITY: relay_conn_security, + } + set_scope_property(client_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props) + + client = self.config_data["client"] + + if hasattr(self.args, "job_id") and self.args.job_id: + # this is CJ + sp_scheme = self.args.sp_scheme + sp_target = self.args.sp_target + root_url = f"{sp_scheme}://{sp_target}" + root_conn_props = { + ConnPropKey.FQCN: FQCN.ROOT_SERVER, + ConnPropKey.URL: root_url, + ConnPropKey.CONNECTION_SECURITY: client.get(ConnPropKey.CONNECTION_SECURITY), + } + set_scope_property(client_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props) + + cp_conn_props = { + ConnPropKey.FQCN: cp_fqcn, + ConnPropKey.URL: self.args.parent_url, + ConnPropKey.CONNECTION_SECURITY: self.args.parent_conn_sec, + } + else: + # this is CP + cp_conn_props = { + ConnPropKey.FQCN: cp_fqcn, + } + set_scope_property(client_name, ConnPropKey.CP_CONN_PROPS, cp_conn_props) + def start_config(self, config_ctx: ConfigContext): """Start the config process. @@ -301,6 +372,17 @@ def start_config(self, config_ctx: ConfigContext): client[SSLConstants.CERT] = self.workspace.get_file_path_in_startup(client[SSLConstants.CERT]) if client.get(SSLConstants.ROOT_CERT): client[SSLConstants.ROOT_CERT] = self.workspace.get_file_path_in_startup(client[SSLConstants.ROOT_CERT]) + + client_name = self.cmd_vars.get("uid", None) + if not client_name: + raise ConfigError("missing 'uid' from command args") + + conn_sec = client.get(ConnPropKey.CONNECTION_SECURITY) + if conn_sec: + set_scope_property(client_name, ConnPropKey.CONNECTION_SECURITY, conn_sec) + + self._determine_conn_props(client_name, self.config_data) + except Exception: raise ValueError(f"Client config error: '{self.client_config_file_names}'") diff --git a/nvflare/private/fed/app/relay/relay.py b/nvflare/private/fed/app/relay/relay.py new file mode 100644 index 0000000000..1fc2d4a0da --- /dev/null +++ b/nvflare/private/fed/app/relay/relay.py @@ -0,0 +1,224 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import sys +import threading + +from nvflare.apis.fl_constant import ConnectionSecurity, ConnPropKey, ReservedKey, WorkspaceConstants +from nvflare.apis.fl_context import FLContext +from nvflare.apis.signal import Signal +from nvflare.apis.utils.decomposers import flare_decomposers +from nvflare.apis.workspace import Workspace +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.cellnet.net_agent import NetAgent +from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.f3.drivers.net_utils import SSL_ROOT_CERT, enhance_credential_info +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm +from nvflare.fuel.sec.authn import set_add_auth_headers_filters +from nvflare.fuel.utils.argument_utils import parse_vars +from nvflare.fuel.utils.config_service import ConfigService, search_file +from nvflare.fuel.utils.log_utils import configure_logging +from nvflare.fuel.utils.url_utils import make_url +from nvflare.private.defs import ClientType +from nvflare.private.fed.authenticator import Authenticator, validate_auth_headers +from nvflare.private.fed.utils.identity_utils import TokenVerifier + + +class CellnetMonitor: + def __init__(self, stop_event: threading.Event, workspace: str): + self.stop_event = stop_event + self.workspace = workspace + + def cellnet_stopped(self): + touch_file = os.path.join(self.workspace, WorkspaceConstants.SHUTDOWN_FILE) + with open(touch_file, "a"): + os.utime(touch_file, None) + self.stop_event.set() + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True) + parser.add_argument("--relay_config", "-s", type=str, help="relay config json file", required=True) + parser.add_argument("--set", metavar="KEY=VALUE", nargs="*") + args = parser.parse_args() + return args + + +def main(args): + workspace = Workspace(root_dir=args.workspace) + for name in [WorkspaceConstants.RESTART_FILE, WorkspaceConstants.SHUTDOWN_FILE]: + try: + f = workspace.get_file_path_in_root(name) + if os.path.exists(f): + os.remove(f) + except Exception as ex: + print(f"Could not remove file '{name}': {ex}. Please check your system before starting FL.") + sys.exit(-1) + + configure_logging(workspace, workspace.get_root_dir()) + logger = logging.getLogger() + + relay_config_file = workspace.get_file_path_in_startup(args.relay_config) + with open(relay_config_file, "rt") as f: + relay_config = json.load(f) + + if not isinstance(relay_config, dict): + raise RuntimeError(f"invalid relay config file {args.relay_config}") + + project_name = relay_config.get(ConnPropKey.PROJECT_NAME) + if not project_name: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {ConnPropKey.PROJECT_NAME}") + + server_identity = relay_config.get(ConnPropKey.SERVER_IDENTITY) + if not server_identity: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {ConnPropKey.SERVER_IDENTITY}") + + my_identity = relay_config.get(ConnPropKey.IDENTITY) + if not my_identity: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {ConnPropKey.IDENTITY}") + + parent = relay_config.get(ConnPropKey.PARENT) + if not parent: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {ConnPropKey.PARENT}") + + parent_address = parent.get(ConnPropKey.ADDRESS) + if not parent_address: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.address") + + parent_scheme = parent.get(ConnPropKey.SCHEME) + if not parent_scheme: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.scheme") + + parent_fqcn = parent.get(ConnPropKey.FQCN) + if not parent_fqcn: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.fqcn") + + cmd_vars = parse_vars(args.set) + secure_train = cmd_vars.get("secure_train", False) + logger.info(f"{cmd_vars=} {secure_train=}") + + stop_event = threading.Event() + monitor = CellnetMonitor(stop_event, args.workspace) + + ConfigService.initialize( + section_files={}, + config_path=[args.workspace], + ) + + root_cert_path = search_file(SSL_ROOT_CERT, args.workspace) + if not root_cert_path: + raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {args.workspace}") + + credentials = { + DriverParams.CA_CERT.value: root_cert_path, + } + enhance_credential_info(credentials) + + logger.info(f"{credentials=}") + + conn_security = parent.get(ConnPropKey.CONNECTION_SECURITY) + secure_conn = True + if conn_security: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security + if conn_security == ConnectionSecurity.CLEAR: + secure_conn = False + parent_url = make_url(parent_scheme, parent_address, secure_conn) + + if parent_fqcn == FQCN.ROOT_SERVER: + my_fqcn = my_identity + root_url = parent_url + parent_url = None + else: + my_fqcn = FQCN.join([parent_fqcn, my_identity]) + root_url = None + + flare_decomposers.register() + + cell = Cell( + fqcn=my_fqcn, + root_url=root_url, + secure=secure_conn, + credentials=credentials, + create_internal_listener=True, + parent_url=parent_url, + ) + NetAgent(cell, agent_closed_cb=monitor.cellnet_stopped) + cell.start() + + # authenticate + authenticator = Authenticator( + cell=cell, + project_name=project_name, + client_name=my_identity, + client_type=ClientType.RELAY, + expected_sp_identity=server_identity, + secure_mode=secure_train, + root_cert_file=credentials.get(DriverParams.CA_CERT.value), + private_key_file=credentials.get(DriverParams.CLIENT_KEY.value), + cert_file=credentials.get(DriverParams.CLIENT_CERT.value), + msg_timeout=5.0, + retry_interval=2.0, + ) + + abort_signal = Signal() + shared_fl_ctx = FLContext() + shared_fl_ctx.set_public_props({ReservedKey.IDENTITY_NAME: my_identity}) + token, token_signature, ssid, token_verifier = authenticator.authenticate( + shared_fl_ctx=shared_fl_ctx, + abort_signal=abort_signal, + ) + + if secure_train: + if not isinstance(token_verifier, TokenVerifier): + raise RuntimeError(f"expect token_verifier to be TokenVerifier but got {type(token_verifier)}") + + set_add_auth_headers_filters(cell, my_identity, token, token_signature, ssid) + + cell.core_cell.add_incoming_filter( + channel="*", + topic="*", + cb=_validate_auth_headers, + token_verifier=token_verifier, + logger=logger, + ) + + logger.info(f"Successfully authenticated to {server_identity}: {token=} {ssid=}") + + # wait until stopped + logger.info(f"Started relay {my_identity=} {my_fqcn=} {root_url=} {parent_url=} {parent_fqcn=}") + stop_event.wait() + cell.stop() + logger.info(f"Relay {my_fqcn} stopped.") + + +def _validate_auth_headers(message: CellMessage, token_verifier: TokenVerifier, logger): + """Validate auth headers from messages that go through the server. + Args: + message: the message to validate + Returns: + """ + return validate_auth_headers(message, token_verifier, logger) + + +if __name__ == "__main__": + args = parse_arguments() + rc = mpm.run(main_func=main, run_dir=args.workspace, args=args) + sys.exit(rc) diff --git a/nvflare/private/fed/app/server/runner_process.py b/nvflare/private/fed/app/server/runner_process.py index c0ac5aa576..f93712e4a5 100644 --- a/nvflare/private/fed/app/server/runner_process.py +++ b/nvflare/private/fed/app/server/runner_process.py @@ -23,13 +23,12 @@ from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SiteType, SystemConfigs from nvflare.apis.workspace import Workspace from nvflare.fuel.common.excepts import ConfigError -from nvflare.fuel.f3.message import Message as CellMessage from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm -from nvflare.fuel.sec.authn import add_authentication_headers +from nvflare.fuel.sec.authn import set_add_auth_headers_filters from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import configure_logging, get_script_logger -from nvflare.private.defs import AUTH_CLIENT_NAME_FOR_SJ, AppFolderConstants, CellMessageHeaderKeys +from nvflare.private.defs import AUTH_CLIENT_NAME_FOR_SJ, AppFolderConstants from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger from nvflare.private.fed.app.utils import monitor_parent_process from nvflare.private.fed.server.server_app_runner import ServerAppRunner @@ -112,8 +111,13 @@ def main(args): ) # set filter to add additional auth headers - server.cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=_add_auth_headers, config=args) - server.cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=_add_auth_headers, config=args) + set_add_auth_headers_filters( + cell=server.cell, + client_name=AUTH_CLIENT_NAME_FOR_SJ, + auth_token=args.job_id, + token_signature=args.token_signature, + ssid=args.ssid, + ) server.server_state = HotState(host=args.host, port=args.port, ssid=args.ssid) @@ -145,16 +149,6 @@ def main(args): raise e -def _add_auth_headers(message: CellMessage, config): - message.set_header(CellMessageHeaderKeys.SSID, config.ssid) - add_authentication_headers( - message, - client_name=AUTH_CLIENT_NAME_FOR_SJ, - auth_token=config.job_id, - token_signature=config.token_signature, - ) - - def parse_arguments(): """FL Server program starting point.""" parser = argparse.ArgumentParser() diff --git a/nvflare/private/fed/authenticator.py b/nvflare/private/fed/authenticator.py new file mode 100644 index 0000000000..8a905b4963 --- /dev/null +++ b/nvflare/private/fed/authenticator.py @@ -0,0 +1,322 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import socket +import time +import traceback +import uuid + +from nvflare.apis.fl_constant import ServerCommandKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.fl_exception import FLCommunicationError +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.cellnet.core_cell import make_reply as make_cellnet_reply +from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey +from nvflare.fuel.f3.cellnet.defs import ReturnCode +from nvflare.fuel.f3.cellnet.defs import ReturnCode as F3ReturnCode +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.fuel.utils.log_utils import get_obj_logger +from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, new_cell_message +from nvflare.private.fed.utils.identity_utils import IdentityAsserter, IdentityVerifier, TokenVerifier, load_crt_bytes + + +def _get_client_ip(): + """Return localhost IP. + + More robust than ``socket.gethostbyname(socket.gethostname())``. See + https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib/28950776#28950776 + for more details. + + Returns: + The host IP + + """ + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("10.255.255.255", 1)) # doesn't even have to be reachable + ip = s.getsockname()[0] + except Exception: + ip = "127.0.0.1" + finally: + s.close() + return ip + + +class Authenticator: + def __init__( + self, + cell: Cell, + project_name: str, + client_name: str, + client_type: str, + expected_sp_identity: str, + secure_mode: bool, + root_cert_file: str, + private_key_file: str, + cert_file: str, + msg_timeout: float, + retry_interval: float, + ): + """Authenticator is to be used to register a client to the Server. + + Args: + cell: the communication cell + project_name: name of the project + client_name: name of the client + client_type: type of the client: regular or relay + expected_sp_identity: identity of the service provider (i.e. server) + secure_mode: whether the project is in secure training mode + root_cert_file: file path of the root cert + private_key_file: file path of the private key + cert_file: file path of the client's certificate + msg_timeout: timeout for authentication messages + retry_interval: interval between tries + """ + self.cell = cell + self.project_name = project_name + self.client_name = client_name + self.client_type = client_type + self.expected_sp_identity = expected_sp_identity + self.root_cert_file = root_cert_file + self.private_key_file = private_key_file + self.cert_file = cert_file + self.msg_timeout = msg_timeout + self.retry_interval = retry_interval + self.secure_mode = secure_mode + self.logger = get_obj_logger(self) + + def _challenge_server(self): + # ask server for its info and make sure that it matches expected host + my_nonce = str(uuid.uuid4()) + headers = {IdentityChallengeKey.COMMON_NAME: self.client_name, IdentityChallengeKey.NONCE: my_nonce} + challenge = new_cell_message(headers, None) + result = self.cell.send_request( + target=FQCN.ROOT_SERVER, + channel=CellChannel.SERVER_MAIN, + topic=CellChannelTopic.Challenge, + request=challenge, + timeout=self.msg_timeout, + ) + return_code = result.get_header(MessageHeaderKey.RETURN_CODE) + error = result.get_header(MessageHeaderKey.ERROR, "") + self.logger.info(f"challenge result: {return_code} {error}") + if return_code != ReturnCode.OK: + if return_code in [ReturnCode.TARGET_UNREACHABLE, ReturnCode.COMM_ERROR]: + # trigger retry + return None, None + err = result.get_header(MessageHeaderKey.ERROR, "") + raise FLCommunicationError(f"failed to challenge server: {return_code}: {err}") + + reply = result.payload + assert isinstance(reply, Shareable) + server_nonce = reply.get(IdentityChallengeKey.NONCE) + cert_bytes = reply.get(IdentityChallengeKey.CERT) + server_cert = load_crt_bytes(cert_bytes) + server_signature = reply.get(IdentityChallengeKey.SIGNATURE) + server_cn = reply.get(IdentityChallengeKey.COMMON_NAME) + + if server_cn != self.expected_sp_identity: + raise FLCommunicationError( + f"expected server identity is '{self.expected_sp_identity}' but got '{server_cn}'" + ) + + # Use IdentityVerifier to validate: + # - the server cert can be validated with the root cert. Note that all sites have the same root cert! + # - the asserted CN matches the CN on the server cert + # - signature received from the server is valid + id_verifier = IdentityVerifier(root_cert_file=self.root_cert_file) + id_verifier.verify_common_name( + asserter_cert=server_cert, asserted_cn=server_cn, nonce=my_nonce, signature=server_signature + ) + + self.logger.info(f"verified server identity '{self.expected_sp_identity}'") + return server_nonce, TokenVerifier(server_cert) + + def authenticate(self, shared_fl_ctx: FLContext, abort_signal: Signal): + """Register the client with the FLARE Server. + + Note that the client no longer needs to be directly connected with the Server! + + Since the client may be connected with the Server indirectly (e.g. via bridge nodes or proxy), in the secure + mode, the client authentication cannot be based on the connection's TLS cert. Instead, the server and the + client will explicitly authenticate each other using their provisioned PKI credentials, as follows: + + 1. Make sure that the Server is authentic. The client sends a Challenge request with a random nonce. + The server is expected to return the following in its reply: + - its cert and common name (Server_CN) + - signature on the received client nonce + Server_CN + - a random Server Nonce. This will be used for the server to validate the client's identity in the + Registration request. + + The client then validates to make sure: + - the Server_CN is the same as presented in the server cert + - the Server_CN is the same as configured in the client's config (fed_client.json) + - the signature is valid + + 2. Client sends Registration request that contains: + - client cert and common name (Client_CN) + - signature on the received Server Nonce + Client_CN + + The Server then validates to make sure: + - the Client_CN is the same as presented in the client cert + - the signature is valid + + NOTE: we do not explicitly validate certs' expiration time. This is because currently the same certs are + also used for SSL connections, which already validate expiration. + + Args: + fl_ctx: FLContext + + Returns: + The client's token + + """ + local_ip = _get_client_ip() + shareable = Shareable() + shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) + + token_verifier = None + if self.secure_mode: + # explicitly authenticate with the Server + while True: + server_nonce, token_verifier = self._challenge_server() + + if abort_signal.triggered: + return None, None, None, None + + if server_nonce is None: + # retry + self.logger.info(f"re-challenge after {self.retry_interval} seconds") + time.sleep(self.retry_interval) + else: + break + + id_asserter = IdentityAsserter(private_key_file=self.private_key_file, cert_file=self.cert_file) + cn_signature = id_asserter.sign_common_name(nonce=server_nonce) + shareable[IdentityChallengeKey.CERT] = id_asserter.cert_data + shareable[IdentityChallengeKey.SIGNATURE] = cn_signature + shareable[IdentityChallengeKey.COMMON_NAME] = id_asserter.cn + self.logger.info(f"sent identity info for client {self.client_name}") + + headers = { + CellMessageHeaderKeys.CLIENT_NAME: self.client_name, + CellMessageHeaderKeys.CLIENT_TYPE: self.client_type, + CellMessageHeaderKeys.CLIENT_IP: local_ip, + CellMessageHeaderKeys.PROJECT_NAME: self.project_name, + } + login_message = new_cell_message(headers, shareable) + + self.logger.info("Trying to register with server ...") + while True: + try: + result = self.cell.send_request( + target=FQCN.ROOT_SERVER, + channel=CellChannel.SERVER_MAIN, + topic=CellChannelTopic.Register, + request=login_message, + timeout=self.msg_timeout, + ) + + if not isinstance(result, Message): + raise FLCommunicationError(f"expect result to be Message but got {type(result)}") + + return_code = result.get_header(MessageHeaderKey.RETURN_CODE) + self.logger.info(f"register RC: {return_code}") + if return_code == ReturnCode.UNAUTHENTICATED: + reason = result.get_header(MessageHeaderKey.ERROR) + self.logger.error(f"registration rejected: {reason}") + raise FLCommunicationError("error:client_registration " + reason) + + payload = result.payload + if not isinstance(payload, dict): + raise FLCommunicationError(f"expect payload to be dict but got {type(payload)}") + + token = payload.get(CellMessageHeaderKeys.TOKEN) + token_signature = payload.get(CellMessageHeaderKeys.TOKEN_SIGNATURE, "NA") + ssid = payload.get(CellMessageHeaderKeys.SSID) + if not token and not abort_signal.triggered: + time.sleep(self.retry_interval) + else: + break + + except Exception as ex: + traceback.print_exc() + raise FLCommunicationError("error:client_registration", ex) + + # make sure token_verifier works + if token_verifier: + if not isinstance(token_verifier, TokenVerifier): + raise RuntimeError(f"expect token_verifier to be TokenVerifier but got {type(token_verifier)}") + + if token_verifier and token_signature: + valid = token_verifier.verify(client_name=self.client_name, token=token, signature=token_signature) + if valid: + self.logger.info("Verified received token and signature successfully") + else: + raise RuntimeError("invalid token or verifier!") + + return token, token_signature, ssid, token_verifier + + +def validate_auth_headers(message: CellMessage, token_verifier: TokenVerifier, logger): + """Validate auth headers from messages that go through the server. + + Args: + message: the message to validate + token_verifier: the TokenVerifier to be used to verify the token and signature + + Returns: + """ + headers = message.headers + logger.debug(f"**** _validate_auth_headers: {headers=}") + topic = message.get_header(MessageHeaderKey.TOPIC) + channel = message.get_header(MessageHeaderKey.CHANNEL) + + origin = message.get_header(MessageHeaderKey.ORIGIN) + + if topic in [CellChannelTopic.Register, CellChannelTopic.Challenge] and channel == CellChannel.SERVER_MAIN: + # skip: client not registered yet + logger.debug(f"skip special message {topic=} {channel=}") + return None + + client_name = message.get_header(CellMessageHeaderKeys.CLIENT_NAME) + err_text = f"unauthenticated msg ({channel=} {topic=}) received from {origin}" + if not client_name: + err = "missing client name" + logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + token = message.get_header(CellMessageHeaderKeys.TOKEN) + if not token: + err = "missing auth token" + logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + signature = message.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE) + if not signature: + err = "missing auth token signature" + logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + if not token_verifier.verify(client_name, token, signature): + err = "invalid auth token signature" + logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + # all good + logger.debug(f"auth headers valid from {origin}: {topic=} {channel=}") + return None diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index cbf94777d5..c1c178d20b 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SystemConfigs +from nvflare.apis.fl_constant import AdminCommandNames, ConnPropKey, FLContextKey, RunProcessKey, SystemConfigs from nvflare.apis.fl_context import FLContext from nvflare.apis.job_launcher_spec import JobLauncherSpec, JobProcessArgs from nvflare.apis.resource_manager_spec import ResourceManagerSpec @@ -201,6 +201,13 @@ def start_app( JobProcessArgs.STARTUP_CONFIG_FILE: ("-s", "fed_client.json"), JobProcessArgs.OPTIONS: ("--set", command_options), } + + params = client.cell.get_internal_listener_params() + if params: + parent_conn_sec = params.get(ConnPropKey.CONNECTION_SECURITY) + if parent_conn_sec: + job_args[JobProcessArgs.PARENT_CONN_SEC] = ("-pcs", parent_conn_sec) + fl_ctx.set_prop(key=FLContextKey.JOB_PROCESS_ARGS, value=job_args, private=True, sticky=False) job_handle = job_launcher.launch_job(job_meta, fl_ctx) self.logger.info(f"Launch job_id: {job_id} with job launcher: {type(job_launcher)} ") diff --git a/nvflare/private/fed/client/client_json_config.py b/nvflare/private/fed/client/client_json_config.py index acb66a3b05..2bc028ee50 100644 --- a/nvflare/private/fed/client/client_json_config.py +++ b/nvflare/private/fed/client/client_json_config.py @@ -16,8 +16,9 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import SystemConfigs, SystemVarName +from nvflare.apis.fl_constant import ConnPropKey, SystemConfigs, SystemVarName from nvflare.apis.workspace import Workspace +from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node @@ -64,11 +65,27 @@ def __init__( sp_target = args.sp_target sp_url = f"{sp_scheme}://{sp_target}" + # determine relay URL + # if relay is not used, use the root URL as relay URL + relay_conn_props = get_scope_property(args.client_name, ConnPropKey.RELAY_CONN_PROPS) + relay_url = None + if relay_conn_props: + relay_url = relay_conn_props.get(ConnPropKey.URL) + if not relay_url: + relay_url = sp_url + + if hasattr(args, "parent_url") and args.parent_url: + parent_url = args.parent_url + else: + parent_url = sp_url + sys_vars = { SystemVarName.JOB_ID: args.job_id, SystemVarName.SITE_NAME: args.client_name, SystemVarName.WORKSPACE: args.workspace, SystemVarName.ROOT_URL: sp_url, + SystemVarName.CP_URL: parent_url, + SystemVarName.RELAY_URL: relay_url, SystemVarName.SECURE_MODE: self.cmd_vars.get("secure_train", True), SystemVarName.JOB_CUSTOM_DIR: workspace_obj.get_app_custom_dir(args.job_id), } diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 50d86cf30e..14a1b6fd67 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import socket import time -import traceback -import uuid from typing import List, Optional from nvflare.apis.event_type import EventType @@ -26,43 +23,28 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import FLCommunicationError from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.fuel.data_event.utils import get_scope_property, set_scope_property from nvflare.fuel.f3.cellnet.cell import Cell -from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.utils import format_size -from nvflare.fuel.f3.message import Message as CellMessage -from nvflare.fuel.sec.authn import add_authentication_headers +from nvflare.fuel.sec.authn import set_add_auth_headers_filters from nvflare.fuel.utils.log_utils import get_obj_logger -from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, SpecialTaskName, new_cell_message +from nvflare.private.defs import ( + CellChannel, + CellChannelTopic, + CellMessageHeaderKeys, + ClientType, + SpecialTaskName, + new_cell_message, +) +from nvflare.private.fed.authenticator import Authenticator from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec -from nvflare.private.fed.utils.identity_utils import IdentityAsserter, IdentityVerifier, load_crt_bytes from nvflare.security.logging import secure_format_exception -def _get_client_ip(): - """Return localhost IP. - - More robust than ``socket.gethostbyname(socket.gethostname())``. See - https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib/28950776#28950776 - for more details. - - Returns: - The host IP - - """ - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect(("10.255.255.255", 1)) # doesn't even have to be reachable - ip = s.getsockname()[0] - except Exception: - ip = "127.0.0.1" - finally: - s.close() - return ip - - class Communicator: def __init__( self, @@ -88,7 +70,6 @@ def __init__( self.secure_train = secure_train self.verbose = False - self.should_stop = False self.heartbeat_done = False self.client_state_processors = client_state_processors self.compression = compression @@ -102,78 +83,37 @@ def __init__( self.token_signature = None self.ssid = None self.client_name = None + self.token_verifier = None + self.abort_signal = Signal() self.logger = get_obj_logger(self) + """ + To call set_add_auth_headers_filters, both cell and token must be available. + The set_cell is called when cell becomes available, set_auth is called when token becomes available. + In CP, set_cell happens before set_auth, hence we call set_add_auth_headers_filters in set_auth for CP. + In CJ, set_auth happens before set_cell, hence we call set_add_auth_headers_filters in set_cell for CJ. + """ + def set_auth(self, client_name, token, token_signature, ssid): self.ssid = ssid self.token_signature = token_signature self.token = token self.client_name = client_name - # put auth properties in databus so that they can be used elsewhere + if self.cell: + # for CP + set_add_auth_headers_filters(self.cell, client_name, token, token_signature, ssid) + + # put auth properties in data bus so that they can be used elsewhere set_scope_property(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN, value=token) set_scope_property(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, value=token_signature) def set_cell(self, cell): self.cell = cell - - # set filter to add additional auth headers - cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._add_auth_headers) - cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._add_auth_headers) - - def _add_auth_headers(self, message: CellMessage): - if self.ssid: - message.set_header(CellMessageHeaderKeys.SSID, self.ssid) - - # Note that auth info (client_name, token and signature) is not available until the client is fully - # authenticated. - add_authentication_headers(message, self.client_name, self.token, self.token_signature) - - def _challenge_server(self, client_name, expected_host, root_cert_file): - # ask server for its info and make sure that it matches expected host - my_nonce = str(uuid.uuid4()) - headers = {IdentityChallengeKey.COMMON_NAME: client_name, IdentityChallengeKey.NONCE: my_nonce} - challenge = new_cell_message(headers, None) - result = self.cell.send_request( - target=FQCN.ROOT_SERVER, - channel=CellChannel.SERVER_MAIN, - topic=CellChannelTopic.Challenge, - request=challenge, - timeout=self.maint_msg_timeout, - ) - return_code = result.get_header(MessageHeaderKey.RETURN_CODE) - error = result.get_header(MessageHeaderKey.ERROR, "") - self.logger.info(f"challenge result: {return_code} {error}") - if return_code != ReturnCode.OK: - if return_code in [ReturnCode.TARGET_UNREACHABLE, ReturnCode.COMM_ERROR]: - # trigger retry - return None - err = result.get_header(MessageHeaderKey.ERROR, "") - raise FLCommunicationError(f"failed to challenge server: {return_code}: {err}") - - reply = result.payload - assert isinstance(reply, Shareable) - server_nonce = reply.get(IdentityChallengeKey.NONCE) - cert_bytes = reply.get(IdentityChallengeKey.CERT) - server_cert = load_crt_bytes(cert_bytes) - server_signature = reply.get(IdentityChallengeKey.SIGNATURE) - server_cn = reply.get(IdentityChallengeKey.COMMON_NAME) - - if server_cn != expected_host: - raise FLCommunicationError(f"expected server identity is '{expected_host}' but got '{server_cn}'") - - # Use IdentityVerifier to validate: - # - the server cert can be validated with the root cert. Note that all sites have the same root cert! - # - the asserted CN matches the CN on the server cert - # - signature received from the server is valid - id_verifier = IdentityVerifier(root_cert_file=root_cert_file) - id_verifier.verify_common_name( - asserter_cert=server_cert, asserted_cn=server_cn, nonce=my_nonce, signature=server_signature - ) - - self.logger.info(f"verified server identity '{expected_host}'") - return server_nonce + if self.token: + # for CJ + set_add_auth_headers_filters(self.cell, self.client_name, self.token, self.token_signature, self.ssid) def client_registration(self, client_name, project_name, fl_ctx: FLContext): """Register the client with the FLARE Server. @@ -223,12 +163,14 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): raise RuntimeError("Client cell could not be created. Failed to login the client.") time.sleep(0.5) - local_ip = _get_client_ip() - shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) - shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) + private_key_file = None + root_cert_file = None + cert_file = None secure_mode = fl_ctx.get_prop(FLContextKey.SECURE_MODE, False) + expected_host = None + if secure_mode: # explicitly authenticate with the Server expected_host = None @@ -253,60 +195,24 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): cert_file = client_config.get(SecureTrainConst.SSL_CERT) root_cert_file = client_config.get(SecureTrainConst.SSL_ROOT_CERT) - while True: - server_nonce = self._challenge_server(client_name, expected_host, root_cert_file) - if server_nonce is None and not self.should_stop: - # retry - self.logger.info(f"re-challenge after {self.client_register_interval} seconds") - time.sleep(self.client_register_interval) - else: - break - - id_asserter = IdentityAsserter(private_key_file=private_key_file, cert_file=cert_file) - cn_signature = id_asserter.sign_common_name(nonce=server_nonce) - shareable[IdentityChallengeKey.CERT] = id_asserter.cert_data - shareable[IdentityChallengeKey.SIGNATURE] = cn_signature - shareable[IdentityChallengeKey.COMMON_NAME] = id_asserter.cn - self.logger.info(f"sent identity info for client {client_name}") - - headers = { - CellMessageHeaderKeys.CLIENT_NAME: client_name, - CellMessageHeaderKeys.CLIENT_IP: local_ip, - CellMessageHeaderKeys.PROJECT_NAME: project_name, - } - login_message = new_cell_message(headers, shareable) - - self.logger.info("Trying to register with server ...") - while True: - try: - result = self.cell.send_request( - target=FQCN.ROOT_SERVER, - channel=CellChannel.SERVER_MAIN, - topic=CellChannelTopic.Register, - request=login_message, - timeout=self.maint_msg_timeout, - ) - return_code = result.get_header(MessageHeaderKey.RETURN_CODE) - self.logger.info(f"register RC: {return_code}") - if return_code == ReturnCode.UNAUTHENTICATED: - reason = result.get_header(MessageHeaderKey.ERROR) - self.logger.error(f"registration rejected: {reason}") - raise FLCommunicationError("error:client_registration " + reason) - - token = result.get_header(CellMessageHeaderKeys.TOKEN) - token_signature = result.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, "NA") - ssid = result.get_header(CellMessageHeaderKeys.SSID) - if not token and not self.should_stop: - time.sleep(self.client_register_interval) - else: - self.set_auth(client_name, token, token_signature, ssid) - break - - except Exception as ex: - traceback.print_exc() - raise FLCommunicationError("error:client_registration", ex) - - return token, token_signature, ssid + authenticator = Authenticator( + cell=self.cell, + project_name=project_name, + client_name=client_name, + client_type=ClientType.REGULAR, + expected_sp_identity=expected_host, + secure_mode=secure_mode, + root_cert_file=root_cert_file, + private_key_file=private_key_file, + cert_file=cert_file, + msg_timeout=self.maint_msg_timeout, + retry_interval=self.client_register_interval, + ) + + token, signature, ssid, token_verifier = authenticator.authenticate(shared_fl_ctx, self.abort_signal) + self.token_verifier = token_verifier + self.set_auth(client_name, token, signature, ssid) + return token, signature, ssid def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): """Get a task from server. @@ -326,7 +232,6 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) - client_name = fl_ctx.get_identity_name() task_message = new_cell_message( { CellMessageHeaderKeys.PROJECT_NAME: project_name, @@ -444,10 +349,10 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): server's reply to the last message """ + self.abort_signal.trigger(True) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable = Shareable() shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) - client_name = fl_ctx.get_identity_name() quit_message = new_cell_message( { CellMessageHeaderKeys.PROJECT_NAME: task_name, diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 6a40717457..14215bed55 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -18,13 +18,13 @@ from nvflare.apis.filter import Filter from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FLContextKey, SecureTrainConst, ServerCommandKey +from nvflare.apis.fl_constant import ConnPropKey, FLContextKey, SecureTrainConst, ServerCommandKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import FLCommunicationError from nvflare.apis.overseer_spec import SP from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal -from nvflare.fuel.data_event.utils import set_scope_property +from nvflare.fuel.data_event.utils import get_scope_property, set_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent @@ -188,29 +188,36 @@ def _create_cell(self, location, scheme): """ # Determine the CP's fqcn root_url = scheme + "://" + location + root_conn_security = self.client_args.get(ConnPropKey.CONNECTION_SECURITY) - # bridge_fqcn and bridge_url are set in the client's local/resources.json. - # If they are set, then connect via the specified bridge; if not, try to connect the Server directly - bridge_fqcn = self.client_args.get("bridge_fqcn") - bridge_url = self.client_args.get("bridge_url") - if bridge_fqcn: - cp_fqcn = FQCN.join([bridge_fqcn, self.client_name]) - root_url = None # do not connect to server if bridge is used - else: - cp_fqcn = self.client_name + relay_conn_props = get_scope_property(self.client_name, ConnPropKey.RELAY_CONN_PROPS, {}) + self.logger.info(f"got {ConnPropKey.RELAY_CONN_PROPS}: {relay_conn_props}") + + relay_fqcn = relay_conn_props.get(ConnPropKey.FQCN) + if relay_fqcn: + root_url = None # do not connect to server if relay is used + cp_conn_props = get_scope_property(self.client_name, ConnPropKey.CP_CONN_PROPS) + cp_fqcn = cp_conn_props.get(ConnPropKey.FQCN) + parent_resources = None if self.args.job_id: # I am CJ me = "CJ" my_fqcn = FQCN.join([cp_fqcn, self.args.job_id]) - parent_url = self.args.parent_url + parent_url = cp_conn_props.get(ConnPropKey.URL) + parent_conn_sec = cp_conn_props.get(ConnPropKey.CONNECTION_SECURITY) create_internal_listener = False + if parent_conn_sec: + parent_resources = {DriverParams.CONNECTION_SECURITY.value: parent_conn_sec} else: # I am CP me = "CP" my_fqcn = cp_fqcn - parent_url = bridge_url + parent_url = relay_conn_props.get(ConnPropKey.URL) create_internal_listener = True + relay_conn_security = relay_conn_props.get(ConnPropKey.CONNECTION_SECURITY) + if relay_conn_security: + parent_resources = {DriverParams.CONNECTION_SECURITY.value: relay_conn_security} if self.secure_train: root_cert = self.client_args[SecureTrainConst.SSL_ROOT_CERT] @@ -222,13 +229,13 @@ def _create_cell(self, location, scheme): DriverParams.CLIENT_CERT.value: ssl_cert, DriverParams.CLIENT_KEY.value: private_key, } - conn_security = self.client_args.get(SecureTrainConst.CONNECTION_SECURITY) - if conn_security: - credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security - set_scope_property(self.client_name, SecureTrainConst.CONNECTION_SECURITY, conn_security) else: credentials = {} + if root_conn_security: + # this is the default conn sec + credentials[DriverParams.CONNECTION_SECURITY.value] = root_conn_security + self.logger.info(f"{me=}: {my_fqcn=} {root_url=} {parent_url=}") self.cell = Cell( fqcn=my_fqcn, @@ -237,6 +244,7 @@ def _create_cell(self, location, scheme): credentials=credentials, create_internal_listener=create_internal_listener, parent_url=parent_url, + parent_resources=parent_resources, ) self.cell.start() self.communicator.set_cell(self.cell) diff --git a/nvflare/private/fed/server/client_manager.py b/nvflare/private/fed/server/client_manager.py index 1e1c58ad11..2dade0ad57 100644 --- a/nvflare/private/fed/server/client_manager.py +++ b/nvflare/private/fed/server/client_manager.py @@ -23,7 +23,7 @@ from nvflare.apis.shareable import Shareable from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey from nvflare.fuel.utils.log_utils import get_obj_logger -from nvflare.private.defs import CellMessageHeaderKeys, ClientRegSession, InternalFLContextKey +from nvflare.private.defs import CellMessageHeaderKeys, ClientRegSession, ClientType, InternalFLContextKey from nvflare.private.fed.utils.identity_utils import IdentityVerifier, load_crt_bytes from nvflare.security.logging import secure_format_exception @@ -57,10 +57,17 @@ def authenticate(self, request, fl_ctx: FLContext) -> Optional[Client]: # new client join with self.lock: - self.clients.update({client.token: client}) + client_type = request.get_header(CellMessageHeaderKeys.CLIENT_TYPE) + if client_type == ClientType.REGULAR: + self.clients.update({client.token: client}) + client_kind = "client" + else: + # do not update self.clients for non-regular clients + client_kind = client_type + self.logger.info( - "Client: New client {} joined. Sent token: {}. Total clients: {}".format( - client.name + "@" + client_ip, client.token, len(self.clients) + "Client: New {} {} joined. Sent token: {}. Total clients: {}".format( + client_kind, client.name + "@" + client_ip, client.token, len(self.clients) ) ) return client diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 354df3fda3..f1008e5951 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -25,6 +25,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ( ConfigVarName, + ConnPropKey, FLContextKey, MachineStatus, RunProcessKey, @@ -63,13 +64,15 @@ CellChannelTopic, CellMessageHeaderKeys, ClientRegSession, + ClientType, InternalFLContextKey, JobFailureMsgKey, new_cell_message, ) +from nvflare.private.fed.authenticator import validate_auth_headers from nvflare.private.fed.server.server_command_agent import ServerCommandAgent from nvflare.private.fed.server.server_runner import ServerRunner -from nvflare.private.fed.utils.identity_utils import IdentityAsserter +from nvflare.private.fed.utils.identity_utils import IdentityAsserter, TokenVerifier from nvflare.security.logging import secure_format_exception from nvflare.widgets.fed_event import ServerFedEventRunner @@ -169,7 +172,7 @@ def deploy(self, args, grpc_args=None, secure_train=False): DriverParams.SERVER_KEY.value: private_key, } - conn_security = grpc_args.get(SecureTrainConst.CONNECTION_SECURITY) + conn_security = grpc_args.get(ConnPropKey.CONNECTION_SECURITY) if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security else: @@ -401,12 +404,15 @@ def _add_auth_headers(self, message: Message): """ origin = message.get_header(MessageHeaderKey.ORIGIN) dest = message.get_header(MessageHeaderKey.DESTINATION) - if origin == FQCN.ROOT_SERVER and dest == origin: - if not self.my_own_token_signature: - self.my_own_token_signature = self.sign_auth_token(self.my_own_auth_client_name, self.my_own_token) - add_authentication_headers( - message, self.my_own_auth_client_name, self.my_own_token, self.my_own_token_signature - ) + channel = message.get_header(MessageHeaderKey.CHANNEL) + topic = message.get_header(MessageHeaderKey.TOPIC) + if not self.my_own_token_signature: + self.my_own_token_signature = self.sign_auth_token(self.my_own_auth_client_name, self.my_own_token) + + add_authentication_headers( + message, self.my_own_auth_client_name, self.my_own_token, self.my_own_token_signature + ) + self.logger.debug(f"added auth headers: {origin=} {dest=} {channel=} {topic=}") def _validate_auth_headers(self, message: Message): """Validate auth headers from messages that go through the server. @@ -414,45 +420,17 @@ def _validate_auth_headers(self, message: Message): message: the message to validate Returns: """ - headers = message.headers - self.logger.debug(f"**** _validate_auth_headers: {headers=}") - topic = message.get_header(MessageHeaderKey.TOPIC) - channel = message.get_header(MessageHeaderKey.CHANNEL) - - origin = message.get_header(MessageHeaderKey.ORIGIN) - - if topic in [CellChannelTopic.Register, CellChannelTopic.Challenge] and channel == CellChannel.SERVER_MAIN: - # skip: client not registered yet - self.logger.debug(f"skip special message {topic=} {channel=}") + id_asserter = self._get_id_asserter() + if not id_asserter: return None - client_name = message.get_header(CellMessageHeaderKeys.CLIENT_NAME) - err_text = f"unauthenticated msg ({channel=} {topic=}) received from {origin}" - if not client_name: - err = "missing client name" - self.logger.error(f"{err_text}: {err}") - return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) - - token = message.get_header(CellMessageHeaderKeys.TOKEN) - if not token: - err = "missing auth token" - self.logger.error(f"{err_text}: {err}") - return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) - - signature = message.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE) - if not signature: - err = "missing auth token signature" - self.logger.error(f"{err_text}: {err}") - return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) - - if not self.verify_auth_token(client_name, token, signature): - err = "invalid auth token signature" - self.logger.error(f"{err_text}: {err}") - return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) - - # all good - self.logger.debug(f"auth valid from {origin}: {topic=} {channel=}") - return None + token_verifier = TokenVerifier(id_asserter.cert) + + return validate_auth_headers( + message=message, + token_verifier=token_verifier, + logger=self.logger, + ) def sign_auth_token(self, client_name: str, token: str): id_asserter = self._get_id_asserter() @@ -464,7 +442,9 @@ def verify_auth_token(self, client_name: str, token: str, signature): id_asserter = self._get_id_asserter() if not id_asserter: return True - return id_asserter.verify_signature(client_name + token, signature) + + token_verifier = TokenVerifier(id_asserter.cert) + return token_verifier.verify(client_name, token, signature) def _check_regs(self): while True: @@ -541,7 +521,7 @@ def create_job_cell(self, job_id, root_url, parent_url, secure_train, server_con DriverParams.SERVER_KEY.value: private_key, } - conn_security = server_config.get(SecureTrainConst.CONNECTION_SECURITY) + conn_security = server_config.get(ConnPropKey.CONNECTION_SECURITY) if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security else: @@ -710,20 +690,22 @@ def register_client(self, request: Message) -> Message: client = self.client_manager.authenticate(request, fl_ctx) if client and client.token: - self.tokens[client.token] = self.task_meta_info(client.name) - if self.admin_server: - self.admin_server.client_heartbeat(client.token, client.name, client.get_fqcn()) + client_type = request.get_header(CellMessageHeaderKeys.CLIENT_TYPE) + if client_type == ClientType.REGULAR: + self.tokens[client.token] = self.task_meta_info(client.name) + if self.admin_server: + self.admin_server.client_heartbeat(client.token, client.name, client.get_fqcn()) token_signature = self.sign_auth_token(client.name, client.token) - headers = { + result = { CellMessageHeaderKeys.TOKEN: client.token, CellMessageHeaderKeys.TOKEN_SIGNATURE: token_signature, CellMessageHeaderKeys.SSID: self.server_state.ssid, } else: - headers = {} + result = {} self.engine.fire_event(EventType.CLIENT_REGISTER_PROCESSED, fl_ctx=fl_ctx) - return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) + return self._generate_reply(headers={}, payload=result, fl_ctx=fl_ctx) except NotAuthenticated as e: self.logger.error(f"Failed to authenticate the register_client: {secure_format_exception(e)}") return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error="register_client unauthenticated") diff --git a/nvflare/private/fed/server/server_state.py b/nvflare/private/fed/server/server_state.py index 5b0ba3ca9e..fa2a8cd414 100644 --- a/nvflare/private/fed/server/server_state.py +++ b/nvflare/private/fed/server/server_state.py @@ -91,11 +91,11 @@ def aux_communicate(self, fl_ctx: FLContext) -> dict: def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: if sp: - self.logger.info( + self.logger.debug( f"handle_sd_callback Got SP: {sp.name=} {sp.fl_port=} {sp.primary=} {self.host=} {self.service_port=}" ) else: - self.logger.info("handle_sd_callback no SP!") + self.logger.debug("handle_sd_callback no SP!") if sp and sp.primary is True: if sp.name == self.host and sp.fl_port == self.service_port: diff --git a/nvflare/private/fed/server/training_cmds.py b/nvflare/private/fed/server/training_cmds.py index 8ad9be7b67..7d31af86e0 100644 --- a/nvflare/private/fed/server/training_cmds.py +++ b/nvflare/private/fed/server/training_cmds.py @@ -18,6 +18,7 @@ from nvflare.apis.client import Client from nvflare.apis.fl_constant import AdminCommandNames, SiteType +from nvflare.fuel.data_event.data_bus import DataBus from nvflare.fuel.hci.conn import Connection from nvflare.fuel.hci.proto import ConfirmMethod, MetaKey, MetaStatusValue, make_meta from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec @@ -160,6 +161,12 @@ def shutdown(self, conn: Connection, args: List[str]): conn.update_meta(make_meta(MetaStatusValue.ERROR, "failed to shut down all clients")) return + if target_type in [self.TARGET_TYPE_ALL]: + # shutdown the cellnet + data_bus = DataBus() + data_bus.publish(["stop_cellnet"], conn) + # time.sleep(2.0) + if target_type in [self.TARGET_TYPE_SERVER, self.TARGET_TYPE_ALL]: # shut down the server err = self._shutdown_app_on_server(conn) diff --git a/nvflare/private/fed/simulator/simulator_server.py b/nvflare/private/fed/simulator/simulator_server.py index 881eed6039..13b273fa1b 100644 --- a/nvflare/private/fed/simulator/simulator_server.py +++ b/nvflare/private/fed/simulator/simulator_server.py @@ -79,7 +79,6 @@ def create_job_processing_context_properties(self, workspace, job_id): class SimulatorIdentityAsserter(IdentityAsserter): - def __init__(self, private_key_file: str, cert_file: str): self.private_key_file = private_key_file self.cert_file = cert_file diff --git a/nvflare/private/fed/utils/identity_utils.py b/nvflare/private/fed/utils/identity_utils.py index d8a8a44850..10d948d6a3 100644 --- a/nvflare/private/fed/utils/identity_utils.py +++ b/nvflare/private/fed/utils/identity_utils.py @@ -14,6 +14,7 @@ from cryptography.x509.oid import NameOID +from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.lighter.utils import ( load_crt, load_crt_bytes, @@ -110,3 +111,18 @@ def verify_common_name(self, asserted_cn: str, nonce: str, asserter_cert, signat except Exception as ex: raise InvalidCNSignature(f"cannot verify common name signature: {secure_format_exception(ex)}") return True + + +class TokenVerifier: + def __init__(self, cert): + self.cert = cert + self.public_key = cert.public_key() + self.logger = get_obj_logger(self) + + def verify(self, client_name, token, signature): + try: + verify_content(content=client_name + token, signature=signature, public_key=self.public_key) + return True + except Exception as ex: + self.logger.error(f"exception verifying token: {client_name=} {token=}: {secure_format_exception(ex)}") + return False diff --git a/nvflare/private/json_configer.py b/nvflare/private/json_configer.py index aca538658b..aaecbcbd1b 100644 --- a/nvflare/private/json_configer.py +++ b/nvflare/private/json_configer.py @@ -22,6 +22,7 @@ from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.dict_utils import augment from nvflare.fuel.utils.json_scanner import JsonObjectProcessor, JsonScanner, Node +from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.wfconf import resolve_var_refs from nvflare.security.logging import secure_format_exception @@ -54,6 +55,7 @@ def __init__( sys_vars: system vars """ JsonObjectProcessor.__init__(self) + self.logger = get_obj_logger(self) if not isinstance(num_passes, int): raise TypeError(f"num_passes must be int but got {num_passes}") diff --git a/nvflare/utils/job_launcher_utils.py b/nvflare/utils/job_launcher_utils.py index 6865673869..a3013e25aa 100644 --- a/nvflare/utils/job_launcher_utils.py +++ b/nvflare/utils/job_launcher_utils.py @@ -24,7 +24,10 @@ def _job_args_str(job_args, arg_names) -> str: result = "" sep = "" for name in arg_names: - n, v = job_args[name] + e = job_args.get(name) + if not e: + continue + n, v = e result += f"{sep}{n} {v}" sep = " " return result @@ -45,6 +48,7 @@ def get_client_job_args(include_exe_module=True, include_set_options=True): JobProcessArgs.JOB_ID, JobProcessArgs.CLIENT_NAME, JobProcessArgs.PARENT_URL, + JobProcessArgs.PARENT_CONN_SEC, JobProcessArgs.TARGET, JobProcessArgs.SCHEME, JobProcessArgs.STARTUP_CONFIG_FILE, diff --git a/tests/unit_test/client/in_process/api_test.py b/tests/unit_test/client/in_process/api_test.py index 835bd80b5d..a873d78499 100644 --- a/tests/unit_test/client/in_process/api_test.py +++ b/tests/unit_test/client/in_process/api_test.py @@ -65,8 +65,10 @@ def test_init_with_custom_interval(self): def test_init_subscriptions(self): client_api = InProcessClientAPI(self.task_metadata) xs = list(client_api.data_bus.subscribers.keys()) - xs.sort() - assert xs == [TOPIC_ABORT, TOPIC_GLOBAL_RESULT, TOPIC_STOP] + + # Depending on the timing of this test, the data bus may have other subscribed topics + # since the data bus is a singleton! + assert set(xs).issuperset([TOPIC_ABORT, TOPIC_GLOBAL_RESULT, TOPIC_STOP]) def local_result_callback(self, data, topic): pass diff --git a/tests/unit_test/fuel/f3/comm_config_utils_test.py b/tests/unit_test/fuel/f3/comm_config_utils_test.py new file mode 100644 index 0000000000..b8da01935f --- /dev/null +++ b/tests/unit_test/fuel/f3/comm_config_utils_test.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from nvflare.apis.fl_constant import ConnectionSecurity +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection +from nvflare.fuel.f3.drivers.driver_params import DriverParams + +CS = DriverParams.CONNECTION_SECURITY.value +S = DriverParams.SECURE.value +IS = ConnectionSecurity.CLEAR +T = ConnectionSecurity.TLS +M = ConnectionSecurity.MTLS + + +class TestCommConfigUtils: + + @pytest.mark.parametrize( + "resources, expected", + [ + ({}, False), + ({"x": 1, "y": 2}, False), + ({S: True}, True), + ({S: False}, False), + ({CS: IS}, False), + ({CS: T}, True), + ({CS: M}, True), + ({CS: M, S: False}, True), + ({CS: M, S: True}, True), + ({CS: T, S: False}, True), + ({CS: T, S: True}, True), + ({CS: IS, S: False}, False), + ({CS: IS, S: True}, False), + ], + ) + def test_requires_secure_connection(self, resources, expected): + result = requires_secure_connection(resources) + assert result == expected diff --git a/tests/unit_test/fuel/utils/url_utils_test.py b/tests/unit_test/fuel/utils/url_utils_test.py new file mode 100644 index 0000000000..a466d4a5c4 --- /dev/null +++ b/tests/unit_test/fuel/utils/url_utils_test.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from nvflare.fuel.utils.url_utils import make_url + + +class TestUrlUtils: + + @pytest.mark.parametrize( + "scheme, address, secure, expected", + [ + ("tcp", "xyz.com", False, "tcp://xyz.com"), + ("tcp", "xyz.com:1234", False, "tcp://xyz.com:1234"), + ("tcp", "xyz.com:1234", True, "stcp://xyz.com:1234"), + ("grpc", "xyz.com", False, "grpc://xyz.com"), + ("grpc", "xyz.com:1234", False, "grpc://xyz.com:1234"), + ("grpc", "xyz.com:1234", True, "grpcs://xyz.com:1234"), + ("http", "xyz.com", False, "http://xyz.com"), + ("http", "xyz.com:1234", False, "http://xyz.com:1234"), + ("http", "xyz.com:1234", True, "https://xyz.com:1234"), + ("tcp", ("xyz.com",), False, "tcp://xyz.com"), + ("tcp", ("xyz.com", 1234), False, "tcp://xyz.com:1234"), + ("tcp", ["xyz.com"], False, "tcp://xyz.com"), + ("tcp", ["xyz.com", 1234], False, "tcp://xyz.com:1234"), + ("tcp", {"host": "xyz.com"}, False, "tcp://xyz.com"), + ("tcp", {"host": "xyz.com", "port": 1234}, False, "tcp://xyz.com:1234"), + ("stcp", {"host": "xyz.com"}, False, "tcp://xyz.com"), + ("https", {"host": "xyz.com"}, False, "http://xyz.com"), + ("grpcs", {"host": "xyz.com"}, False, "grpc://xyz.com"), + ("stcp", {"host": "xyz.com"}, True, "stcp://xyz.com"), + ("https", {"host": "xyz.com"}, True, "https://xyz.com"), + ("grpcs", {"host": "xyz.com"}, True, "grpcs://xyz.com"), + ], + ) + def test_make_url(self, scheme, address, secure, expected): + result = make_url(scheme, address, secure) + assert result == expected + + @pytest.mark.parametrize( + "scheme, address, secure", + [ + ("tcp", "", False), + ("abc", "xyz.com:1234", False), + ("tcp", 1234, True), + ("grpc", [], False), + ("grpc", (), False), + ("grpc", {}, True), + ("http", [1234], False), + ("http", [1234, "xyz.com"], False), + ("http", ["xyz.com", 1234, 22], True), + ("http", (1234,), False), + ("http", (1234, "xyz.com"), False), + ("http", ("xyz.com", 1234, 22), True), + ("tcp", {"hosts": "xyz.com"}, False), + ("tcp", {"host": "xyz.com", "port": 1234, "extra": 2323}, False), + ], + ) + def test_make_url_error(self, scheme, address, secure): + with pytest.raises(ValueError): + make_url(scheme, address, secure)