diff --git a/agency_swarm/__init__.py b/agency_swarm/__init__.py index 9c6d4e1d..5416c1d5 100644 --- a/agency_swarm/__init__.py +++ b/agency_swarm/__init__.py @@ -4,3 +4,4 @@ from .util import set_openai_key from .util import set_openai_client from .util import get_openai_client +from .util.streaming import AgencyEventHandler diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 9eda4228..342c1e6d 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -3,17 +3,16 @@ import os import queue import readline -import shutil import threading import uuid from enum import Enum from typing import List, TypedDict, Callable, Any, Dict, Literal, Union -from openai.types.beta import AssistantStreamEvent from openai.types.beta.threads import Message from openai.types.beta.threads.runs import RunStep from pydantic import Field, field_validator from rich.console import Console +from typing_extensions import override from agency_swarm.agents import Agent from agency_swarm.messages import MessageOutput @@ -22,8 +21,7 @@ from agency_swarm.tools import BaseTool from agency_swarm.user import User -from agency_swarm.lib.streaming import AgencyEventHandler -from typing_extensions import override +from agency_swarm.util.streaming import AgencyEventHandler console = Console() @@ -155,6 +153,7 @@ def demo_gradio(self, height=450, dark_mode=True, share=False): share (bool, optional): Flag to determine if the interface should be shared publicly. Default is False. This method sets up and runs a Gradio interface, allowing users to interact with the agency's chatbot. It includes a text input for the user's messages and a chatbot interface for displaying the conversation. The method handles user input and chatbot responses, updating the interface dynamically. """ + try: import gradio as gr except ImportError: @@ -363,7 +362,7 @@ def run_demo(self): """ Executes agency in the terminal with autocomplete for recipient agent names. """ - + from agency_swarm import AgencyEventHandler class TermEventHandler(AgencyEventHandler): message_output = None diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 0289e75e..7f2f0910 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -1,17 +1,14 @@ -import copy import inspect import time from typing import Literal from openai import BadRequestError -from openai.types.beta.threads.runs import ToolCall +from agency_swarm.util.streaming import AgencyEventHandler from agency_swarm.agents import Agent from agency_swarm.messages import MessageOutput from agency_swarm.user import User from agency_swarm.util.oai import get_openai_client -from agency_swarm.lib.streaming import AgencyEventHandler -from typing_extensions import override class Thread: diff --git a/agency_swarm/util/__init__.py b/agency_swarm/util/__init__.py index e2e69c04..02f69b22 100644 --- a/agency_swarm/util/__init__.py +++ b/agency_swarm/util/__init__.py @@ -1,2 +1,2 @@ from .create_agent_template import create_agent_template -from .oai import set_openai_key, get_openai_client, set_openai_client +from .oai import set_openai_key, get_openai_client, set_openai_client \ No newline at end of file diff --git a/agency_swarm/util/streaming.py b/agency_swarm/util/streaming.py new file mode 100644 index 00000000..249d8a50 --- /dev/null +++ b/agency_swarm/util/streaming.py @@ -0,0 +1,14 @@ +from abc import ABC + +from openai.lib.streaming import AssistantEventHandler + + +class AgencyEventHandler(AssistantEventHandler, ABC): + agent_name = None + recipient_agent_name = None + + @classmethod + def on_all_streams_end(cls): + """Fires when streams for all agents have ended, as there can be multiple if you're agents are communicating + with each other or using tools.""" + pass diff --git a/docs/advanced-usage/agencies.md b/docs/advanced-usage/agencies.md index e9d96393..8fdeb411 100644 --- a/docs/advanced-usage/agencies.md +++ b/docs/advanced-usage/agencies.md @@ -39,7 +39,7 @@ The only difference is that you must extend the `AgencyEventHandler` class, whic ```python from typing_extensions import override -from agency_swarm.lib.streaming import AgencyEventHandler +from agency_swarm import AgencyEventHandler class EventHandler(AgencyEventHandler): @override diff --git a/tests/demos/streaming_demo.py b/tests/demos/streaming_demo.py index c5f50c05..2a190f79 100644 --- a/tests/demos/streaming_demo.py +++ b/tests/demos/streaming_demo.py @@ -1,9 +1,6 @@ import sys import unittest -from agency_swarm.lib.streaming import AgencyEventHandler -from typing_extensions import override - from agency_swarm import Agent, BaseTool from agency_swarm.agency.agency import Agency