Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback Manager #990

Draft
wants to merge 1 commit into
base: feat/memory-llm-agent
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 167 additions & 84 deletions libs/superagent/app/agents/llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import asyncio
import datetime
import json
import logging
from dataclasses import dataclass
from functools import cached_property, partial
from typing import Any, cast
from typing import Any, AsyncIterator, Dict, List, Literal, cast

from decouple import config
from langchain_core.agents import AgentActionMessageLog
Expand Down Expand Up @@ -33,7 +32,7 @@
from app.memory.memory_stores.redis import RedisMemoryStore
from app.memory.message import MessageType
from app.tools import get_tools
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from app.utils.callbacks import AsyncCallbackHandler
from app.utils.prisma import prisma
from prisma.enums import LLMProvider
from prisma.models import Agent
Expand Down Expand Up @@ -127,24 +126,119 @@ async def execute_tool(
)


class LLMAgent(AgentBase):
_streaming_callback: CustomAsyncIteratorCallbackHandler
class MemoryCallbackHandler(AsyncCallbackHandler):
def __init__(self, *args, memory: BufferMemory, **kwargs):
super().__init__(*args, **kwargs)
self.memory = memory

@property
def streaming_callback(self):
return self._streaming_callback
async def on_tool_end(
self,
output: str,
):
return await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.TOOL_RESULT,
content=output,
)
)

async def on_tool_start(self, serialized: Dict[str, Any]):
return await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.TOOL_CALL,
content=serialized,
)
)

async def on_tool_error(self, error: BaseException):
return await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.TOOL_RESULT,
content=str(error),
)
)

async def on_llm_end(self, response: str):
return await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.AI,
content=response,
)
)

# TODO: call all callbacks in the list, don't distinguish between them
def _set_streaming_callback(
self, callbacks: list[CustomAsyncIteratorCallbackHandler]
async def on_human_message(self, input: str):
return await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.HUMAN,
content=input,
)
)

async def on_llm_new_token(self, token: str):
pass

async def on_llm_start(self, prompt: str):
pass

async def on_agent_finish(self, content: str):
pass

async def aiter(self) -> AsyncIterator[str]:
pass

async def on_chain_end(self):
pass

async def on_chain_start(self):
pass

async def on_llm_error(self, response: str):
pass


class CallbackManager:
def __init__(self, callbacks: List[AsyncCallbackHandler]):
self.callbacks = callbacks

async def on_event(
self,
event: Literal[
"on_llm_start",
"on_llm_end",
"on_tool_start",
"on_tool_end",
"on_tool_error",
"on_llm_new_token",
"on_chain_start",
"on_chain_end",
"on_llm_error",
"on_human_message",
"on_agent_finish",
],
*args,
**kwargs,
):
for callback in callbacks:
if isinstance(callback, CustomAsyncIteratorCallbackHandler):
self._streaming_callback = callback
break
for callback in self.callbacks:
if hasattr(callback, event):
await getattr(callback, event)(*args, **kwargs)

if not self._streaming_callback:
raise Exception("Streaming Callback not found")

class LLMAgent(AgentBase):
def __init__(self, *args, **kwargs):
super().__init__(
session_id=kwargs.get("session_id"),
agent_data=kwargs.get("agent_data"),
output_schema=kwargs.get("output_schema"),
callbacks=kwargs.get("callbacks"),
enable_streaming=kwargs.get("enable_streaming"),
llm_data=kwargs.get("llm_data"),
)
self._callback_manager = CallbackManager(
callbacks=[
MemoryCallbackHandler(memory=self.memory),
*self.callbacks,
]
)

@cached_property
def tools(self):
Expand Down Expand Up @@ -214,10 +308,10 @@ async def _stream_text_by_lines(self, output: str):
output_by_lines = output.split("\n")
if len(output_by_lines) > 1:
for line in output_by_lines:
await self.streaming_callback.on_llm_new_token(line)
await self.streaming_callback.on_llm_new_token("\n")
await self._callback_manager.on_event("on_llm_new_token", line)
await self._callback_manager.on_event("on_llm_new_token", "\n")
else:
await self.streaming_callback.on_llm_new_token(output_by_lines[0])
await self._callback_manager.on_event("on_llm_new_token", output)

