Skip to content

Commit

Permalink
Support relay - Part 1 (#3198)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

This PR implements support of using relays in building NVFLARE cellnet.
Relays are nodes in the cellnet that are only used for routing messages.
They do not have any learning functions.

This is the Part 1 of the relay support. It does not include
provisioning functions for creating startup kits for relays. These
functions will be done in future PRs.

Just like regular clients, relays also register to the Server when they
are started. Once successfully registered, the Server sends auth token
and signature to the relay.

Relay nodes also perform message authentication: relays validate auth
headers for all messages going thru them. Since all cross-site messages
must go thru either the server or relays (or both), by enforcing message
authentication at both the Server and relays ensure that no cross-site
messages can go through without valid auth headers.

Currently peers of a CellPipe can only communicate via the Server. This
PR includes an enhancement that allows peers to communicate via CP or
Relay nodes. This can make the peer-to-peer communication more
efficient.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.
  • Loading branch information
yanchengnv authored Feb 4, 2025
1 parent 96dcd0e commit de829ce
Show file tree
Hide file tree
Showing 49 changed files with 1,468 additions and 377 deletions.
28 changes: 27 additions & 1 deletion nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ class SecureTrainConst:
SSL_ROOT_CERT = "ssl_root_cert"
SSL_CERT = "ssl_cert"
PRIVATE_KEY = "ssl_private_key"
CONNECTION_SECURITY = "connection_security"


class FLMetaKey:
Expand All @@ -467,6 +466,7 @@ class FLMetaKey:

class CellMessageAuthHeaderKey:
CLIENT_NAME = "client_name"
SSID = "ssid"
TOKEN = "__token__"
TOKEN_SIGNATURE = "__token_signature__"

Expand Down Expand Up @@ -542,6 +542,8 @@ class SystemVarName:
WORKSPACE = "WORKSPACE" # directory of the workspace
JOB_ID = "JOB_ID" # Job ID
ROOT_URL = "ROOT_URL" # the URL of the Service Provider (server)
CP_URL = "CP_URL" # URL to CP
RELAY_URL = "RELAY_URL" # URL to relay that the CP is connected to
SECURE_MODE = "SECURE_MODE" # whether the system is running in secure mode
JOB_CUSTOM_DIR = "JOB_CUSTOM_DIR" # custom dir of the job
PYTHONPATH = "PYTHONPATH"
Expand All @@ -552,3 +554,27 @@ class RunnerTask:
INIT = "init"
TASK_EXEC = "task_exec"
END_RUN = "end_run"


class ConnPropKey:

PROJECT_NAME = "project_name"
SERVER_IDENTITY = "server_identity"
IDENTITY = "identity"
PARENT = "parent"
FQCN = "fqcn"
URL = "url"
SCHEME = "scheme"
ADDRESS = "address"
CONNECTION_SECURITY = "connection_security"

RELAY_CONFIG = "relay_config"
CP_CONN_PROPS = "cp_conn_props"
RELAY_CONN_PROPS = "relay_conn_props"
ROOT_CONN_PROPS = "root_conn_props"


class ConnectionSecurity:
CLEAR = "clear"
TLS = "tls"
MTLS = "mtls"
1 change: 1 addition & 0 deletions nvflare/apis/job_launcher_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class JobProcessArgs:
CLIENT_NAME = "client_name"
ROOT_URL = "root_url"
PARENT_URL = "parent_url"
PARENT_CONN_SEC = "parent_conn_sec"
SERVICE_HOST = "service_host"
SERVICE_PORT = "service_port"
HA_MODE = "ha_mode"
Expand Down
16 changes: 2 additions & 14 deletions nvflare/app_common/executors/client_api_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
import os
from typing import Optional

from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.launcher_executor import LauncherExecutor
from nvflare.app_common.utils.export_utils import update_export_props
from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file
from nvflare.client.constants import CLIENT_API_CONFIG
from nvflare.fuel.data_event.utils import get_scope_property
from nvflare.fuel.utils.attributes_exportable import ExportMode


Expand Down Expand Up @@ -126,22 +125,11 @@ def prepare_config_for_launch(self, fl_ctx: FLContext):
ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout,
}

site_name = fl_ctx.get_identity_name()
auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")

