Skip to content

Commit

Permalink
Refactor client's node-check invocation timeout default value. (#1442)
Browse files Browse the repository at this point in the history
* refactor mc check node timeout value

* refactor mc check node timeout value

* lint

* fix ut

* fix ut

* revert
  • Loading branch information
BalaBalaYi authored Jan 16, 2025
1 parent 7257b60 commit 227940f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 9 deletions.
9 changes: 4 additions & 5 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,14 @@ class JobConstant(object):
INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MIN = 600
INSUFFICIENT_NODE_TIMEOUT_DEFAULT_MAX = 3600
PENDING_NODE_TIMEOUT_DEFAULT_MIN = 600
NODE_CHECK_TIMEOUT = 300

# grpc timeout 60s
MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60

# master_client.check_straggler timeout
MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT = 300

# master_client.check_fault_node timeout
MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT = 300
# master_client.check_fault_node/check_straggler timeout value
# must > NODE_CHECK_TIMEOUT
MASTER_CLIENT_CHECK_NODE_TIMEOUT = 360

# sleep 1s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT = 1
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def check_fault_node(self, timeout=300):
break
return result.nodes, result.reason

def check_straggler(self, timeout=3):
def check_straggler(self, timeout=300):
request = grpc.StragglerExistRequest()
start = time.time()
while True:
Expand Down
11 changes: 8 additions & 3 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,9 @@ def __init__(
self._check_round = check_round
self._config: ElasticLaunchConfig = config

def _get_check_node_timeout(self):
return JobConstant.MASTER_CLIENT_CHECK_NODE_TIMEOUT

def run(self, role: str = DEFAULT_ROLE) -> bool:
spec = self._worker_group.spec
role = spec.role
Expand All @@ -1409,7 +1412,9 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
fault_nodes = []
stragglers = []
for i in range(self._check_round):
result, elapsed_time = self._run_node_check()
result, elapsed_time = self._run_node_check(
timeout=JobConstant.NODE_CHECK_TIMEOUT
)
elapsed_time = round(elapsed_time, 3)
logger.info(
f"Network check time of round {i} is {elapsed_time}"
Expand All @@ -1427,10 +1432,10 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
)
success = success or result
fault_nodes, fault_reason = self._client.check_fault_node(
timeout=JobConstant.MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT
timeout=self._get_check_node_timeout()
)
stragglers, straggler_reason = self._client.check_straggler(
timeout=JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT
timeout=self._get_check_node_timeout()
)
logger.info(
f"Fault nodes are: {fault_nodes} with {fault_reason} "
Expand Down
15 changes: 15 additions & 0 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,21 @@ def test_comm_perf_test(self, mock_run):
comm_perf_check(config, entrypoint, args)
mock_run.assert_called()

def test_get_check_node_timeout(self):
config = ElasticLaunchConfig(4, 4, 8)

agent = _create_check_agent(
config=config,
entrypoint="python",
args=[],
rdzv_name="elastic-training",
check_round=2,
)
self.assertEqual(
agent._get_check_node_timeout(),
JobConstant.MASTER_CLIENT_CHECK_NODE_TIMEOUT,
)


class MasterRendezvousHandlerTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit 227940f

Please sign in to comment.