Skip to content

Commit

Permalink
add train task lists
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 committed Oct 18, 2024
1 parent 68fff85 commit c374c92
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
9 changes: 7 additions & 2 deletions lazyllm/engine/lightengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ def start(self, nodes: List[Dict] = [], edges: List[Dict] = [], resources: List[
self.build_node(node).func.start()
return gid

def stop(self, id, task_name: Optional[str] = None):
node = self.build_node(id)
def status(self, node_id: str, task_name: Optional[str] = None):
node = self.build_node(node_id)
assert node.kind in ('LocalLLM')
return node.func.status(task_name=task_name)

def stop(self, node_id: str, task_name: Optional[str] = None):
node = self.build_node(node_id)
if task_name:
assert node.kind in ('LocalLLM')
node.func.stop(task_name=task_name)
Expand Down
5 changes: 5 additions & 0 deletions lazyllm/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def cleanup(self):
LOG.info(f"killed job:{k}")
self.all_processes.pop(self._id)

@property
def status(self):
assert len(self.all_processes[self._id]) == 1
return self.all_processes[self._id][0].status

def wait(self):
for _, v in self.all_processes[self._id]:
v.wait()
Expand Down
9 changes: 7 additions & 2 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,9 @@ def _get_train_or_deploy_args(self, arg_cls: str, disable: List[str] = []):
if len(set(args.keys()).intersection(set(disable))) > 0:
raise ValueError(f'Key `{", ".join(disable)}` can not be set in '
'{arg_cls}_args, please pass them from Module.__init__()')
args['launcher'] = args['launcher'].clone() if args.get('launcher') else launchers.remote(sync=False)
self._launchers['default'][arg_cls] = args['launcher']
if 'url' not in args:
args['launcher'] = args['launcher'].clone() if args.get('launcher') else launchers.remote(sync=False)
self._launchers['default'][arg_cls] = args['launcher']
return args

def _get_train_tasks_impl(self, mode: Optional[str] = None, **kw):
Expand Down Expand Up @@ -666,6 +667,10 @@ def stop(self, task_name: Optional[str] = None):
launcher = self._impl._launchers['manual' if task_name else 'default'][task_name or 'deploy']
launcher.cleanup()

def status(self, task_name: Optional[str] = None):
launcher = self._impl._launchers['manual' if task_name else 'default'][task_name or 'deploy']
return launcher.status

# modify default value to ''
def prompt(self, prompt=''):
if self.base_model != '' and prompt == '' and ModelManager.get_model_type(self.base_model) != 'llm':
Expand Down

0 comments on commit c374c92

Please sign in to comment.