Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into refactor_failed_nod…
Browse files Browse the repository at this point in the history
…e_counter

# Conflicts:
#	dlrover/python/elastic_agent/torch/training.py
  • Loading branch information
BalaBalaYi committed Jan 22, 2025
2 parents d101246 + 17ab888 commit bd904a7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
36 changes: 26 additions & 10 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,24 @@ def _set_paral_config():
def _get_local_ip():
local_ip = os.getenv("POD_IP", "")
if not local_ip:
local_ip = socket.gethostbyname(_get_fq_hostname())
try:
local_ip = socket.gethostbyname(_get_fq_hostname())
except socket.gaierror:
logger.warning(
"Can not resolve host IP. " "Use default '127.0.0.1' instead."
)
local_ip = "127.0.0.1"
return local_ip


class RendezvousOutSyncError(Exception):
pass


class NodeCheckFailedError(RuntimeError):
pass


@dataclass
class ElasticLaunchConfig(LaunchConfig):
"""
Expand Down Expand Up @@ -1273,6 +1283,7 @@ def launch_agent(
)

shutdown_rdzv = True
is_node_check_failed = False
result = None
try:
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
Expand Down Expand Up @@ -1302,17 +1313,23 @@ def launch_agent(
shutdown_rdzv = False
events.record(agent.get_event_failed())
raise
except NodeCheckFailedError:
is_node_check_failed = True
raise
except Exception:
events.record(agent.get_event_failed())
raise
finally:
exc_type, exc_value, exc_traceback = sys.exc_info()
client = MasterClient.singleton_instance()
if (exc_type is not None) or (
result is not None and result.is_failed()
):
if (
(exc_type is not None)
or (result is not None and result.is_failed())
) and not is_node_check_failed:
client.report_failed_exited()
logger.info("Failed and exit.")
elif is_node_check_failed:
logger.info("Node check failed and exit.")

if shutdown_rdzv:
spec.rdzv_handler.shutdown()
Expand Down Expand Up @@ -1458,7 +1475,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
"No need for another round of network "
"check because the nodes is less than 3."
)
raise RuntimeError("This node is down.")
raise NodeCheckFailedError("This node is down.")
else:
# Run the next round check to detect the fault node.
time.sleep(JobConstant.NODE_CHECK_NEXT_ROUND_TIMEOUT)
Expand All @@ -1471,14 +1488,13 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
NodeErrorMessage.NETWORKER_ERROR,
level=TrainingExceptionLevel.NODE_ERROR,
)
raise RuntimeError("This node is down.")
raise NodeCheckFailedError("This node is down.")
elif self._node_rank in stragglers:
logger.warning("This node is a straggler!")
if self._config.exclude_straggler:
raise RuntimeError("The node is a straggler and exits.")



raise NodeCheckFailedError(
"The node is a straggler and exits."
)
return success

def _run_node_check(self, monitor_interval=3, timeout=300):
Expand Down
33 changes: 33 additions & 0 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@
ElasticTrainingAgent,
MasterRendezvousHandler,
NodeCheckElasticAgent,
NodeCheckFailedError,
RendezvousOutSyncError,
_create_check_agent,
_create_worker_spec,
_get_local_ip,
_set_paral_config,
comm_perf_check,
launch_agent,
node_health_check,
)
from dlrover.python.tests.test_utils import start_local_master
Expand Down Expand Up @@ -678,6 +680,37 @@ def test_diagnosis(self):
1,
)

@patch(
"dlrover.python.elastic_agent.master_client"
".MasterClient.report_failed_exited"
)
@patch(
"dlrover.python.elastic_agent.torch.training"
".ElasticTrainingAgent.run"
)
def test_node_status_report(self, mock_run, mock_report_failed_exited):
config = ElasticLaunchConfig(1, 1, 1)
entrypoint = "python"

mock_run.side_effect = RuntimeError("test")
mock_report_failed_exited.return_value = True
try:
launch_agent(config, entrypoint, [])
self.fail()
except RuntimeError:
self.assertTrue(True)
mock_run.assert_called_once()
mock_report_failed_exited.assert_called_once()

mock_run.side_effect = NodeCheckFailedError("test")
try:
launch_agent(config, entrypoint, [])
self.fail()
except NodeCheckFailedError:
self.assertTrue(True)
self.assertEqual(mock_run.call_count, 2)
mock_report_failed_exited.assert_called_once()


class NodeCheckElasticAgentTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit bd904a7

Please sign in to comment.