Skip to content

Commit

Permalink
Merge branch 'master' into add-event-report-logs
Browse files Browse the repository at this point in the history
  • Loading branch information
BalaBalaYi authored Jan 16, 2025
2 parents b648ea4 + 227940f commit d1d35b5
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 12 deletions.
13 changes: 9 additions & 4 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,20 @@ class JobConstant(object):
INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MIN = 600
INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MAX = 3600
PENDING_NODE_TIMEOUT_DEFAULT_MIN = 600
NODE_CHECK_TIMEOUT = 300

# grpc timeout 60s
MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60

# sleep 3s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_FAULT_TIMEOUT = 1
# master_client.check_fault_node/check_straggler timeout value
# must > NODE_CHECK_TIMEOUT
MASTER_CLIENT_CHECK_NODE_TIMEOUT = 360

# sleep 3s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_STRAGGLER_TIMEOUT = 1
# sleep 1s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT = 1

# sleep 1s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_STRAGGLER_SLEEP_TIMEOUT = 1

# sleep 5s before next node check round
NODE_CHECK_NEXT_ROUND_TIMEOUT = 5
Expand Down
10 changes: 6 additions & 4 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ def check_fault_node(self, timeout=300):
result: grpc.NetworkCheckResult = self._get(request)
if (
result.reason == NetworkFailureReason.WAITING_NODE
and time.time() - start < timeout
):
time.sleep(JobConstant.MASTER_CLIENT_CHECK_FAULT_TIMEOUT)
or result.reason == NetworkFailureReason.NO_INIT
) and time.time() - start < timeout:
time.sleep(JobConstant.MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT)
continue
break
return result.nodes, result.reason
Expand All @@ -411,7 +411,9 @@ def check_straggler(self, timeout=300):
result.reason == NetworkFailureReason.WAITING_NODE
and time.time() - start < timeout
):
time.sleep(JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_TIMEOUT)
time.sleep(
JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_SLEEP_TIMEOUT
)
continue
break
return result.nodes, result.reason
Expand Down
15 changes: 12 additions & 3 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,9 @@ def __init__(
self._check_round = check_round
self._config: ElasticLaunchConfig = config

def _get_check_node_timeout(self):
return JobConstant.MASTER_CLIENT_CHECK_NODE_TIMEOUT

def run(self, role: str = DEFAULT_ROLE) -> bool:
spec = self._worker_group.spec
role = spec.role
Expand All @@ -1409,7 +1412,9 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
fault_nodes = []
stragglers = []
for i in range(self._check_round):
result, elapsed_time = self._run_node_check()
result, elapsed_time = self._run_node_check(
timeout=JobConstant.NODE_CHECK_TIMEOUT
)
elapsed_time = round(elapsed_time, 3)
logger.info(
f"Network check time of round {i} is {elapsed_time}"
Expand All @@ -1426,8 +1431,12 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
elapsed_time,
)
success = success or result
fault_nodes, fault_reason = self._client.check_fault_node()
stragglers, straggler_reason = self._client.check_straggler()
fault_nodes, fault_reason = self._client.check_fault_node(
timeout=self._get_check_node_timeout()
)
stragglers, straggler_reason = self._client.check_straggler(
timeout=self._get_check_node_timeout()
)
logger.info(
f"Fault nodes are: {fault_nodes} with {fault_reason} "
f" and stragglers are: {stragglers} with {straggler_reason}"
Expand Down
17 changes: 17 additions & 0 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,8 @@ def test_run_agent(self):
)

# with no fault and no stragglers
agent._client.check_fault_node = mock.MagicMock(return_value=([], ""))
agent._client.check_straggler = mock.MagicMock(return_value=([], ""))
agent._run_node_check = mock.MagicMock(return_value=(True, 100))
agent._stop_workers = mock.MagicMock(return_value=True)
self.assertTrue(agent.run())
Expand Down Expand Up @@ -821,6 +823,21 @@ def test_comm_perf_test(self, mock_run):
comm_perf_check(config, entrypoint, args)
mock_run.assert_called()

def test_get_check_node_timeout(self):
config = ElasticLaunchConfig(4, 4, 8)

agent = _create_check_agent(
config=config,
entrypoint="python",
args=[],
rdzv_name="elastic-training",
check_round=2,
)
self.assertEqual(
agent._get_check_node_timeout(),
JobConstant.MASTER_CLIENT_CHECK_NODE_TIMEOUT,
)


class MasterRendezvousHandlerTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/tests/test_master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_get(self):
self.assertEqual(len(nodes), 1)
self.assertEqual(nodes[0].type, NodeType.WORKER)

nodes, _ = self._master_client.check_fault_node()
nodes, _ = self._master_client.check_fault_node(timeout=1)
self.assertListEqual(nodes, [])

round = self._master_client.join_rendezvous(0, 8, "elastic-training")
Expand Down

0 comments on commit d1d35b5

Please sign in to comment.