From a249fa4dc1115bc251d67abda9ad4ccbbf48e21e Mon Sep 17 00:00:00 2001 From: "chentianyi.cty" Date: Tue, 21 Jan 2025 19:55:46 +0800 Subject: [PATCH] add ut --- .../python/elastic_agent/torch/training.py | 11 +++++-- .../tests/test_elastic_training_agent.py | 33 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index fd29fe3ff..2cdee6082 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -135,7 +135,13 @@ def _set_paral_config(): def _get_local_ip(): local_ip = os.getenv("POD_IP", "") if not local_ip: - local_ip = socket.gethostbyname(_get_fq_hostname()) + try: + local_ip = socket.gethostbyname(_get_fq_hostname()) + except socket.gaierror: + logger.warning( + "Can not resolve host IP. " "Use default '127.0.0.1' instead." + ) + local_ip = "127.0.0.1" return local_ip @@ -1315,8 +1321,7 @@ def launch_agent( if ( (exc_type is not None) or (result is not None and result.is_failed()) - and not is_node_check_failed - ): + ) and not is_node_check_failed: client.report_failed_exited() logger.info("Failed and exit.") elif is_node_check_failed: diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 1e960e8d0..adbbe57fd 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -59,12 +59,14 @@ ElasticTrainingAgent, MasterRendezvousHandler, NodeCheckElasticAgent, + NodeCheckFailedError, RendezvousOutSyncError, _create_check_agent, _create_worker_spec, _get_local_ip, _set_paral_config, comm_perf_check, + launch_agent, node_health_check, ) from dlrover.python.tests.test_utils import start_local_master @@ -678,6 +680,37 @@ def test_diagnosis(self): 1, ) + @patch( + "dlrover.python.elastic_agent.master_client" + ".MasterClient.report_failed_exited" + ) + @patch( + "dlrover.python.elastic_agent.torch.training" + ".ElasticTrainingAgent.run" + ) + def test_node_status_report(self, mock_run, mock_report_failed_exited): + config = ElasticLaunchConfig(1, 1, 1) + entrypoint = "python" + + mock_run.side_effect = RuntimeError("test") + mock_report_failed_exited.return_value = True + try: + launch_agent(config, entrypoint, []) + self.fail() + except RuntimeError: + self.assertTrue(True) + mock_run.assert_called_once() + mock_report_failed_exited.assert_called_once() + + mock_run.side_effect = NodeCheckFailedError("test") + try: + launch_agent(config, entrypoint, []) + self.fail() + except NodeCheckFailedError: + self.assertTrue(True) + self.assertEqual(mock_run.call_count, 2) + mock_report_failed_exited.assert_called_once() + class NodeCheckElasticAgentTest(unittest.TestCase): def setUp(self) -> None: