Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance backend with context management, error handling, and refactored code #4286

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b4064ab
Add cycle detection and management for graph vertices in run manager
ogabrielluiz Oct 10, 2024
2f3bb69
Refactor: Move AIMLEmbeddingsImpl to a new module path
ogabrielluiz Oct 10, 2024
6e913cd
Add AIMLEmbeddingsImpl class for document and query embeddings using …
ogabrielluiz Oct 10, 2024
8459565
Add agent components for action routing, decision-making, execution, …
ogabrielluiz Oct 10, 2024
a5e3d44
Add AgentContext class for managing agent state and context serializa…
ogabrielluiz Oct 10, 2024
927b359
Add new agent components to the langflow module's init file
ogabrielluiz Oct 10, 2024
3ee2686
Update `apply_on_outputs` to use `_outputs_map` in vertex base class
ogabrielluiz Oct 11, 2024
d43751b
Add _pre_run_setup method to custom component for pre-execution setup
ogabrielluiz Oct 11, 2024
86fc405
Handle non-list action types in decide_action method
ogabrielluiz Oct 11, 2024
25ccac1
Enhance AgentActionRouter with iteration control and context routing …
ogabrielluiz Oct 11, 2024
3e9332d
Fix incorrect variable usage in tool call result message formatting
ogabrielluiz Oct 11, 2024
e312f13
Add AgentActionRouter to module exports in agents package
ogabrielluiz Oct 11, 2024
818443c
Refactor cycle detection logic in graph base class
ogabrielluiz Oct 11, 2024
7405501
Add test for complex agent flow with cyclic graph validation
ogabrielluiz Oct 11, 2024
45bb9f3
Enhance readiness checks in tracing service methods
ogabrielluiz Oct 11, 2024
8efafc7
Add context management to Graph class with dotdict support
ogabrielluiz Oct 16, 2024
431bbe6
Add context management methods to custom component class
ogabrielluiz Oct 16, 2024
5d8a529
Add customizable Agent component with input/output handling and actio…
ogabrielluiz Oct 16, 2024
3c3355b
Handle non-list 'tools' attribute in 'build_context' method
ogabrielluiz Oct 17, 2024
0bf83c2
Convert `get_response` method to asynchronous and update graph proces…
ogabrielluiz Oct 17, 2024
b92300b
Add async test for Agent component in graph cycle tests
ogabrielluiz Oct 17, 2024
4861084
Refactor Agent Flow JSON: Simplify input types and update agent compo…
ogabrielluiz Oct 17, 2024
7143208
[autofix.ci] apply automated fixes
autofix-ci[bot] Oct 21, 2024
c556074
Add Agent import to init, improve error handling, and clean up imports
ogabrielluiz Oct 23, 2024
438ac6b
Refactor agent component imports for improved modularity and organiza…
ogabrielluiz Oct 23, 2024
5024e6b
Remove agent components and update `__init__.py` exports
ogabrielluiz Oct 25, 2024
d5f03a2
remove agent flow
ogabrielluiz Oct 25, 2024
c48aea4
Add iteration control and default route options to ConditionalRouter …
ogabrielluiz Oct 25, 2024
cda8540
Refactor graph tests to include new components and update iteration l…
ogabrielluiz Oct 25, 2024
70fa7f5
Refactor conditional router to return message consistently and use it…
ogabrielluiz Oct 25, 2024
4ad2e26
readd agent flow starter project
ogabrielluiz Oct 25, 2024
8f40330
Add return type annotations to methods in langsmith.py
ogabrielluiz Oct 25, 2024
ee72767
Remove unnecessary `@override` decorator and add `# noqa: ARG002` com…
ogabrielluiz Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/backend/base/langflow/base/agents/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from datetime import datetime, timezone
from typing import Any

from langchain_core.language_models import BaseLanguageModel, BaseLLM
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic import BaseModel, Field, field_validator, model_serializer

from langflow.field_typing import LanguageModel
from langflow.schema.data import Data


class AgentContext(BaseModel):
tools: dict[str, Any]
llm: Any
context: str = ""
iteration: int = 0
max_iterations: int = 5
thought: str = ""
last_action: Any = None
last_action_result: Any = None
final_answer: Any = ""
context_history: list[tuple[str, str, str]] = Field(default_factory=list)

