diff --git a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py index f0808b18b..b39ac2eaf 100644 --- a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py +++ b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py @@ -264,6 +264,7 @@ def diagnose_training_failure(self) -> NodeAction: def _report_failure_to_master(self, failures, restart_count): errors = {} if len(failures) == 0: + logger.info("Skip failure report due to empty failures") return for rank, failure in failures.items(): dt = str(datetime.fromtimestamp(int(failure.timestamp))) @@ -287,7 +288,7 @@ def send_heartbeat(self): action = self._client.report_heart_beat(ts) self._agent_context.enqueue_diagnosis_action(action) except Exception as e: - logger.warning(f"fail to report a heartbeat: {e}") + logger.warning(f"Fail to report a heartbeat: {e}") def _periodically_report(self): logger.info("Start diagnosis agent periodically reporter.") diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index 8c1f63318..7c3e376e6 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -1198,8 +1198,8 @@ def handle_training_failure( ): """Process the training failure reported by the node.""" node = self._job_context.job_node(node_type, node_id) + logger.info(f"Handle failed node: {node}") if node.is_released: - logger.info(f"The node {node.name} has been released.") return relaunch_node = self._process_error( node, restart_count, error_data, level diff --git a/dlrover/python/tests/test_diagnosis_agent.py b/dlrover/python/tests/test_diagnosis_agent.py index aad5999e0..d820d0b27 100644 --- a/dlrover/python/tests/test_diagnosis_agent.py +++ b/dlrover/python/tests/test_diagnosis_agent.py @@ -196,6 +196,15 @@ def test_send_heartbeat(self): DiagnosisActionType.RESTART_WORKER, ) + agent._client.report_heart_beat = mock.MagicMock( + side_effect=[Exception] + ) + agent.send_heartbeat() + self.assertTrue( + context._diagnosis_action_queue.next_action().action_type, + DiagnosisActionType.NONE, + ) + def test_async_thread(self): DiagnosisConstant.AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS = 1 DiagnosisConstant.AGENT_PERIODICALLY_REPORT_INTERVAL_SECS = 1