diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 27b79a851..a162e0fbe 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -1424,9 +1424,11 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: f"Network check time of round {i} is {elapsed_time}" f" and succeed is {result}." ) + + success = success or result status = ( NodeEventType.NODE_CHECK_SUCCEEDED - if result + if success else NodeEventType.NODE_CHECK_FAILED ) self._client.report_network_check_status( @@ -1434,7 +1436,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: status, elapsed_time, ) - success = success or result + fault_nodes, fault_reason = self._client.check_fault_node( timeout=self._get_check_node_timeout() ) @@ -1474,6 +1476,9 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: logger.warning("This node is a straggler!") if self._config.exclude_straggler: raise RuntimeError("The node is a straggler and exits.") + + + return success def _run_node_check(self, monitor_interval=3, timeout=300):