@model_serializer(mode="plain")
def serialize_agent_context(self):
serliazed_llm = self.llm.to_json() if hasattr(self.llm, "to_json") else str(self.llm)
serliazed_tools = {k: v.to_json() if hasattr(v, "to_json") else str(v) for k, v in self.tools.items()}
return {
"tools": serliazed_tools,
"llm": serliazed_llm,
"context": self.context,
"iteration": self.iteration,
"max_iterations": self.max_iterations,
"thought": self.thought,
"last_action": self.last_action.to_json()
if hasattr(self.last_action, "to_json")
else str(self.last_action),
"action_result": self.last_action_result.to_json()
if hasattr(self.last_action_result, "to_json")
else str(self.last_action_result),
"final_answer": self.final_answer,
"context_history": self.context_history,
}

@field_validator("llm", mode="before")
@classmethod
def validate_llm(cls, v) -> LanguageModel:
if not isinstance(v, BaseLLM | BaseChatModel | BaseLanguageModel):
msg = "llm must be an instance of LanguageModel"
raise TypeError(msg)
return v

def to_data_repr(self):
data_objs = []
for name, val, time_str in self.context_history:
content = val.content if hasattr(val, "content") else val
data_objs.append(Data(name=name, value=content, timestamp=time_str))

sorted_data_objs = sorted(data_objs, key=lambda x: datetime.fromisoformat(x.timestamp), reverse=True)

sorted_data_objs.append(
Data(
name="Formatted Context",
value=self.get_full_context(),
)
)
return sorted_data_objs

def _build_tools_context(self):
tool_context = ""
for tool_name, tool_obj in self.tools.items():
tool_context += f"{tool_name}: {tool_obj.description}\n"
return tool_context

def _build_init_context(self):
return f"""
{self.context}

"""

def model_post_init(self, _context: Any) -> None:
if hasattr(self.llm, "bind_tools"):
self.llm = self.llm.bind_tools(self.tools.values())
if self.context:
self.update_context("Initial Context", self.context)

def update_context(self, key: str, value: str):
self.context_history.insert(0, (key, value, datetime.now(tz=timezone.utc).astimezone().isoformat()))

def _serialize_context_history_tuple(self, context_history_tuple: tuple[str, str, str]) -> str:
name, value, _ = context_history_tuple
if hasattr(value, "content"):
value = value.content
elif hasattr(value, "log"):
value = value.log
return f"{name}: {value}"

def get_full_context(self) -> str:
context_history_reversed = self.context_history[::-1]
context_formatted = "\n".join(
[
self._serialize_context_history_tuple(context_history_tuple)
for context_history_tuple in context_history_reversed
]
)
return f"""
Context:
{context_formatted}
"""
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
try:
result_data = future.result()
if len(result_data["data"]) != 1:
msg = "Expected one embedding"
msg = f"Expected one embedding, got {len(result_data['data'])}"
raise ValueError(msg)
embeddings[index] = result_data["data"][0]["embedding"]
except (
httpx.HTTPStatusError,
httpx.RequestError,
json.JSONDecodeError,
KeyError,
ValueError,
):
logger.exception("Error occurred")
raise
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/components/embeddings/aiml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langflow.base.embeddings.aiml_embeddings import AIMLEmbeddingsImpl
from langflow.base.embeddings.model import LCEmbeddingsModel
from langflow.base.models.aiml_constants import AIML_EMBEDDING_MODELS
from langflow.components.embeddings.util import AIMLEmbeddingsImpl
from langflow.field_typing import Embeddings
from langflow.inputs.inputs import DropdownInput
from langflow.io import SecretStrInput
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from langflow.custom import Component
from langflow.io import BoolInput, DropdownInput, MessageInput, MessageTextInput, Output
from langflow.io import BoolInput, DropdownInput, IntInput, MessageInput, MessageTextInput, Output
from langflow.schema.message import Message


Expand All @@ -9,6 +9,10 @@ class ConditionalRouterComponent(Component):
icon = "equal"
name = "ConditionalRouter"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__iteration_updated = False

inputs = [
MessageTextInput(
name="input_text",
Expand Down Expand Up @@ -40,13 +44,30 @@ class ConditionalRouterComponent(Component):
display_name="Message",
info="The message to pass through either route.",
),
IntInput(
name="max_iterations",
display_name="Max Iterations",
info="The maximum number of iterations for the conditional router.",
value=10,
),
DropdownInput(
name="default_route",
display_name="Default Route",
options=["true_result", "false_result"],
info="The default route to take when max iterations are reached.",
value="false_result",
advanced=True,
),
]

