From 84674e82ed2e2b415140a3e06478c47a8a9f4180 Mon Sep 17 00:00:00 2001 From: Ma Jie Yue Date: Sun, 5 Jan 2025 19:02:15 +0800 Subject: [PATCH] add exit_barrier_timeout unittest --- dlrover/python/elastic_agent/torch/training.py | 5 ++++- .../tests/test_elastic_training_agent.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 32ebfeeaf..4255359ea 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -453,7 +453,10 @@ def __init__( with_diagnostician: bool = True, ): if version_less_than_230(): - super().__init__(spec, exit_barrier_timeout) + super().__init__( + spec=spec, + exit_barrier_timeout=exit_barrier_timeout, + ) else: super().__init__( spec=spec, diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 3369ae599..9c3d18359 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -219,6 +219,24 @@ def _set_store(store): self.assertEqual(store.get("MASTER_ADDR").decode(), "127.0.0.1") self.assertEqual(store.get("MASTER_PORT").decode(), "12345") + def test_exit_barrier(self): + agent = ElasticTrainingAgent( + node_rank=0, + config=self.config, + entrypoint="python", + spec=self.spec, + start_method=self.config.start_method, + log_dir=self.config.log_dir, + exit_barrier_timeout=1, + ) + self.rdzv_handler._client._node_id = 1 + self.rdzv_handler._client.join_rendezvous( + 1, 8, self.rdzv_handler._name + ) + agent._client._node_id = 0 + agent._rendezvous(agent._worker_group) + agent._exit_barrier() + def test_get_local_ip(self): local_ip = _get_local_ip() self.assertNotEqual(local_ip, "")