Skip to content

Commit

Permalink
Fix network check status report issue. (#1447)
Browse files Browse the repository at this point in the history
* fix network check status report

* add ut
  • Loading branch information
BalaBalaYi authored Jan 22, 2025
1 parent e889a69 commit 17ab888
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 9 deletions.
39 changes: 30 additions & 9 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,24 @@ def _set_paral_config():
def _get_local_ip():
local_ip = os.getenv("POD_IP", "")
if not local_ip:
local_ip = socket.gethostbyname(_get_fq_hostname())
try:
local_ip = socket.gethostbyname(_get_fq_hostname())
except socket.gaierror:
logger.warning(
"Can not resolve host IP. " "Use default '127.0.0.1' instead."
)
local_ip = "127.0.0.1"
return local_ip


class RendezvousOutSyncError(Exception):
pass


class NodeCheckFailedError(RuntimeError):
pass


@dataclass
class ElasticLaunchConfig(LaunchConfig):
"""
Expand Down Expand Up @@ -1269,6 +1279,7 @@ def launch_agent(
)

shutdown_rdzv = True
is_node_check_failed = False
result = None
try:
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
Expand Down Expand Up @@ -1298,17 +1309,23 @@ def launch_agent(
shutdown_rdzv = False
events.record(agent.get_event_failed())
raise
except NodeCheckFailedError:
is_node_check_failed = True
raise
except Exception:
events.record(agent.get_event_failed())
raise
finally:
exc_type, exc_value, exc_traceback = sys.exc_info()
client = MasterClient.singleton_instance()
if (exc_type is not None) or (
result is not None and result.is_failed()
):
if (
(exc_type is not None)
or (result is not None and result.is_failed())
) and not is_node_check_failed:
client.report_failed_exited()
logger.info("Failed and exit.")
elif is_node_check_failed:
logger.info("Node check failed and exit.")

if shutdown_rdzv:
spec.rdzv_handler.shutdown()
Expand Down Expand Up @@ -1420,17 +1437,19 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
f"Network check time of round {i} is {elapsed_time}"
f" and succeed is {result}."
)

success = success or result
status = (
NodeEventType.NODE_CHECK_SUCCEEDED
if result
if success
else NodeEventType.NODE_CHECK_FAILED
)
self._client.report_network_check_status(
self._node_rank,
status,
elapsed_time,
)
success = success or result

fault_nodes, fault_reason = self._client.check_fault_node(
timeout=self._get_check_node_timeout()
)
Expand All @@ -1452,7 +1471,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
"No need for another round of network "
"check because the nodes is less than 3."
)
raise RuntimeError("This node is down.")
raise NodeCheckFailedError("This node is down.")
else:
# Run the next round check to detect the fault node.
time.sleep(JobConstant.NODE_CHECK_NEXT_ROUND_TIMEOUT)
Expand All @@ -1465,11 +1484,13 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
NodeErrorMessage.NETWORKER_ERROR,
level=TrainingExceptionLevel.NODE_ERROR,
)
raise RuntimeError("This node is down.")
raise NodeCheckFailedError("This node is down.")
elif self._node_rank in stragglers:
logger.warning("This node is a straggler!")
if self._config.exclude_straggler:
raise RuntimeError("The node is a straggler and exits.")
raise NodeCheckFailedError(
"The node is a straggler " "and exits."
)
return success

def _run_node_check(self, monitor_interval=3, timeout=300):
Expand Down
33 changes: 33 additions & 0 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@
ElasticTrainingAgent,
MasterRendezvousHandler,
NodeCheckElasticAgent,
NodeCheckFailedError,
RendezvousOutSyncError,
_create_check_agent,
_create_worker_spec,
_get_local_ip,
_set_paral_config,
comm_perf_check,
launch_agent,
node_health_check,
)
from dlrover.python.tests.test_utils import start_local_master
Expand Down Expand Up @@ -678,6 +680,37 @@ def test_diagnosis(self):
1,
)

@patch(
"dlrover.python.elastic_agent.master_client"
".MasterClient.report_failed_exited"
)
@patch(
"dlrover.python.elastic_agent.torch.training"
".ElasticTrainingAgent.run"
)
def test_node_status_report(self, mock_run, mock_report_failed_exited):
config = ElasticLaunchConfig(1, 1, 1)
entrypoint = "python"

mock_run.side_effect = RuntimeError("test")
mock_report_failed_exited.return_value = True
try:
launch_agent(config, entrypoint, [])
self.fail()
except RuntimeError:
self.assertTrue(True)
mock_run.assert_called_once()
mock_report_failed_exited.assert_called_once()

mock_run.side_effect = NodeCheckFailedError("test")
try:
launch_agent(config, entrypoint, [])
self.fail()
except NodeCheckFailedError:
self.assertTrue(True)
self.assertEqual(mock_run.call_count, 2)
mock_report_failed_exited.assert_called_once()


class NodeCheckElasticAgentTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit 17ab888

Please sign in to comment.