Skip to content

Commit

Permalink
Added support for streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Mar 20, 2024
1 parent 12c0e67 commit c3473d9
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 26 deletions.
120 changes: 100 additions & 20 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import inspect
import json
import os
import queue
import readline
import shutil
import threading
import uuid
from enum import Enum
from typing import List, TypedDict, Callable, Any, Dict, Literal, Union
Expand Down Expand Up @@ -92,7 +94,6 @@ def __init__(self,
self._init_agents()
self._init_threads()


def get_completion(self, message: str, message_files=None, yield_messages=True, recipient_agent=None):
"""
Retrieves the completion for a given message from the main thread.
Expand All @@ -118,26 +119,30 @@ def get_completion(self, message: str, message_files=None, yield_messages=True,

return gen

def get_completion_stream(self, message: str, event_handler: type(AgencyEventHandler), message_files=None, recipient_agent=None):
def get_completion_stream(self, message: str, event_handler: type(AgencyEventHandler), message_files=None,
recipient_agent=None):
"""
Generates a stream of completions for a given message from the main thread.
Parameters:
message (str): The message for which completion is to be retrieved.
event_handler (type(AgencyEventHandler)): The event handler class to handle the completion stream. https://github.com/openai/openai-python/blob/main/helpers.md
message_files (list, optional): A list of file ids to be sent as attachments with the message. Defaults to None.
yield_messages (bool, optional): Flag to determine if intermediate messages should be yielded. Defaults to True.
recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart.
Returns:
Final response: Final response from the main thread.
"""
if self.async_mode:
raise Exception("Streaming is not supported in async mode.")

gen = self.main_thread.get_completion_stream(message=message, event_handler=event_handler,
message_files=message_files, recipient_agent=recipient_agent)

while True:
try:
next(gen)
except StopIteration as e:
event_handler.on_all_streams_end()
return e.value

def demo_gradio(self, height=450, dark_mode=True, share=False):
Expand Down Expand Up @@ -173,6 +178,7 @@ def demo_gradio(self, height=450, dark_mode=True, share=False):
recipient_agent = self.main_recipients[0]

with gr.Blocks(js=js) as demo:
chatbot_queue = queue.Queue()
chatbot = gr.Chatbot(height=height)
with gr.Row():
with gr.Column(scale=9):
Expand Down Expand Up @@ -229,29 +235,90 @@ def user(user_message, history):

return original_user_message, history + [[user_message, None]]

class GradioEventHandler(AgencyEventHandler):
message_output = None

@override
def on_message_created(self, message: Message) -> None:
if message.role == "user":
self.message_output = MessageOutput("text", self.agent_name, self.recipient_agent_name,
"")
else:
self.message_output = MessageOutput("text", self.recipient_agent_name, self.agent_name, "")
chatbot_queue.put("[new_message]")

chatbot_queue.put(self.message_output.get_formatted_header() + "\n")

@override
def on_text_delta(self, delta, snapshot):
chatbot_queue.put(delta.value)

@override
def on_tool_call_created(self, tool_call):
chatbot_queue.put("[new_message]")
self.message_output = MessageOutput("function", self.recipient_agent_name, self.agent_name,
str(tool_call.function))

chatbot_queue.put(self.message_output.get_formatted_header() + "\n")

@override
def on_tool_call_done(self, snapshot):
chatbot_queue.put(str(snapshot.function))

@override
def on_run_step_done(self, run_step: RunStep) -> None:
if run_step.type == "tool_calls":
for tool_call in run_step.step_details.tool_calls:
if tool_call.function.name == "SendMessage":
continue

self.message_output = MessageOutput("function_output", tool_call.function.name,
self.recipient_agent_name,
tool_call.function.output)

chatbot_queue.put(self.message_output.get_formatted_header() + "\n")
chatbot_queue.put(tool_call.function.output)
@override
@classmethod
def on_all_streams_end(cls):
self.message_output = None
chatbot_queue.put("[end]")

def bot(original_message, history):
nonlocal message_file_ids
nonlocal message_file_names
nonlocal recipient_agent
print("Message files: ", message_file_ids)
# Replace this with your actual chatbot logic
gen = self.get_completion(message=original_message, message_files=message_file_ids, recipient_agent=recipient_agent)

completion_thread = threading.Thread(target=self.get_completion_stream, args=(original_message, GradioEventHandler, message_file_ids, recipient_agent))
completion_thread.start()

message_file_ids = []
message_file_names = []
try:
# Yield each message from the generator
for bot_message in gen:
if bot_message.sender_name.lower() == "user":

new_message = True
while True:
try:
bot_message = chatbot_queue.get(block=True, timeout=10)

if bot_message == "[end]":
completion_thread.join()
break

if bot_message == "[new_message]":
new_message = True
continue

message = bot_message.get_sender_emoji() + " " + bot_message.get_formatted_content()
if new_message:
history.append([None, bot_message])
new_message = False
else:
history[-1][1] += bot_message

history.append((None, message))
yield "", history
except StopIteration:
# Handle the end of the conversation if necessary

pass
except queue.Empty:
break

button.click(
user,
Expand Down Expand Up @@ -287,26 +354,35 @@ def _setup_autocomplete(self):
"""
Sets up readline with the completer function.
"""
self.recipient_agents = [agent.name for agent in self.main_recipients] # Cache recipient agents for autocomplete
self.recipient_agents = [agent.name for agent in
self.main_recipients] # Cache recipient agents for autocomplete
readline.set_completer(self._recipient_agent_completer)
readline.parse_and_bind('tab: complete')

def run_demo(self):
"""
Executes agency in the terminal with autocomplete for recipient agent names.
"""

class TermEventHandler(AgencyEventHandler):
message_output = None

@override
def on_message_created(self, message: Message) -> None:
if message.role == "user":
self.message_output = MessageOutputLive("text", self.agent_name, self.recipient_agent_name,
"")
"")
else:
self.message_output = MessageOutputLive("text", self.recipient_agent_name, self.agent_name, "")

@override
def on_message_done(self, message: Message) -> None:
self.message_output = None

@override
def on_text_delta(self, delta, snapshot):
self.message_output.cprint_update(snapshot.value)

@override
def on_tool_call_created(self, tool_call):
self.message_output = MessageOutputLive("function", self.recipient_agent_name, self.agent_name,
Expand All @@ -316,6 +392,10 @@ def on_tool_call_created(self, tool_call):
def on_tool_call_delta(self, delta, snapshot):
self.message_output.cprint_update(str(snapshot.function))

@override
def on_tool_call_done(self, snapshot):
self.message_output = None

@override
def on_run_step_done(self, run_step: RunStep) -> None:
if run_step.type == "tool_calls":
Expand Down Expand Up @@ -348,7 +428,8 @@ def on_end(self):
recipient_agent = text.split("@")[1].split(" ")[0]
text = text.replace(f"@{recipient_agent}", "").strip()
try:
recipient_agent = [agent for agent in self.recipient_agents if agent.lower() == recipient_agent.lower()][0]
recipient_agent = \
[agent for agent in self.recipient_agents if agent.lower() == recipient_agent.lower()][0]
recipient_agent = self._get_agent_by_name(recipient_agent)
except Exception as e:
print(f"Recipient agent {recipient_agent} not found.")
Expand Down Expand Up @@ -492,7 +573,7 @@ def _parse_agency_chart(self, agency_chart):
self.agents_and_threads[agent.name] = {}

if i < len(node) - 1:
other_agent = node[i+1]
other_agent = node[i + 1]
if other_agent.name == agent.name:
continue
if other_agent.name not in self.agents_and_threads[agent.name].keys():
Expand Down Expand Up @@ -633,8 +714,7 @@ def run(self):
except StopIteration as e:
message = e.value
else:
message = thread.get_completion(message=self.message, message_files=self.message_files,
event_handler=self.event_handler)
message = thread.get_completion_async(message=self.message, message_files=self.message_files)

return message or ""

Expand Down
43 changes: 43 additions & 0 deletions docs/advanced-usage/agencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,49 @@ agency = Agency([

All agents added inside the top level list of `agency_chart` without being part of a second list, can talk to the user.

## Streaming Responses

To stream the conversation between agents, you can use the `get_completion_stream` method with your event handler like below. The process is extremely similar to the one in the [official documentation](https://platform.openai.com/docs/assistants/overview/step-4-create-a-run?context=with-streaming).

The only difference is that you must extend the `AgencyEventHandler` class, which has 2 additional properties: `agent_name` and `recipient_agent_name`, to get the names of the agents communicating with each other. (See the `on_text_created` below.)


```python
from typing_extensions import override
from agency_swarm.lib.streaming import AgencyEventHandler

class EventHandler(AgencyEventHandler):
@override
def on_text_created(self, text) -> None:
# get the name of the agent that is sending the message
print(f"\n{self.recipient_agent_name} @ {self.agent_name} > ", end="", flush=True)

@override
def on_text_delta(self, delta, snapshot):
print(delta.value, end="", flush=True)

def on_tool_call_created(self, tool_call):
print(f"\n{self.recipient_agent_name} > {tool_call.type}\n", flush=True)

def on_tool_call_delta(self, delta, snapshot):
if delta.type == 'code_interpreter':
if delta.code_interpreter.input:
print(delta.code_interpreter.input, end="", flush=True)
if delta.code_interpreter.outputs:
print(f"\n\noutput >", flush=True)
for output in delta.code_interpreter.outputs:
if output.type == "logs":
print(f"\n{output.logs}", flush=True)

@classmethod
def on_all_streams_end(cls):
print("\n\nAll streams have ended.") # Conversation is over and message is returned to the user.

response = agency.get_completion_stream("I want you to build me a website", event_handler=EventHandler)
```

Also, there is an additional class method `on_all_streams_end` which is called when all streams have ended. This method is needed because, unlike in the official documentation, your event handler will be called multiple times and probably by even multiple agents.

## Asynchronous Communication

If you would like to use asynchronous communication between agents, you can specify a `async_mode` parameter. This is useful when you want your agents to execute multiple tasks concurrently. Only `threading` mode is supported for now.
Expand Down
11 changes: 5 additions & 6 deletions tests/demos/streaming_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ class TestTool(BaseTool):
def run(self):
return "Test Successful"


self.ceo = Agent(name="ceo", instructions="You are a CEO of an agency made for testing purposes.")
self.test_agent1 = Agent(name="test_agent1", tools=[TestTool])
self.test_agent2 = Agent(name="test_agent2")

self.ceo = Agent(name="ceo", instructions="You are a CEO of an agency made for testing purposes.",
model='gpt-3.5-turbo')
self.test_agent1 = Agent(name="test_agent1", tools=[TestTool], model='gpt-3.5-turbo')
self.test_agent2 = Agent(name="test_agent2", model='gpt-3.5-turbo')

self.agency = Agency([
self.ceo,
Expand All @@ -29,7 +28,7 @@ def run(self):
])

def test_demo(self):
self.agency.demo_gradio()
self.agency.run_demo()


if __name__ == '__main__':
Expand Down

0 comments on commit c3473d9

Please sign in to comment.