From ba305dc91f78471776038c191d974e7e9ee49c90 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Mon, 6 May 2024 21:06:28 +0400 Subject: [PATCH] refactor callbacks --- libs/superagent/app/agents/llm.py | 251 ++++++++++++++++--------- libs/superagent/app/api/agents.py | 6 +- libs/superagent/app/api/workflows.py | 109 +++++------ libs/superagent/app/utils/callbacks.py | 158 +++++++++++++--- libs/superagent/app/workflows/base.py | 35 ++-- 5 files changed, 374 insertions(+), 185 deletions(-) diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py index 63856b348..a21293f1a 100644 --- a/libs/superagent/app/agents/llm.py +++ b/libs/superagent/app/agents/llm.py @@ -1,10 +1,9 @@ -import asyncio import datetime import json import logging from dataclasses import dataclass from functools import cached_property, partial -from typing import Any, cast +from typing import Any, AsyncIterator, Dict, List, Literal, cast from decouple import config from langchain_core.agents import AgentActionMessageLog @@ -33,7 +32,7 @@ from app.memory.memory_stores.redis import RedisMemoryStore from app.memory.message import MessageType from app.tools import get_tools -from app.utils.callbacks import CustomAsyncIteratorCallbackHandler +from app.utils.callbacks import AsyncCallbackHandler from app.utils.prisma import prisma from prisma.enums import LLMProvider from prisma.models import Agent @@ -127,24 +126,119 @@ async def execute_tool( ) -class LLMAgent(AgentBase): - _streaming_callback: CustomAsyncIteratorCallbackHandler +class MemoryCallbackHandler(AsyncCallbackHandler): + def __init__(self, *args, memory: BufferMemory, **kwargs): + super().__init__(*args, **kwargs) + self.memory = memory - @property - def streaming_callback(self): - return self._streaming_callback + async def on_tool_end( + self, + output: str, + ): + return await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.TOOL_RESULT, + content=output, + ) + ) + + async def on_tool_start(self, serialized: Dict[str, Any]): + return await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.TOOL_CALL, + content=serialized, + ) + ) + + async def on_tool_error(self, error: BaseException): + return await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.TOOL_RESULT, + content=str(error), + ) + ) + + async def on_llm_end(self, response: str): + return await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.AI, + content=response, + ) + ) - # TODO: call all callbacks in the list, don't distinguish between them - def _set_streaming_callback( - self, callbacks: list[CustomAsyncIteratorCallbackHandler] + async def on_human_message(self, input: str): + return await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.HUMAN, + content=input, + ) + ) + + async def on_llm_new_token(self, token: str): + pass + + async def on_llm_start(self, prompt: str): + pass + + async def on_agent_finish(self, content: str): + pass + + async def aiter(self) -> AsyncIterator[str]: + pass + + async def on_chain_end(self): + pass + + async def on_chain_start(self): + pass + + async def on_llm_error(self, response: str): + pass + + +class CallbackManager: + def __init__(self, callbacks: List[AsyncCallbackHandler]): + self.callbacks = callbacks + + async def on_event( + self, + event: Literal[ + "on_llm_start", + "on_llm_end", + "on_tool_start", + "on_tool_end", + "on_tool_error", + "on_llm_new_token", + "on_chain_start", + "on_chain_end", + "on_llm_error", + "on_human_message", + "on_agent_finish", + ], + *args, + **kwargs, ): - for callback in callbacks: - if isinstance(callback, CustomAsyncIteratorCallbackHandler): - self._streaming_callback = callback - break + for callback in self.callbacks: + if hasattr(callback, event): + await getattr(callback, event)(*args, **kwargs) - if not self._streaming_callback: - raise Exception("Streaming Callback not found") + +class LLMAgent(AgentBase): + def __init__(self, *args, **kwargs): + super().__init__( + session_id=kwargs.get("session_id"), + agent_data=kwargs.get("agent_data"), + output_schema=kwargs.get("output_schema"), + callbacks=kwargs.get("callbacks"), + enable_streaming=kwargs.get("enable_streaming"), + llm_data=kwargs.get("llm_data"), + ) + self._callback_manager = CallbackManager( + callbacks=[ + MemoryCallbackHandler(memory=self.memory), + *self.callbacks, + ] + ) @cached_property def tools(self): @@ -214,10 +308,10 @@ async def _stream_text_by_lines(self, output: str): output_by_lines = output.split("\n") if len(output_by_lines) > 1: for line in output_by_lines: - await self.streaming_callback.on_llm_new_token(line) - await self.streaming_callback.on_llm_new_token("\n") + await self._callback_manager.on_event("on_llm_new_token", line) + await self._callback_manager.on_event("on_llm_new_token", "\n") else: - await self.streaming_callback.on_llm_new_token(output_by_lines[0]) + await self._callback_manager.on_event("on_llm_new_token", output) async def get_agent(self): if self._supports_tool_calling: @@ -237,7 +331,6 @@ def __init__( super().__init__( **kwargs, ) - self._streaming_callback = None self._intermediate_steps = [] NON_STREAMING_TOOL_PROVIDERS = [ @@ -278,7 +371,11 @@ async def _execute_tools( if tool_call_res.return_direct: if self.enable_streaming: await self._stream_text_by_lines(tool_call_res.result) - self.streaming_callback.done.set() + # self._callback_manager.on_event( + # "on_llm_end", response=tool_call_res.result + # ) + await self._callback_manager.on_event("on_chain_end") + return tool_call_res.result self.messages = messages @@ -341,7 +438,9 @@ async def _process_stream_response(self, res: CustomStreamWrapper): if new_message.content: output += new_message.content if self._can_stream_directly: - await self.streaming_callback.on_llm_new_token(new_message.content) + await self._callback_manager.on_event( + "on_llm_new_token", token=new_message.content + ) chunks.append(chunk) model_response = stream_chunk_builder(chunks=chunks) @@ -361,7 +460,9 @@ async def _process_model_response(self, res: ModelResponse): new_messages.append(new_message.dict()) if new_message.content and self._can_stream_directly: - await self.streaming_callback.on_llm_new_token(new_message.content) + await self._callback_manager.on_event( + "on_llm_new_token", token=new_message.content + ) return (tool_calls, new_messages, new_message.content) @@ -369,7 +470,7 @@ async def _acompletion(self, depth: int = 0, **kwargs) -> Any: logger.info(f"Calling LLM with kwargs: {kwargs}") if kwargs.get("stream"): - await self.streaming_callback.on_llm_start() + await self._callback_manager.on_event("on_llm_start", prompt=self.prompt) # TODO: Remove this when Groq and Bedrock supports streaming with tools if ( @@ -390,25 +491,13 @@ async def _acompletion(self, depth: int = 0, **kwargs) -> Any: tool_calls, new_messages, output = result if output: - await self.memory.aadd_message( - message=BaseMessage( - type=MessageType.AI, - content=output, - ) - ) + await self._callback_manager.on_event("on_llm_end", response=output) if tool_calls: - await asyncio.gather( - *[ - self.memory.aadd_message( - message=BaseMessage( - type=MessageType.TOOL_CALL, - content=tool_call.json(), - ) - ) - for tool_call in tool_calls - ] - ) + for tool_call in tool_calls: + await self._callback_manager.on_event( + "on_tool_start", serialized=tool_call.json() + ) self.messages = new_messages @@ -429,12 +518,11 @@ async def _acompletion(self, depth: int = 0, **kwargs) -> Any: if not self._can_stream_directly and self.enable_streaming: await self._stream_text_by_lines(output) - if self.enable_streaming: - self.streaming_callback.done.set() + await self._callback_manager.on_event("on_chain_end") return output - async def ainvoke(self, input, *_, **kwargs): + async def ainvoke(self, input, **kwargs): self.input = input self.messages = [ { @@ -447,15 +535,12 @@ async def ainvoke(self, input, *_, **kwargs): }, ] - await self.memory.aadd_message( - message=BaseMessage( - type=MessageType.HUMAN, - content=self.input, - ) - ) + await self._callback_manager.on_event("on_human_message", input=self.input) - if self.enable_streaming: - self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", [])) + # for callback in self.callbacks: + # if hasattr(callback, "on_human_message"): + # print("my_hey_callback", callback, callback.on_human_message) + # await callback.on_human_message(input=self.input) output = await self._acompletion( model=self.llm_data.model, @@ -477,15 +562,6 @@ async def ainvoke(self, input, *_, **kwargs): class AgentExecutorOpenAIFunc(LLMAgent): """Agent Executor that binded with OpenAI Function Calling""" - def __init__( - self, - **kwargs, - ): - super().__init__( - **kwargs, - ) - self._streaming_callback = None - @property def messages_function_calling(self): return [ @@ -518,11 +594,6 @@ async def ainvoke(self, input, *_, **kwargs): self.input = input tool_results = [] - if self.enable_streaming: - self._set_streaming_callback( - kwargs.get("config", {}).get("callbacks", []) - ) - if len(self.tools) > 0: openai_llm = await prisma.llm.find_first( where={ @@ -576,10 +647,14 @@ async def ainvoke(self, input, *_, **kwargs): if tool_call_res.return_direct: if self.enable_streaming: await self._stream_text_by_lines(tool_call_res.result) - self.streaming_callback.done.set() + + # for callback in self.callbacks: + # await callback.on_llm_end(response=tool_call_res.result) output = tool_call_res.result + await self._callback_manager.on_event("on_chain_end") + return { "intermediate_steps": tool_results, "input": self.input, @@ -605,36 +680,44 @@ async def ainvoke(self, input, *_, **kwargs): ) if self.enable_streaming: - await self.streaming_callback.on_llm_start() + for callback in self.callbacks: + await callback.on_llm_start(prompt=self.prompt) + second_res = cast(CustomStreamWrapper, second_res) async for chunk in second_res: token = chunk.choices[0].delta.content if token: output += token - await self.streaming_callback.on_llm_new_token(token) + print("token", token) - self.streaming_callback.done.set() + await self._callback_manager.on_event( + "on_llm_new_token", token=token + ) else: second_res = cast(ModelResponse, second_res) output = second_res.choices[0].message.content + await self._callback_manager.on_event("on_chain_end") + return { "intermediate_steps": tool_results, "input": self.input, "output": output, } finally: - await self.memory.aadd_message( - message=BaseMessage( - type=MessageType.HUMAN, - content=self.input, - ) - ) - - await self.memory.aadd_message( - message=BaseMessage( - type=MessageType.AI, - content=output, - ) - ) + await self._callback_manager.on_even("on_human_message", input=self.input) + await self._callback_manager.on_event("on_llm_end", response=output) + + # await self.memory.aadd_message( + # message=BaseMessage( + # type=MessageType.HUMAN, + # content=self.input, + # ) + # ) + # await self.memory.aadd_message( + # message=BaseMessage( + # type=MessageType.AI, + # content=output, + # ) + # ) diff --git a/libs/superagent/app/api/agents.py b/libs/superagent/app/api/agents.py index 20e7019e8..136d7da6a 100644 --- a/libs/superagent/app/api/agents.py +++ b/libs/superagent/app/api/agents.py @@ -52,7 +52,7 @@ ) from app.utils.analytics import track_agent_invocation from app.utils.api import get_current_api_user, handle_exception -from app.utils.callbacks import CostCalcAsyncHandler, CustomAsyncIteratorCallbackHandler +from app.utils.callbacks import CostCalcCallback, CustomAsyncIteratorCallbackHandler from app.utils.helpers import stream_dict_keys from app.utils.llm import LLM_MAPPING, LLM_PROVIDER_MAPPING from app.utils.prisma import prisma @@ -459,7 +459,7 @@ async def invoke( if not model and metadata.get("model"): model = metadata.get("model") - costCallback = CostCalcAsyncHandler(model=model) + costCallback = CostCalcCallback(model=model) monitoring_callbacks = [costCallback] @@ -559,7 +559,7 @@ async def send_message( input = body.input enable_streaming = body.enableStreaming output_schema = body.outputSchema or agent_data.outputSchema - cost_callback = CostCalcAsyncHandler(model=model) + cost_callback = CostCalcCallback(model=model) streaming_callback = CustomAsyncIteratorCallbackHandler() agent_base = AgentFactory( diff --git a/libs/superagent/app/api/workflows.py b/libs/superagent/app/api/workflows.py index cef7ad71c..f2289b9e4 100644 --- a/libs/superagent/app/api/workflows.py +++ b/libs/superagent/app/api/workflows.py @@ -1,10 +1,9 @@ import asyncio import json import logging -from typing import AsyncIterable +from typing import AsyncIterable, List import segment.analytics as analytics -from agentops.langchain_callback_handler import AsyncLangchainCallbackHandler from decouple import config from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse @@ -25,14 +24,13 @@ from app.utils.analytics import track_agent_invocation from app.utils.api import get_current_api_user, handle_exception from app.utils.callbacks import ( - CostCalcAsyncHandler, + CostCalcCallback, CustomAsyncIteratorCallbackHandler, - get_session_tracker_handler, ) from app.utils.helpers import stream_dict_keys from app.utils.llm import LLM_MAPPING from app.utils.prisma import prisma -from app.workflows.base import WorkflowBase +from app.workflows.base import WorkflowBase, WorkflowStep SEGMENT_WRITE_KEY = config("SEGMENT_WRITE_KEY", None) @@ -192,7 +190,7 @@ async def invoke( output_schemas = body.outputSchemas last_output_schema = body.outputSchema - workflow_steps = [] + workflow_steps: List[WorkflowStep] = [] for idx, step in enumerate(workflow_data.steps): agent_data = await prisma.agent.find_unique_or_raise( where={"id": step.agentId}, @@ -216,58 +214,52 @@ async def invoke( elif last_output_schema and idx == len(workflow_data.steps) - 1: output_schema = last_output_schema - item = { - "callbacks": { - "cost_calc": CostCalcAsyncHandler(model=llm_model), - }, - "agent_name": agent_data.name, - "output_schema": output_schema, - "agent_data": agent_data, - } - session_tracker_handler = get_session_tracker_handler( - workflow_data.id, agent_data.id, session_id, api_user.id + workflow_step = WorkflowStep( + agent_data=agent_data, + output_schema=output_schema, + callbacks=[ + CostCalcCallback(model=llm_model), + ], ) - if session_tracker_handler: - item["callbacks"]["session_tracker"] = session_tracker_handler - - if enable_streaming: - item["callbacks"]["streaming"] = CustomAsyncIteratorCallbackHandler() - - workflow_steps.append(item) - workflow_callbacks = [] + # logging_handlers = get_logging_handlers( + # workflow_id=workflow_id, + # agent_id=agent_data.id, + # session_id=session_id, + # user_id=api_user.id, + # ) - for s in workflow_steps: - callbacks = [] - for _, v in s["callbacks"].items(): - callbacks.append(v) - workflow_callbacks.append(callbacks) + # workflow_step.callbacks.extend(logging_handlers) - agentops_api_key = config("AGENTOPS_API_KEY", default=None) - agentops_org_key = config("AGENTOPS_ORG_KEY", default=None) + if enable_streaming: + workflow_step.callbacks.append(CustomAsyncIteratorCallbackHandler()) - agentops_handler = AsyncLangchainCallbackHandler( - api_key=agentops_api_key, org_key=agentops_org_key, tags=[session_id] - ) + workflow_steps.append(workflow_step) workflow = WorkflowBase( workflow_steps=workflow_steps, enable_streaming=enable_streaming, - callbacks=workflow_callbacks, - constructor_callbacks=[agentops_handler], session_id=session_id, ) def track_invocation(output): for index, workflow_step in enumerate(workflow_steps): workflow_step_result = output.get("steps")[index] - cost_callback = workflow_step["callbacks"]["cost_calc"] - agent = workflow_steps[index]["agent_data"] + for callback in workflow_step.callbacks: + if isinstance(callback, CostCalcCallback): + cost_callback = callback + break + + if not cost_callback: + logger.warning( + f"Cost callback not found for step {workflow_step.agent_data.name}" + ) + continue track_agent_invocation( { "workflow_id": workflow_id, - "agent": agent, + "agent": workflow_step.agent_data, "user_id": api_user.id, "session_id": session_id, **workflow_step_result, @@ -282,17 +274,19 @@ async def send_message() -> AsyncIterable[str]: try: task = asyncio.ensure_future(workflow.arun(input)) for workflow_step in workflow_steps: - output_schema = workflow_step["output_schema"] - # we are not streaming token by token if output schema is set schema_tokens = "" - async for token in workflow_step["callbacks"]["streaming"].aiter(): + for callback in workflow_step.callbacks: + if isinstance(callback, CustomAsyncIteratorCallbackHandler): + streaming_callback = callback + break + + async for token in streaming_callback.aiter(): if not output_schema: - agent_name = workflow_step["agent_name"] async for val in stream_dict_keys( { - "id": agent_name, + "id": workflow_step.agent_data.name, "data": token, } ): @@ -313,10 +307,9 @@ async def send_message() -> AsyncIterable[str]: # stream line by line to prevent streaming large data in one go for line in json.dumps(parsed_res).split("\n"): - agent_name = workflow_step["agent_name"] async for val in stream_dict_keys( { - "id": agent_name, + "id": workflow_step.agent_data.name, "data": line, } ): @@ -344,7 +337,7 @@ async def send_message() -> AsyncIterable[str]: { "event": "function_call", "data": { - "step_name": workflow_step["agent_name"], + "step_name": workflow_step.agent_data.name, "function": function, "args": json.dumps(args), "response": tool_response, @@ -358,11 +351,10 @@ async def send_message() -> AsyncIterable[str]: if SEGMENT_WRITE_KEY: for workflow_step in workflow_steps: - agent = workflow_step["agent_data"] track_agent_invocation( { "workflow_id": workflow_id, - "agent": agent, + "agent": workflow_step.agent_data, "user_id": api_user.id, "session_id": session_id, "error": str(error), @@ -371,25 +363,26 @@ async def send_message() -> AsyncIterable[str]: ) logger.exception(f"Error in send_message: {error}") - finally: - for workflow_step in workflow_steps: - workflow_step["callbacks"]["streaming"].done.set() + # finally: + # for workflow_step in workflow_steps: + # for callback in workflow_step.callbacks: + # if isinstance(callback, CustomAsyncIteratorCallbackHandler): + # callback.done.set() generator = send_message() return StreamingResponse(generator, media_type="text/event-stream") logger.info("Streaming not enabled. Invoking workflow synchronously...") - output = await workflow.arun( - input, - ) + output = await workflow.arun(input=input) if SEGMENT_WRITE_KEY: track_invocation(output) # End session - agentops_handler.ao_client.end_session( - "Success", end_state_reason="Workflow completed" - ) + # TODO: + # agentops_handler.ao_client.end_session( + # "Success", end_state_reason="Workflow completed" + # ) return {"success": True, "data": output} diff --git a/libs/superagent/app/utils/callbacks.py b/libs/superagent/app/utils/callbacks.py index 21bc63770..b7b482cc9 100644 --- a/libs/superagent/app/utils/callbacks.py +++ b/libs/superagent/app/utils/callbacks.py @@ -2,18 +2,56 @@ import asyncio import logging -from typing import Any, AsyncIterator, List, Literal, Tuple, Union, cast +from abc import ABC, abstractmethod +from typing import Any, AsyncIterator, Literal, Tuple, Union, cast +from agentops.langchain_callback_handler import AsyncLangchainCallbackHandler from decouple import config -from langchain.callbacks.base import AsyncCallbackHandler from langchain.schema.agent import AgentFinish -from langchain.schema.output import LLMResult from langfuse import Langfuse from litellm import cost_per_token, token_counter logger = logging.getLogger(__name__) +class AsyncCallbackHandler(ABC): + @abstractmethod + async def on_agent_finish(self, content: str): + pass + + @abstractmethod + async def on_llm_start(self, prompt: str): + pass + + @abstractmethod + async def on_llm_new_token(self, token: str): + pass + + @abstractmethod + async def on_llm_end(self, response: str): + pass + + @abstractmethod + async def on_chain_start(self): + pass + + @abstractmethod + async def on_chain_end(self): + pass + + @abstractmethod + async def on_llm_error(self, response: str): + pass + + @abstractmethod + async def on_human_message(self, input: str): + pass + + @abstractmethod + async def aiter(self) -> AsyncIterator[str]: + pass + + class CustomAsyncIteratorCallbackHandler(AsyncCallbackHandler): """Callback handler that returns an async iterator.""" @@ -58,14 +96,17 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: # noqa except asyncio.QueueFull: continue - async def on_llm_end(self, response, **kwargs: Any) -> None: # noqa - # TODO: - # This should be removed when Langchain has merged - # https://github.com/langchain-ai/langchain/pull/9536 - for gen_list in response.generations: - for gen in gen_list: - if gen.message.content != "": - self.done.set() + # async def on_llm_end(self, response, **kwargs: Any) -> None: # noqa + # # TODO: + # # This should be removed when Langchain has merged + # # https://github.com/langchain-ai/langchain/pull/9536 + # for gen_list in response.generations: + # for gen in gen_list: + # if gen.message.content != "": + # self.done.set() + + async def on_chain_end(self) -> None: + self.done.set() async def on_llm_error(self, *args: Any, **kwargs: Any) -> None: # noqa self.done.set() @@ -96,8 +137,20 @@ async def aiter(self) -> AsyncIterator[str]: yield token_or_done + async def on_chain_start(self, *args: Any, **kwargs: Any) -> None: + pass + + async def on_human_message(self, input: str) -> None: + pass + + async def on_llm_end(self, *args: Any, **kwargs: Any) -> None: + pass + + async def on_llm_start(self, *args: Any, **kwargs: Any) -> None: + pass -class CostCalcAsyncHandler(AsyncCallbackHandler): + +class CostCalcCallback(AsyncCallbackHandler): """Callback handler that calculates the cost of the prompt and completion.""" def __init__(self, model): @@ -109,11 +162,11 @@ def __init__(self, model): self.prompt_tokens_cost_usd: float = 0.0 self.completion_tokens_cost_usd: float = 0.0 - def on_llm_start(self, _, prompts: List[str], **kwargs: Any) -> None: # noqa - self.prompt = prompts[0] + async def on_llm_start(self, prompt: str) -> None: # noqa + self.prompt = prompt - def on_llm_end(self, llm_result: LLMResult, **kwargs: Any) -> None: # noqa - self.completion = llm_result.generations[0][0].message.content + async def on_llm_end(self, response: str) -> None: # noqa + self.completion = response completion_tokens = self._calculate_tokens(self.completion) prompt_tokens = self._calculate_tokens(self.prompt) @@ -139,12 +192,34 @@ def _calculate_cost_per_token( completion_tokens=completion_tokens, ) + async def on_agent_finish(self, content: str): + pass + + async def on_llm_new_token(self, token: str): + pass + + async def on_chain_start(self): + pass + + async def on_chain_end(self): + pass + + async def on_llm_error(self, response: str): + pass + + async def on_human_message(self, input: str): + pass + + async def aiter(self): + pass + -def get_session_tracker_handler( - workflow_id, - agent_id, - session_id, - user_id, +def get_langfuse_handler( + *, + workflow_id: str, + agent_id: str, + session_id: str, + user_id: str, ): langfuse_secret_key = config("LANGFUSE_SECRET_KEY", "") langfuse_public_key = config("LANGFUSE_PUBLIC_KEY", "") @@ -169,3 +244,44 @@ def get_session_tracker_handler( return langfuse_handler return None + + +def get_agentops_handler( + *, + session_id: str, +): + agentops_api_key = config("AGENTOPS_API_KEY", default=None) + agentops_org_key = config("AGENTOPS_ORG_KEY", default=None) + + if not agentops_api_key or not agentops_org_key: + return None + + return AsyncLangchainCallbackHandler( + api_key=agentops_api_key, org_key=agentops_org_key, tags=[session_id] + ) + + +def get_logging_handlers( + *, + workflow_id: str, + agent_id: str, + session_id: str, + user_id: str, +): + langfuse_handler = get_langfuse_handler( + workflow_id=session_id, + agent_id=session_id, + session_id=session_id, + user_id=session_id, + ) + agentops_handler = get_agentops_handler( + session_id=session_id, + ) + + callback_handlers = [] + if langfuse_handler: + callback_handlers.append(langfuse_handler) + if agentops_handler: + callback_handlers.append(agentops_handler) + + return callback_handlers diff --git a/libs/superagent/app/workflows/base.py b/libs/superagent/app/workflows/base.py index 3ddc63eb6..ce3061611 100644 --- a/libs/superagent/app/workflows/base.py +++ b/libs/superagent/app/workflows/base.py @@ -1,60 +1,57 @@ import logging +from dataclasses import dataclass from typing import Any, List from agentops.langchain_callback_handler import ( AsyncCallbackHandler, - LangchainCallbackHandler, ) from langchain.output_parsers.json import SimpleJsonOutputParser from app.agents.base import AgentFactory -from app.utils.callbacks import CustomAsyncIteratorCallbackHandler +from prisma.models import Agent logger = logging.getLogger(__name__) +@dataclass +class WorkflowStep: + agent_data: Agent + output_schema: dict[str, str] + callbacks: List[AsyncCallbackHandler] + + class WorkflowBase: def __init__( self, - workflow_steps: list[Any], - callbacks: List[CustomAsyncIteratorCallbackHandler], + workflow_steps: list[WorkflowStep], session_id: str, - constructor_callbacks: List[ - AsyncCallbackHandler | LangchainCallbackHandler - ] = None, enable_streaming: bool = False, ): self.workflow_steps = workflow_steps self.enable_streaming = enable_streaming self.session_id = session_id - self.constructor_callbacks = constructor_callbacks - self.callbacks = callbacks async def arun(self, input: Any): previous_output = input steps_output = [] - for stepIndex, step in enumerate(self.workflow_steps): - agent_data = step["agent_data"] + for _, step in enumerate(self.workflow_steps): input = previous_output - output_schema = step["output_schema"] agent_base = AgentFactory( enable_streaming=self.enable_streaming, - callbacks=self.constructor_callbacks, + callbacks=step.callbacks, session_id=self.session_id, - agent_data=agent_data, - output_schema=output_schema, + agent_data=step.agent_data, + output_schema=step.output_schema, ) agent = await agent_base.get_agent() agent_response = await agent.ainvoke( input=input, - config={ - "callbacks": self.callbacks[stepIndex], - }, + config={"callbacks": step.callbacks}, ) - if output_schema: + if step.output_schema: # TODO: throw error if output is not valid parser = SimpleJsonOutputParser() try: