diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index c96c37aa9b44..7f8b975388a8 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -29,10 +29,14 @@ class Config: Attributes ---------- retry : See `dmlc_retry` in :py:meth:`init`. - timeout : See `dmlc_timeout` in :py:meth:`init`. + + timeout : + See `dmlc_timeout` in :py:meth:`init`. This is used for both the tracker and the + communicators. + tracker_host : See :py:class:`~xgboost.tracker.RabitTracker`. + tracker_port : See :py:class:`~xgboost.tracker.RabitTracker`. - tracker_timeout : See :py:class:`~xgboost.tracker.RabitTracker`. """ @@ -41,7 +45,6 @@ class Config: tracker_host: Optional[str] = None tracker_port: Optional[int] = None - tracker_timeout: Optional[int] = None def get_comm_config(self, args: _Args) -> _Args: """Update the arguments for the communicator.""" @@ -60,7 +63,6 @@ def to_dict(self) -> _Args: "timeout", "tracker_host", "tracker_port", - "tracker_timeout", ) } @@ -85,7 +87,6 @@ def to_t(key: str, typ: T) -> Optional[T]: timeout=to_t("timeout", int()), tracker_host=to_t("tracker_host", str()), tracker_port=to_t("tracker_port", int()), - tracker_timeout=to_t("tracker_timeout", int()), ) diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index 5c98c9b8dd8a..3e5a5d0b668a 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -899,7 +899,7 @@ async def _get_rabit_args( # We assume the scheduler is a fair process and run the tracker there. env = await client.run_on_scheduler( - _start_tracker, n_workers, sched_addr, user_addr, coll_config.tracker_timeout + _start_tracker, n_workers, sched_addr, user_addr, coll_config.timeout ) env = coll_config.get_comm_config(env) return env diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 8bc1c1d4a751..d6c045241c3a 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -361,6 +361,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { Json args{Object{}}; args["dmlc_tracker_uri"] = String{host_}; args["dmlc_tracker_port"] = this->Port(); + args["dmlc_timeout"] = static_cast(this->Timeout().count()); return args; } diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 5c6554c066fb..5696f2a49e55 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -65,7 +65,7 @@ def test_federated_communicator(): port = 9091 world_size = 2 with get_reusable_executor(max_workers=world_size+1) as pool: - kwargs={"port": port, "n_workers": world_size, "blocking": False} + kwargs = {"port": port, "n_workers": world_size, "blocking": False} tracker = pool.submit(federated.run_federated_server, **kwargs) if not tracker.running(): raise RuntimeError("Error starting Federated Learning server") @@ -82,7 +82,7 @@ def test_federated_communicator(): def test_config_serialization() -> None: cfg = Config( - retry=1, timeout=2, tracker_host="127.0.0.1", tracker_port=None, tracker_timeout=3 + retry=1, timeout=2, tracker_host="127.0.0.1", tracker_port=None ) cfg1 = Config.from_dict(cfg.to_dict()) assert cfg == cfg1