From 962c1bcac559af6dfa8509be7ea5054363032b4d Mon Sep 17 00:00:00 2001 From: wangzhihong Date: Mon, 21 Oct 2024 19:16:19 +0800 Subject: [PATCH] fix formatter bug (#312) --- lazyllm/components/formatter/formatterbase.py | 16 ++++++++--- lazyllm/engine/engine.py | 3 +- lazyllm/flow/flow.py | 5 ++-- tests/basic_tests/test_engine.py | 28 +++++++++++++++++++ 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/lazyllm/components/formatter/formatterbase.py b/lazyllm/components/formatter/formatterbase.py index 9d47172b..1afdd5d5 100644 --- a/lazyllm/components/formatter/formatterbase.py +++ b/lazyllm/components/formatter/formatterbase.py @@ -1,4 +1,5 @@ from ...common import LazyLLMRegisterMetaClass, package +from typing import Optional def is_number(s: str): try: @@ -29,9 +30,16 @@ class JsonLikeFormatter(LazyLLMFormatterBase): class _ListIdxes(tuple): pass class _DictKeys(tuple): pass - def __init__(self, formatter: str = None): - self._formatter = formatter + def __init__(self, formatter: Optional[str] = None): + if formatter and formatter.startswith('*['): + self._return_package = True + self._formatter = formatter.strip('*') + else: + self._return_package = False + self._formatter = formatter + if self._formatter: + assert '*' not in self._formatter, '`*` can only be used before `[` in the beginning' self._formatter = self._formatter.strip().replace('{', '[{').replace('}', '}]') self._parse_formatter() else: @@ -82,14 +90,14 @@ def _impl(data, slice): assert curr_slice.start is None and curr_slice.stop is None and curr_slice.step is None, ( 'Only {:} and [:] is supported in dict slice') curr_slice = __class__._ListIdxes(data.keys()) - elif isinstance(data, list): + elif isinstance(data, (tuple, list)): return type(data)(self._parse_py_data_by_formatter(d, slices=slices[1:]) for d in _impl(data, curr_slice)) if isinstance(curr_slice, __class__._DictKeys): return {k: self._parse_py_data_by_formatter(v, slices=slices[1:]) for k, v in _impl(data, curr_slice).items()} elif isinstance(curr_slice, __class__._ListIdxes): - tp = list if isinstance(data, dict) else type(data) + tp = package if self._return_package else list if isinstance(data, dict) else type(data) return tp(self._parse_py_data_by_formatter(r, slices=slices[1:]) for r in _impl(data, curr_slice)) else: return self._parse_py_data_by_formatter(_impl(data, curr_slice), slices=slices[1:]) diff --git a/lazyllm/engine/engine.py b/lazyllm/engine/engine.py index 7d952c16..1650d2b9 100644 --- a/lazyllm/engine/engine.py +++ b/lazyllm/engine/engine.py @@ -164,8 +164,7 @@ def make_graph(nodes: List[dict], edges: List[dict], resources: List[dict] = [], for edge in edges: if formatter := edge.get('formatter'): - assert formatter.startswith('[') and formatter.endswith(']') or \ - formatter.startswith('{') and formatter.endswith('}') + assert formatter.startswith(('*[', '[', '}')) and formatter.endswith((']', '}')) formatter = lazyllm.formatter.JsonLike(formatter) g.add_edge(engine._nodes[edge['iid']].name, engine._nodes[edge['oid']].name, formatter) diff --git a/lazyllm/flow/flow.py b/lazyllm/flow/flow.py index d9c26936..f80fa446 100644 --- a/lazyllm/flow/flow.py +++ b/lazyllm/flow/flow.py @@ -523,7 +523,7 @@ def add_edge(self, from_node, to_node, formatter=None): def topological_sort(self): in_degree = self._in_degree.copy() queue = deque([node for node in self._nodes.values() if in_degree[node] == 0]) - sorted_nodes = [] + sorted_nodes: List[Graph.Node] = [] while queue: node = queue.popleft() @@ -536,7 +536,7 @@ def topological_sort(self): if len(sorted_nodes) != len(self._nodes): raise ValueError("Graph has a cycle") - return sorted_nodes + return [n for n in sorted_nodes if (self._in_degree[n] > 0 or n.name == Graph.start_node_name)] def compute_node(self, sid, node, intermediate_results, futures): globals._init_sid(sid) @@ -548,6 +548,7 @@ def get_input(name): if name not in intermediate_results['values']: intermediate_results['values'][name] = r r = intermediate_results['values'][name] + if isinstance(r, Exception): raise r if node.inputs[name]: r = node.inputs[name]((r.args or r.kw) if isinstance(r, arguments) else r) return r diff --git a/tests/basic_tests/test_engine.py b/tests/basic_tests/test_engine.py index 45306220..fa4d4168 100644 --- a/tests/basic_tests/test_engine.py +++ b/tests/basic_tests/test_engine.py @@ -149,6 +149,34 @@ def test_engine_edge_formatter_start(self): assert engine.run(3, 1) == 5 assert engine.run(5, 3, 1) == 11 + def test_engine_formatter_end(self): + nodes = [dict(id='1', kind='Code', name='m1', args='def test(x: int):\n return x\n'), + dict(id='2', kind='Code', name='m2', args='def test1(x: int):\n return [[x, 2*x], [3*x, 4*x]]\n'), + # two unused node + dict(id='3', kind='Code', name='m3', args='def test2(x: int):\n return dict(a=1, b=x * x)\n'), + dict(id='4', kind='Code', name='m4', args='def test3(x, y, z):\n return f"{x}{y}{z}"\n')] + edges = [dict(iid='__start__', oid='1'), dict(iid='__start__', oid='2'), dict(iid='2', oid='__end__'), + dict(iid='1', oid='__end__')] + + engine = LightEngine() + engine.start(nodes, edges) + r = engine.run(1) + print(r, type(r)) + print(isinstance(r, lazyllm.package)) + + engine.reset() + + nodes = [dict(id='1', kind='Code', name='m1', args='def test(x: int):\n return x\n'), + dict(id='2', kind='Code', name='m2', args='def test1(x: int):\n return [[x, 2*x], [3*x, 4*x]]\n'), + dict(id='3', kind='JoinFormatter', name='join', args=dict(type='to_dict', names=['a', 'b']))] + edges = [dict(iid='__start__', oid='1'), dict(iid='__start__', oid='2'), dict(iid='2', oid='3'), + dict(iid='1', oid='3'), dict(iid='3', oid='__end__', formatter='*[a, b]')] + engine = LightEngine() + engine.start(nodes, edges) + r = engine.run(1) + print(r, type(r)) + print(isinstance(r, lazyllm.package)) + def test_engine_join_stack(self): nodes = [dict(id='0', kind='Code', name='c1', args='def test(x: int): return x'), dict(id='1', kind='JoinFormatter', name='join', args=dict(type='stack'))]