outputs = [
Output(display_name="True Route", name="true_result", method="true_response"),
Output(display_name="False Route", name="false_result", method="false_response"),
]

def _pre_run_setup(self):
self.__iteration_updated = False

def evaluate_condition(self, input_text: str, match_text: str, operator: str, *, case_sensitive: bool) -> bool:
if not case_sensitive:
input_text = input_text.lower()
Expand All @@ -64,22 +85,35 @@ def evaluate_condition(self, input_text: str, match_text: str, operator: str, *,
return input_text.endswith(match_text)
return False

def iterate_and_stop_once(self, route_to_stop: str):
if not self.__iteration_updated:
_id = self._id.lower()
self.update_ctx({f"{_id}_iteration": self.ctx.get(f"{_id}_iteration", 0) + 1})
self.__iteration_updated = True
_id = self._id.lower()
if self.ctx.get(f"{_id}_iteration", 0) >= self.max_iterations and route_to_stop == self.default_route:
# We need to stop the other route
route_to_stop = "true_result" if route_to_stop == "false_result" else "false_result"
self.stop(route_to_stop)

def true_response(self) -> Message:
result = self.evaluate_condition(
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if result:
self.status = self.message
self.iterate_and_stop_once("false_result")
return self.message
self.stop("true_result")
return None # type: ignore[return-value]
self.iterate_and_stop_once("true_result")
return self.message

def false_response(self) -> Message:
result = self.evaluate_condition(
self.input_text, self.match_text, self.operator, case_sensitive=self.case_sensitive
)
if not result:
self.status = self.message
self.iterate_and_stop_once("true_result")
return self.message
self.stop("false_result")
return None # type: ignore[return-value]
self.iterate_and_stop_once("false_result")
return self.message
50 changes: 50 additions & 0 deletions src/backend/base/langflow/custom/custom_component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Component(CustomComponent):
_output_logs: dict[str, Log] = {}
_current_output: str = ""
_metadata: dict = {}
_ctx: dict = {}

def __init__(self, **kwargs) -> None:
# if key starts with _ it is a config
Expand Down Expand Up @@ -107,6 +108,53 @@ def __init__(self, **kwargs) -> None:
self.set_class_code()
self._set_output_required_inputs()

@property
def ctx(self):
if not hasattr(self, "graph") or self.graph is None:
msg = "Graph not found. Please build the graph first."
raise ValueError(msg)
return self.graph.context

def add_to_ctx(self, key: str, value: Any, *, overwrite: bool = False) -> None:
"""Add a key-value pair to the context.

Args:
key (str): The key to add.
value (Any): The value to associate with the key.
overwrite (bool, optional): Whether to overwrite the existing value. Defaults to False.

Raises:
ValueError: If the graph is not built.
"""
if not hasattr(self, "graph") or self.graph is None:
msg = "Graph not found. Please build the graph first."
raise ValueError(msg)
if key in self.graph.context and not overwrite:
msg = f"Key {key} already exists in context. Set overwrite=True to overwrite."
raise ValueError(msg)
self.graph.context.update({key: value})

def update_ctx(self, value_dict: dict[str, Any]) -> None:
"""Update the context with a dictionary of values.

Args:
value_dict (dict[str, Any]): The dictionary of values to update.

Raises:
ValueError: If the graph is not built.
"""
if not hasattr(self, "graph") or self.graph is None:
msg = "Graph not found. Please build the graph first."
raise ValueError(msg)
if not isinstance(value_dict, dict):
msg = "Value dict must be a dictionary"
raise TypeError(msg)

self.graph.context.update(value_dict)

def _pre_run_setup(self):
pass

def set_event_manager(self, event_manager: EventManager | None = None) -> None:
self._event_manager = event_manager

Expand Down Expand Up @@ -700,6 +748,8 @@ async def build_results(
async def _build_results(self):
_results = {}
_artifacts = {}
if hasattr(self, "_pre_run_setup"):
self._pre_run_setup()
if hasattr(self, "outputs"):
for output in self._outputs_map.values():
# Build the output if it's connected to some other vertex
Expand Down
Loading
Loading