diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index d6978bbbb..ebba0be08 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -240,6 +240,9 @@ def test_relaunch_node(self): manager = create_job_manager(params, SpeedMonitor()) self.assertEqual(manager._ps_relaunch_max_num, 1) manager.start() + + # reset failed nodes for testing + self.job_context._failed_nodes = {} self.assertEqual(manager._job_args.job_uuid, _MOCK_JOB_UUID) job_nodes = self.job_context.job_nodes() @@ -296,18 +299,18 @@ def test_relaunch_node(self): should_relaunch = manager._should_relaunch(node, NODE_STATE_FLOWS[6]) self.assertFalse(should_relaunch) - self.assertEqual(self.job_context.get_failed_node_cnt(), 2) + self.assertEqual(self.job_context.get_failed_node_cnt(), 0) manager.handle_training_failure( NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR ) manager.handle_training_failure( NodeType.WORKER, 0, level=TrainingExceptionLevel.NODE_ERROR ) - self.assertEqual(self.job_context.get_failed_node_cnt(), 3) + self.assertEqual(self.job_context.get_failed_node_cnt(), 1) manager.handle_training_failure( NodeType.WORKER, 1, level=TrainingExceptionLevel.NODE_ERROR ) - self.assertEqual(self.job_context.get_failed_node_cnt(), 3) + self.assertEqual(self.job_context.get_failed_node_cnt(), 2) def test_relaunch_under_deleted_event(self): params = MockK8sPSJobArgs()