Skip to content

Commit

Permalink
Reuse the timeout option.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 12, 2024
1 parent 1c99225 commit d80471d
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
11 changes: 6 additions & 5 deletions python-package/xgboost/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ class Config:
Attributes
----------
retry : See `dmlc_retry` in :py:meth:`init`.
timeout : See `dmlc_timeout` in :py:meth:`init`.
timeout :
See `dmlc_timeout` in :py:meth:`init`. This is used for both the tracker and the
communicators.
tracker_host : See :py:class:`~xgboost.tracker.RabitTracker`.
tracker_port : See :py:class:`~xgboost.tracker.RabitTracker`.
tracker_timeout : See :py:class:`~xgboost.tracker.RabitTracker`.
"""

Expand All @@ -41,7 +45,6 @@ class Config:

tracker_host: Optional[str] = None
tracker_port: Optional[int] = None
tracker_timeout: Optional[int] = None

def get_comm_config(self, args: _Args) -> _Args:
"""Update the arguments for the communicator."""
Expand All @@ -60,7 +63,6 @@ def to_dict(self) -> _Args:
"timeout",
"tracker_host",
"tracker_port",
"tracker_timeout",
)
}

Expand All @@ -85,7 +87,6 @@ def to_t(key: str, typ: T) -> Optional[T]:
timeout=to_t("timeout", int()),
tracker_host=to_t("tracker_host", str()),
tracker_port=to_t("tracker_port", int()),
tracker_timeout=to_t("tracker_timeout", int()),
)


Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ async def _get_rabit_args(

# We assume the scheduler is a fair process and run the tracker there.
env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr, coll_config.tracker_timeout
_start_tracker, n_workers, sched_addr, user_addr, coll_config.timeout
)
env = coll_config.get_comm_config(env)
return env
Expand Down
1 change: 1 addition & 0 deletions src/collective/tracker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
Json args{Object{}};
args["dmlc_tracker_uri"] = String{host_};
args["dmlc_tracker_port"] = this->Port();
args["dmlc_timeout"] = static_cast<Integer::Int>(this->Timeout().count());
return args;
}

Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_federated_communicator():
port = 9091
world_size = 2
with get_reusable_executor(max_workers=world_size+1) as pool:
kwargs={"port": port, "n_workers": world_size, "blocking": False}
kwargs = {"port": port, "n_workers": world_size, "blocking": False}
tracker = pool.submit(federated.run_federated_server, **kwargs)
if not tracker.running():
raise RuntimeError("Error starting Federated Learning server")
Expand All @@ -82,7 +82,7 @@ def test_federated_communicator():

def test_config_serialization() -> None:
cfg = Config(
retry=1, timeout=2, tracker_host="127.0.0.1", tracker_port=None, tracker_timeout=3
retry=1, timeout=2, tracker_host="127.0.0.1", tracker_port=None
)
cfg1 = Config.from_dict(cfg.to_dict())
assert cfg == cfg1
Expand Down

0 comments on commit d80471d

Please sign in to comment.