config_data = {
ConfigKey.TASK_EXCHANGE: task_exchange_attributes,
FLMetaKey.SITE_NAME: site_name,
FLMetaKey.JOB_ID: fl_ctx.get_job_id(),
FLMetaKey.AUTH_TOKEN: auth_token,
FLMetaKey.AUTH_TOKEN_SIGNATURE: signature,
}

conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY)
if conn_sec:
config_data[SecureTrainConst.CONNECTION_SECURITY] = conn_sec

update_export_props(config_data, fl_ctx)
config_file_path = self._get_external_config_file_path(fl_ctx)
write_config_to_file(config_data=config_data, config_file_path=config_file_path)

Expand Down
39 changes: 39 additions & 0 deletions nvflare/app_common/utils/export_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.apis.fl_context import FLContext
from nvflare.fuel.data_event.utils import get_scope_property


def update_export_props(props: dict, fl_ctx: FLContext):
site_name = fl_ctx.get_identity_name()
auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")

props[FLMetaKey.SITE_NAME] = site_name
props[FLMetaKey.JOB_ID] = fl_ctx.get_job_id()
props[FLMetaKey.AUTH_TOKEN] = auth_token
props[FLMetaKey.AUTH_TOKEN_SIGNATURE] = signature

root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS)
if root_conn_props:
props[ConnPropKey.ROOT_CONN_PROPS] = root_conn_props

cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS)
if cp_conn_props:
props[ConnPropKey.CP_CONN_PROPS] = cp_conn_props

relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS)
if relay_conn_props:
props[ConnPropKey.RELAY_CONN_PROPS] = relay_conn_props
15 changes: 10 additions & 5 deletions nvflare/app_common/widgets/external_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from typing import List

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLMetaKey
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.utils.export_utils import update_export_props
from nvflare.client.config import write_config_to_file
from nvflare.client.constants import CLIENT_API_CONFIG
from nvflare.fuel.utils.attributes_exportable import ExportMode, export_components
Expand Down Expand Up @@ -47,9 +48,7 @@ def __init__(
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.ABOUT_TO_START_RUN:
components_data = self._export_all_components(fl_ctx)
components_data[FLMetaKey.SITE_NAME] = fl_ctx.get_identity_name()
components_data[FLMetaKey.JOB_ID] = fl_ctx.get_job_id()

update_export_props(components_data, fl_ctx)
config_file_path = self._get_external_config_file_path(fl_ctx)
write_config_to_file(config_data=components_data, config_file_path=config_file_path)

Expand All @@ -65,5 +64,11 @@ def _export_all_components(self, fl_ctx: FLContext) -> dict:
engine = fl_ctx.get_engine()
all_components = engine.get_all_components()
components = {i: all_components.get(i) for i in self._component_ids}
reserved_keys = [FLMetaKey.SITE_NAME, FLMetaKey.JOB_ID]
reserved_keys = [
FLMetaKey.SITE_NAME,
FLMetaKey.JOB_ID,
ConnPropKey.CP_CONN_PROPS,
ConnPropKey.ROOT_CONN_PROPS,
ConnPropKey.RELAY_CONN_PROPS,
]
return export_components(components=components, reserved_keys=reserved_keys, export_mode=ExportMode.PEER)
6 changes: 5 additions & 1 deletion nvflare/app_opt/job_launcher/k8s_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@ def get_module_args(self, job_id, fl_ctx: FLContext):
def _job_args_dict(job_args: dict, arg_names: list) -> dict:
result = {}
for name in arg_names:
n, v = job_args[name]
e = job_args.get(name)
if not e:
continue

n, v = e
result[n] = v
return result

Expand Down
13 changes: 11 additions & 2 deletions nvflare/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from typing import Dict, Optional

from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.fuel.utils.config_factory import ConfigFactory


Expand Down Expand Up @@ -157,7 +157,16 @@ def get_heartbeat_timeout(self):
)

def get_connection_security(self):
return self.config.get(SecureTrainConst.CONNECTION_SECURITY)
return self.config.get(ConnPropKey.CONNECTION_SECURITY)

def get_root_conn_props(self):
return self.config.get(ConnPropKey.ROOT_CONN_PROPS)

