Skip to content

Commit

Permalink
Added support for async agency streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed May 7, 2024
1 parent b7ee514 commit 84141ac
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
3 changes: 0 additions & 3 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ def get_completion_stream(self,
Returns:
Final response: Final response from the main thread.
"""
if self.async_mode:
raise Exception("Streaming is not supported in async mode.")

if not inspect.isclass(event_handler):
raise Exception("Event handler must not be an instance.")

Expand Down
31 changes: 29 additions & 2 deletions tests/test_agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import unittest

from openai.types.beta.threads import Text
from openai.types.beta.threads.runs import ToolCall

from agency_swarm.tools import CodeInterpreter, FileSearch
Expand Down Expand Up @@ -387,10 +388,36 @@ def test_8_async_agent_communication(self):

time.sleep(10)

message = self.__class__.agency.get_completion(
num_on_all_streams_end_calls = 0
delta_value = ""
full_text = ""

class EventHandler(AgencyEventHandler):
@override
def on_text_delta(self, delta, snapshot):
nonlocal delta_value
delta_value += delta.value

@override
def on_text_done(self, text: Text) -> None:
nonlocal full_text
full_text += text.value

@override
@classmethod
def on_all_streams_end(cls):
nonlocal num_on_all_streams_end_calls
num_on_all_streams_end_calls += 1

message = self.__class__.agency.get_completion_stream(
"Please check response. If output includes `TestAgent2's Response`, say 'success'. If the function output does not include `TestAgent2's Response`, or if you get a System Notification, or an error instead, say 'error'.",
tool_choice={"type": "function", "function": {"name": "GetResponse"}},
recipient_agent=self.__class__.agent1)
recipient_agent=self.__class__.agent1,
event_handler=EventHandler)

self.assertTrue(num_on_all_streams_end_calls == 1)

self.assertTrue(delta_value == full_text == message)

if 'error' in message.lower():
print(self.__class__.agency.get_completion("Explain why you said error."))
Expand Down

0 comments on commit 84141ac

Please sign in to comment.