Skip to content

Commit

Permalink
fix formatter bug (LazyAGI#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Oct 21, 2024
1 parent 7cfe2b4 commit 962c1bc
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
16 changes: 12 additions & 4 deletions lazyllm/components/formatter/formatterbase.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ...common import LazyLLMRegisterMetaClass, package
from typing import Optional

def is_number(s: str):
try:
Expand Down Expand Up @@ -29,9 +30,16 @@ class JsonLikeFormatter(LazyLLMFormatterBase):
class _ListIdxes(tuple): pass
class _DictKeys(tuple): pass

def __init__(self, formatter: str = None):
self._formatter = formatter
def __init__(self, formatter: Optional[str] = None):
if formatter and formatter.startswith('*['):
self._return_package = True
self._formatter = formatter.strip('*')
else:
self._return_package = False
self._formatter = formatter

if self._formatter:
assert '*' not in self._formatter, '`*` can only be used before `[` in the beginning'
self._formatter = self._formatter.strip().replace('{', '[{').replace('}', '}]')
self._parse_formatter()
else:
Expand Down Expand Up @@ -82,14 +90,14 @@ def _impl(data, slice):
assert curr_slice.start is None and curr_slice.stop is None and curr_slice.step is None, (
'Only {:} and [:] is supported in dict slice')
curr_slice = __class__._ListIdxes(data.keys())
elif isinstance(data, list):
elif isinstance(data, (tuple, list)):
return type(data)(self._parse_py_data_by_formatter(d, slices=slices[1:])
for d in _impl(data, curr_slice))
if isinstance(curr_slice, __class__._DictKeys):
return {k: self._parse_py_data_by_formatter(v, slices=slices[1:])
for k, v in _impl(data, curr_slice).items()}
elif isinstance(curr_slice, __class__._ListIdxes):
tp = list if isinstance(data, dict) else type(data)
tp = package if self._return_package else list if isinstance(data, dict) else type(data)
return tp(self._parse_py_data_by_formatter(r, slices=slices[1:]) for r in _impl(data, curr_slice))
else: return self._parse_py_data_by_formatter(_impl(data, curr_slice), slices=slices[1:])

Expand Down
3 changes: 1 addition & 2 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def make_graph(nodes: List[dict], edges: List[dict], resources: List[dict] = [],

for edge in edges:
if formatter := edge.get('formatter'):
assert formatter.startswith('[') and formatter.endswith(']') or \
formatter.startswith('{') and formatter.endswith('}')
assert formatter.startswith(('*[', '[', '}')) and formatter.endswith((']', '}'))
formatter = lazyllm.formatter.JsonLike(formatter)
g.add_edge(engine._nodes[edge['iid']].name, engine._nodes[edge['oid']].name, formatter)

Expand Down
5 changes: 3 additions & 2 deletions lazyllm/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def add_edge(self, from_node, to_node, formatter=None):
def topological_sort(self):
in_degree = self._in_degree.copy()
queue = deque([node for node in self._nodes.values() if in_degree[node] == 0])
sorted_nodes = []
sorted_nodes: List[Graph.Node] = []

while queue:
node = queue.popleft()
Expand All @@ -536,7 +536,7 @@ def topological_sort(self):
if len(sorted_nodes) != len(self._nodes):
raise ValueError("Graph has a cycle")

return sorted_nodes
return [n for n in sorted_nodes if (self._in_degree[n] > 0 or n.name == Graph.start_node_name)]

def compute_node(self, sid, node, intermediate_results, futures):
globals._init_sid(sid)
Expand All @@ -548,6 +548,7 @@ def get_input(name):
if name not in intermediate_results['values']:
intermediate_results['values'][name] = r
r = intermediate_results['values'][name]
if isinstance(r, Exception): raise r
if node.inputs[name]:
r = node.inputs[name]((r.args or r.kw) if isinstance(r, arguments) else r)
return r
Expand Down
28 changes: 28 additions & 0 deletions tests/basic_tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,34 @@ def test_engine_edge_formatter_start(self):
assert engine.run(3, 1) == 5
assert engine.run(5, 3, 1) == 11

def test_engine_formatter_end(self):
nodes = [dict(id='1', kind='Code', name='m1', args='def test(x: int):\n return x\n'),
dict(id='2', kind='Code', name='m2', args='def test1(x: int):\n return [[x, 2*x], [3*x, 4*x]]\n'),
# two unused node
dict(id='3', kind='Code', name='m3', args='def test2(x: int):\n return dict(a=1, b=x * x)\n'),
dict(id='4', kind='Code', name='m4', args='def test3(x, y, z):\n return f"{x}{y}{z}"\n')]
edges = [dict(iid='__start__', oid='1'), dict(iid='__start__', oid='2'), dict(iid='2', oid='__end__'),
dict(iid='1', oid='__end__')]

engine = LightEngine()
engine.start(nodes, edges)
r = engine.run(1)
print(r, type(r))
print(isinstance(r, lazyllm.package))

engine.reset()

nodes = [dict(id='1', kind='Code', name='m1', args='def test(x: int):\n return x\n'),
dict(id='2', kind='Code', name='m2', args='def test1(x: int):\n return [[x, 2*x], [3*x, 4*x]]\n'),
dict(id='3', kind='JoinFormatter', name='join', args=dict(type='to_dict', names=['a', 'b']))]
edges = [dict(iid='__start__', oid='1'), dict(iid='__start__', oid='2'), dict(iid='2', oid='3'),
dict(iid='1', oid='3'), dict(iid='3', oid='__end__', formatter='*[a, b]')]
engine = LightEngine()
engine.start(nodes, edges)
r = engine.run(1)
print(r, type(r))
print(isinstance(r, lazyllm.package))

def test_engine_join_stack(self):
nodes = [dict(id='0', kind='Code', name='c1', args='def test(x: int): return x'),
dict(id='1', kind='JoinFormatter', name='join', args=dict(type='stack'))]
Expand Down

0 comments on commit 962c1bc

Please sign in to comment.