async def get_agent(self):
if self._supports_tool_calling:
Expand All @@ -237,7 +331,6 @@ def __init__(
super().__init__(
**kwargs,
)
self._streaming_callback = None
self._intermediate_steps = []

NON_STREAMING_TOOL_PROVIDERS = [
Expand Down Expand Up @@ -278,7 +371,11 @@ async def _execute_tools(
if tool_call_res.return_direct:
if self.enable_streaming:
await self._stream_text_by_lines(tool_call_res.result)
self.streaming_callback.done.set()
# self._callback_manager.on_event(
# "on_llm_end", response=tool_call_res.result
# )
await self._callback_manager.on_event("on_chain_end")

return tool_call_res.result

self.messages = messages
Expand Down Expand Up @@ -341,7 +438,9 @@ async def _process_stream_response(self, res: CustomStreamWrapper):
if new_message.content:
output += new_message.content
if self._can_stream_directly:
await self.streaming_callback.on_llm_new_token(new_message.content)
await self._callback_manager.on_event(
"on_llm_new_token", token=new_message.content
)
chunks.append(chunk)

model_response = stream_chunk_builder(chunks=chunks)
Expand All @@ -361,15 +460,17 @@ async def _process_model_response(self, res: ModelResponse):
new_messages.append(new_message.dict())

if new_message.content and self._can_stream_directly:
await self.streaming_callback.on_llm_new_token(new_message.content)
await self._callback_manager.on_event(
"on_llm_new_token", token=new_message.content
)

return (tool_calls, new_messages, new_message.content)

async def _acompletion(self, depth: int = 0, **kwargs) -> Any:
logger.info(f"Calling LLM with kwargs: {kwargs}")

if kwargs.get("stream"):
await self.streaming_callback.on_llm_start()
await self._callback_manager.on_event("on_llm_start", prompt=self.prompt)

# TODO: Remove this when Groq and Bedrock supports streaming with tools
if (
Expand All @@ -390,25 +491,13 @@ async def _acompletion(self, depth: int = 0, **kwargs) -> Any:
tool_calls, new_messages, output = result

if output:
await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.AI,
content=output,
)
)
await self._callback_manager.on_event("on_llm_end", response=output)

if tool_calls:
await asyncio.gather(
*[
self.memory.aadd_message(
message=BaseMessage(
type=MessageType.TOOL_CALL,
content=tool_call.json(),
)
)
for tool_call in tool_calls
]
)
for tool_call in tool_calls:
await self._callback_manager.on_event(
"on_tool_start", serialized=tool_call.json()
)

self.messages = new_messages

Expand All @@ -429,12 +518,11 @@ async def _acompletion(self, depth: int = 0, **kwargs) -> Any:
if not self._can_stream_directly and self.enable_streaming:
await self._stream_text_by_lines(output)

if self.enable_streaming:
self.streaming_callback.done.set()
await self._callback_manager.on_event("on_chain_end")

return output

async def ainvoke(self, input, *_, **kwargs):
async def ainvoke(self, input, **kwargs):
self.input = input
self.messages = [
{
Expand All @@ -447,15 +535,12 @@ async def ainvoke(self, input, *_, **kwargs):
},
]

await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.HUMAN,
content=self.input,
)
)
await self._callback_manager.on_event("on_human_message", input=self.input)

if self.enable_streaming:
self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", []))
# for callback in self.callbacks:
# if hasattr(callback, "on_human_message"):
# print("my_hey_callback", callback, callback.on_human_message)
# await callback.on_human_message(input=self.input)

output = await self._acompletion(
model=self.llm_data.model,
Expand All @@ -477,15 +562,6 @@ async def ainvoke(self, input, *_, **kwargs):
class AgentExecutorOpenAIFunc(LLMAgent):
"""Agent Executor that binded with OpenAI Function Calling"""

def __init__(
self,
**kwargs,
):
super().__init__(
**kwargs,
)
self._streaming_callback = None

@property
def messages_function_calling(self):
return [
Expand Down Expand Up @@ -518,11 +594,6 @@ async def ainvoke(self, input, *_, **kwargs):
self.input = input
tool_results = []

if self.enable_streaming:
self._set_streaming_callback(
kwargs.get("config", {}).get("callbacks", [])
)

if len(self.tools) > 0:
openai_llm = await prisma.llm.find_first(
where={
Expand Down Expand Up @@ -576,10 +647,14 @@ async def ainvoke(self, input, *_, **kwargs):
if tool_call_res.return_direct:
if self.enable_streaming:
await self._stream_text_by_lines(tool_call_res.result)
self.streaming_callback.done.set()

# for callback in self.callbacks:
# await callback.on_llm_end(response=tool_call_res.result)

output = tool_call_res.result

await self._callback_manager.on_event("on_chain_end")

return {
"intermediate_steps": tool_results,
"input": self.input,
Expand All @@ -605,36 +680,44 @@ async def ainvoke(self, input, *_, **kwargs):
)

if self.enable_streaming:
await self.streaming_callback.on_llm_start()
for callback in self.callbacks:
await callback.on_llm_start(prompt=self.prompt)

second_res = cast(CustomStreamWrapper, second_res)

async for chunk in second_res:
token = chunk.choices[0].delta.content
if token:
output += token
await self.streaming_callback.on_llm_new_token(token)
print("token", token)

self.streaming_callback.done.set()
await self._callback_manager.on_event(
"on_llm_new_token", token=token
)
else:
second_res = cast(ModelResponse, second_res)
output = second_res.choices[0].message.content

await self._callback_manager.on_event("on_chain_end")

return {
"intermediate_steps": tool_results,
"input": self.input,
"output": output,
}
finally:
await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.HUMAN,
content=self.input,
)
)

await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.AI,
content=output,
)
)
await self._callback_manager.on_even("on_human_message", input=self.input)
await self._callback_manager.on_event("on_llm_end", response=output)

# await self.memory.aadd_message(
# message=BaseMessage(
# type=MessageType.HUMAN,
# content=self.input,
# )
# )
# await self.memory.aadd_message(
# message=BaseMessage(
# type=MessageType.AI,
# content=output,
# )
# )
Loading