def get_cp_conn_props(self):
return self.config.get(ConnPropKey.CP_CONN_PROPS)

def get_relay_conn_props(self):
return self.config.get(ConnPropKey.RELAY_CONN_PROPS)

def get_site_name(self):
return self.config.get(FLMetaKey.SITE_NAME)
Expand Down
17 changes: 13 additions & 4 deletions nvflare/client/ex_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Dict, Optional, Tuple

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.apis.utils.analytix_utils import create_analytic_dxo
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.client.api_spec import APISpec
Expand All @@ -39,9 +39,18 @@ def _create_client_config(config: str) -> ClientConfig:
raise ValueError(f"config should be a string but got: {type(config)}")

site_name = client_config.get_site_name()
conn_sec = client_config.get_connection_security()
if conn_sec:
set_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY, conn_sec)

root_conn_props = client_config.get_root_conn_props()
if root_conn_props:
set_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props)

cp_conn_props = client_config.get_cp_conn_props()
if cp_conn_props:
set_scope_property(site_name, ConnPropKey.CP_CONN_PROPS, cp_conn_props)

relay_conn_props = client_config.get_relay_conn_props()
if relay_conn_props:
set_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props)

# get message auth info and put them into Databus for CellPipe to use
auth_token = client_config.get_auth_token()
Expand Down
22 changes: 17 additions & 5 deletions nvflare/fuel/f3/cellnet/connector_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import time
from typing import Union

from nvflare.apis.fl_constant import ConnectionSecurity
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.cellnet.defs import ConnectorRequirementKey
from nvflare.fuel.f3.cellnet.fqcn import FqcnInfo
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.communicator import CommError, Communicator, Mode
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.security.logging import secure_format_exception, secure_format_traceback

Expand Down Expand Up @@ -48,6 +50,9 @@ def __init__(self, handle, connect_url: str, active: bool, params: dict):
def get_connection_url(self):
return self.connect_url

def get_connection_params(self):
return self.params


class ConnectorManager:
"""
Expand Down Expand Up @@ -85,6 +90,11 @@ def __init__(self, communicator: Communicator, secure: bool, comm_configurator:
self.adhoc_scheme = adhoc_conf.get(_KEY_SCHEME)
self.adhoc_resources = adhoc_conf.get(_KEY_RESOURCES)

# default conn sec
conn_sec = self.int_resources.get(DriverParams.CONNECTION_SECURITY)
if not conn_sec:
self.int_resources[DriverParams.CONNECTION_SECURITY] = ConnectionSecurity.CLEAR

self.logger.debug(f"internal scheme={self.int_scheme}, resources={self.int_resources}")
self.logger.debug(f"adhoc scheme={self.adhoc_scheme}, resources={self.adhoc_resources}")
self.comm_config = comm_config
Expand Down Expand Up @@ -152,7 +162,7 @@ def _validate_conn_config(config: dict, key: str) -> Union[None, dict]:
return conn_config

def _get_connector(
self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool
self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool, conn_resources=None
) -> Union[None, ConnectorData]:
if active and not url:
raise RuntimeError("url is required by not provided for active connector!")
Expand Down Expand Up @@ -193,10 +203,10 @@ def _get_connector(

try:
if active:
handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required)
handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required, conn_resources)
connect_url = url
elif url:
handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required)
handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required, conn_resources)
connect_url = url
else:
self.logger.info(f"{os.getpid()}: Try start_listener Listener resources: {reqs}")
Expand Down Expand Up @@ -240,11 +250,13 @@ def get_internal_listener(self) -> Union[None, ConnectorData]:
"""
return self._get_connector(url="", active=False, internal=True, adhoc=False, secure=False)

def get_internal_connector(self, url: str) -> Union[None, ConnectorData]:
def get_internal_connector(self, url: str, conn_resources=None) -> Union[None, ConnectorData]:
"""
Try to get an internal listener.
Args:
url:
"""
return self._get_connector(url=url, active=True, internal=True, adhoc=False, secure=False)
return self._get_connector(
url=url, active=True, internal=True, adhoc=False, secure=False, conn_resources=conn_resources
)
Loading

0 comments on commit de829ce

Please sign in to comment.