diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 733bfb044..2cdee6082 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -135,7 +135,13 @@ def _set_paral_config(): def _get_local_ip(): local_ip = os.getenv("POD_IP", "") if not local_ip: - local_ip = socket.gethostbyname(_get_fq_hostname()) + try: + local_ip = socket.gethostbyname(_get_fq_hostname()) + except socket.gaierror: + logger.warning( + "Can not resolve host IP. " "Use default '127.0.0.1' instead." + ) + local_ip = "127.0.0.1" return local_ip @@ -143,6 +149,10 @@ class RendezvousOutSyncError(Exception): pass +class NodeCheckFailedError(RuntimeError): + pass + + @dataclass class ElasticLaunchConfig(LaunchConfig): """ @@ -1269,6 +1279,7 @@ def launch_agent( ) shutdown_rdzv = True + is_node_check_failed = False result = None try: metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) @@ -1298,17 +1309,23 @@ def launch_agent( shutdown_rdzv = False events.record(agent.get_event_failed()) raise + except NodeCheckFailedError: + is_node_check_failed = True + raise except Exception: events.record(agent.get_event_failed()) raise finally: exc_type, exc_value, exc_traceback = sys.exc_info() client = MasterClient.singleton_instance() - if (exc_type is not None) or ( - result is not None and result.is_failed() - ): + if ( + (exc_type is not None) + or (result is not None and result.is_failed()) + ) and not is_node_check_failed: client.report_failed_exited() logger.info("Failed and exit.") + elif is_node_check_failed: + logger.info("Node check failed and exit.") if shutdown_rdzv: spec.rdzv_handler.shutdown() @@ -1420,9 +1437,11 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: f"Network check time of round {i} is {elapsed_time}" f" and succeed is {result}." ) + + success = success or result status = ( NodeEventType.NODE_CHECK_SUCCEEDED - if result + if success else NodeEventType.NODE_CHECK_FAILED ) self._client.report_network_check_status( @@ -1430,7 +1449,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: status, elapsed_time, ) - success = success or result + fault_nodes, fault_reason = self._client.check_fault_node( timeout=self._get_check_node_timeout() ) @@ -1452,7 +1471,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: "No need for another round of network " "check because the nodes is less than 3." ) - raise RuntimeError("This node is down.") + raise NodeCheckFailedError("This node is down.") else: # Run the next round check to detect the fault node. time.sleep(JobConstant.NODE_CHECK_NEXT_ROUND_TIMEOUT) @@ -1465,11 +1484,13 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: NodeErrorMessage.NETWORKER_ERROR, level=TrainingExceptionLevel.NODE_ERROR, ) - raise RuntimeError("This node is down.") + raise NodeCheckFailedError("This node is down.") elif self._node_rank in stragglers: logger.warning("This node is a straggler!") if self._config.exclude_straggler: - raise RuntimeError("The node is a straggler and exits.") + raise NodeCheckFailedError( + "The node is a straggler " "and exits." + ) return success def _run_node_check(self, monitor_interval=3, timeout=300): diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 1e960e8d0..adbbe57fd 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -59,12 +59,14 @@ ElasticTrainingAgent, MasterRendezvousHandler, NodeCheckElasticAgent, + NodeCheckFailedError, RendezvousOutSyncError, _create_check_agent, _create_worker_spec, _get_local_ip, _set_paral_config, comm_perf_check, + launch_agent, node_health_check, ) from dlrover.python.tests.test_utils import start_local_master @@ -678,6 +680,37 @@ def test_diagnosis(self): 1, ) + @patch( + "dlrover.python.elastic_agent.master_client" + ".MasterClient.report_failed_exited" + ) + @patch( + "dlrover.python.elastic_agent.torch.training" + ".ElasticTrainingAgent.run" + ) + def test_node_status_report(self, mock_run, mock_report_failed_exited): + config = ElasticLaunchConfig(1, 1, 1) + entrypoint = "python" + + mock_run.side_effect = RuntimeError("test") + mock_report_failed_exited.return_value = True + try: + launch_agent(config, entrypoint, []) + self.fail() + except RuntimeError: + self.assertTrue(True) + mock_run.assert_called_once() + mock_report_failed_exited.assert_called_once() + + mock_run.side_effect = NodeCheckFailedError("test") + try: + launch_agent(config, entrypoint, []) + self.fail() + except NodeCheckFailedError: + self.assertTrue(True) + self.assertEqual(mock_run.call_count, 2) + mock_report_failed_exited.assert_called_once() + class NodeCheckElasticAgentTest(unittest.TestCase): def setUp(self) -> None: