Skip to content

Commit

Permalink
clean up time consuming stuff in unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Ma Jie Yue committed Jan 15, 2025
1 parent 612e8a7 commit 069ccbf
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
10 changes: 8 additions & 2 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,17 @@ class JobConstant(object):
# grpc timeout 60s
MASTER_CLIENT_GRPC_DEFAULT_TIMEOUT = 60

# master_client.check_straggler timeout
MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT = 300

# master_client.check_fault_node timeout
MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT = 300

# sleep 3s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_FAULT_TIMEOUT = 3
MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT = 1

# sleep 3s on NetworkFailureReason.WAITING_NODE
MASTER_CLIENT_CHECK_STRAGGLER_TIMEOUT = 3
MASTER_CLIENT_CHECK_STRAGGLER_SLEEP_TIMEOUT = 1

# sleep 5s before next node check round
NODE_CHECK_NEXT_ROUND_TIMEOUT = 5
Expand Down
10 changes: 6 additions & 4 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def get_comm_world(self, rdzv_name, node_rank):
result: grpc.RendezvousState = self._get(request)
return result.round, result.group, result.world

def check_fault_node(self, timeout=300):
def check_fault_node(self, timeout=3):
request = grpc.NetworkReadyRequest()
start = time.time()
while True:
Expand All @@ -397,12 +397,12 @@ def check_fault_node(self, timeout=300):
result.reason == NetworkFailureReason.WAITING_NODE
or result.reason == NetworkFailureReason.NO_INIT
) and time.time() - start < timeout:
time.sleep(JobConstant.MASTER_CLIENT_CHECK_FAULT_TIMEOUT)
time.sleep(JobConstant.MASTER_CLIENT_CHECK_FAULT_SLEEP_TIMEOUT)
continue
break
return result.nodes, result.reason

def check_straggler(self, timeout=300):
def check_straggler(self, timeout=3):
request = grpc.StragglerExistRequest()
start = time.time()
while True:
Expand All @@ -411,7 +411,9 @@ def check_straggler(self, timeout=300):
result.reason == NetworkFailureReason.WAITING_NODE
and time.time() - start < timeout
):
time.sleep(JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_TIMEOUT)
time.sleep(
JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_SLEEP_TIMEOUT
)
continue
break
return result.nodes, result.reason
Expand Down
8 changes: 6 additions & 2 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,8 +1426,12 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
elapsed_time,
)
success = success or result
fault_nodes, fault_reason = self._client.check_fault_node()
stragglers, straggler_reason = self._client.check_straggler()
fault_nodes, fault_reason = self._client.check_fault_node(
timeout=JobConstant.MASTER_CLIENT_CHECK_FAULT_NODE_TIMEOUT
)
stragglers, straggler_reason = self._client.check_straggler(
timeout=JobConstant.MASTER_CLIENT_CHECK_STRAGGLER_NODE_TIMEOUT
)
logger.info(
f"Fault nodes are: {fault_nodes} with {fault_reason} "
f" and stragglers are: {stragglers} with {straggler_reason}"
Expand Down
2 changes: 2 additions & 0 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,8 @@ def test_run_agent(self):
)

# with no fault and no stragglers
agent._client.check_fault_node = mock.MagicMock(return_value=([], ""))
agent._client.check_straggler = mock.MagicMock(return_value=([], ""))
agent._run_node_check = mock.MagicMock(return_value=(True, 100))
agent._stop_workers = mock.MagicMock(return_value=True)
self.assertTrue(agent.run())
Expand Down

0 comments on commit 069ccbf

Please sign in to comment.