Skip to content

Commit

Permalink
Merge branch 'cz/ui-uplift' of https://github.com/langflow-ai/langflow
Browse files Browse the repository at this point in the history
…into cz/ui-uplift
  • Loading branch information
Cristhianzl committed Oct 24, 2024
2 parents d93af48 + a92c903 commit 3c1a1c5
Show file tree
Hide file tree
Showing 16 changed files with 186 additions and 54 deletions.
2 changes: 2 additions & 0 deletions src/backend/base/langflow/base/tools/component_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def build_description(component: Component, output: Output) -> str:

def _build_output_function(component: Component, output_method: Callable):
def output_function(*args, **kwargs):
# set the component with the arguments
# set functionality was updatedto handle list of components and other values separately
component.set(*args, **kwargs)
return output_method()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def __init__(self, **kwargs) -> None:
self.__inputs = inputs
self.__config = config
self._reset_all_output_values()
if FEATURE_FLAGS.add_toolkit_output and hasattr(self, "_append_tool_output"):
self._append_tool_output()
super().__init__(**config)
if (FEATURE_FLAGS.add_toolkit_output) and hasattr(self, "_append_tool_output") and self.add_tool_output:
self._append_tool_output()
if hasattr(self, "_trace_type"):
self.trace_type = self._trace_type
if not hasattr(self, "trace_type"):
Expand Down Expand Up @@ -442,7 +442,8 @@ def _process_connection_or_parameter(self, key, value) -> None:

def _process_connection_or_parameters(self, key, value) -> None:
# if value is a list of components, we need to process each component
if isinstance(value, list):
# Note this update make sure it is not a list str | int | float | bool | type(None)
if isinstance(value, list) and not any(isinstance(val, str | int | float | bool | type(None)) for val in value):
for val in value:
self._process_connection_or_parameter(key, val)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class CustomComponent(BaseComponent):
is_input: bool | None = None
"""The input state of the component. Defaults to None.
If True, the component must have a field named 'input_value'."""
add_tool_output: bool | None = False
"""Indicates whether the component will be treated as a tool. Defaults to False."""
is_output: bool | None = None
"""The output state of the component. Defaults to None.
If True, the component must have a field named 'input_value'."""
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/schema/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from langflow.schema.data import Data
from langflow.schema.message import Message
from langflow.schema.schema import recursive_serialize_or_str
from langflow.schema.serialize import recursive_serialize_or_str


class ArtifactType(str, Enum):
Expand Down
2 changes: 2 additions & 0 deletions src/backend/base/langflow/schema/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from loguru import logger
from pydantic import BaseModel, model_serializer, model_validator

from langflow.schema.serialize import recursive_serialize_or_str
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,6 +200,7 @@ def __str__(self) -> str:
# return a JSON string representation of the Data atributes
try:
data = {k: v.to_json() if hasattr(v, "to_json") else v for k, v in self.data.items()}
data = recursive_serialize_or_str(data)
return json.dumps(data, indent=4)
except Exception: # noqa: BLE001
logger.opt(exception=True).debug("Error converting Data to JSON")
Expand Down
40 changes: 2 additions & 38 deletions src/backend/base/langflow/schema/schema.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from collections.abc import AsyncIterator, Generator, Iterator
from collections.abc import Generator
from enum import Enum
from typing import Literal

from loguru import logger
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import TypedDict

from langflow.schema.data import Data
from langflow.schema.message import Message
from langflow.schema.serialize import recursive_serialize_or_str

INPUT_FIELD_NAME = "input_value"

Expand Down Expand Up @@ -113,38 +112,3 @@ def build_output_logs(vertex, result) -> dict:
outputs |= {name: OutputValue(message=message, type=_type).model_dump()}

return outputs


def recursive_serialize_or_str(obj):
try:
if isinstance(obj, str):
return obj
if isinstance(obj, dict):
return {k: recursive_serialize_or_str(v) for k, v in obj.items()}
if isinstance(obj, list):
return [recursive_serialize_or_str(v) for v in obj]
if isinstance(obj, BaseModel | BaseModelV1):
if hasattr(obj, "model_dump"):
obj_dict = obj.model_dump()
elif hasattr(obj, "dict"):
obj_dict = obj.dict()
return {k: recursive_serialize_or_str(v) for k, v in obj_dict.items()}

if isinstance(obj, AsyncIterator | Generator | Iterator):
# contain memory addresses
# without consuming the iterator
# return list(obj) consumes the iterator
# return f"{obj}" this generates '<generator object BaseChatModel.stream at 0x33e9ec770>'
# it is not useful
return "Unconsumed Stream"
if hasattr(obj, "dict"):
return {k: recursive_serialize_or_str(v) for k, v in obj.dict().items()}
if hasattr(obj, "model_dump"):
return {k: recursive_serialize_or_str(v) for k, v in obj.model_dump().items()}
if isinstance(obj, type) and issubclass(obj, BaseModel):
# This a type BaseModel and not an instance of it
return repr(obj)
return str(obj)
except Exception: # noqa: BLE001
logger.debug(f"Cannot serialize object {obj}")
return str(obj)
43 changes: 43 additions & 0 deletions src/backend/base/langflow/schema/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from collections.abc import AsyncIterator, Generator, Iterator
from datetime import datetime

from loguru import logger
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1


def recursive_serialize_or_str(obj):
try:
if isinstance(obj, str):
return obj
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, dict):
return {k: recursive_serialize_or_str(v) for k, v in obj.items()}
if isinstance(obj, list):
return [recursive_serialize_or_str(v) for v in obj]
if isinstance(obj, BaseModel | BaseModelV1):
if hasattr(obj, "model_dump"):
obj_dict = obj.model_dump()
elif hasattr(obj, "dict"):
obj_dict = obj.dict()
return {k: recursive_serialize_or_str(v) for k, v in obj_dict.items()}

if isinstance(obj, AsyncIterator | Generator | Iterator):
# contain memory addresses
# without consuming the iterator
# return list(obj) consumes the iterator
# return f"{obj}" this generates '<generator object BaseChatModel.stream at 0x33e9ec770>'
# it is not useful
return "Unconsumed Stream"
if hasattr(obj, "dict"):
return {k: recursive_serialize_or_str(v) for k, v in obj.dict().items()}
if hasattr(obj, "model_dump"):
return {k: recursive_serialize_or_str(v) for k, v in obj.model_dump().items()}
if isinstance(obj, type) and issubclass(obj, BaseModel):
# This a type BaseModel and not an instance of it
return repr(obj)
return str(obj)
except Exception: # noqa: BLE001
logger.debug(f"Cannot serialize object {obj}")
return str(obj)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from langflow.custom import Component
from langflow.inputs.inputs import MessageTextInput, StrInput


@pytest.fixture
def setup_component():
# Create a sample component for testing
component = Component()
# Define inputs for the component
component.inputs = [
MessageTextInput(name="list_message_input", is_list=True), # Input for a mock component
StrInput(name="mixed_input"), # Input for a mixed list
]
return component


def test_set_with_mixed_list_input(setup_component):
component = setup_component
# Create a mock component to include in the list
mock_component = Component()
message_input_1 = "message data1"
message_input_2 = "message data2"
data = {"mixed_input": [message_input_1, message_input_2], "list_message_input": [message_input_1, mock_component]}
component.set(**data)

# Assert that the mixed input was set correctly
assert hasattr(component, "mixed_input")
assert len(component.mixed_input) == 2
assert component.mixed_input[0] == message_input_1
assert component.mixed_input[1] == message_input_2
assert component.list_message_input[0] == message_input_1
assert component.list_message_input[1] == mock_component


def test_set_with_message_text_input_list(setup_component):
component = setup_component
# Create a list of MessageTextInput instances
message_input_1 = "message data1"
message_input_2 = "message data2"
data = {"mixed_input": [message_input_1, message_input_2], "list_message_input": [message_input_1, message_input_2]}
# Set a list containing MessageTextInput instances
component.set(**data)

# Assert that the mixed input was set correctly
assert hasattr(component, "mixed_input")
assert len(component.list_message_input) == 2
assert component.list_message_input[0] == message_input_1
assert component.list_message_input[1] == message_input_2
19 changes: 17 additions & 2 deletions src/backend/tests/unit/test_custom_component_with_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ def code_component_with_multiple_outputs():
return Component(_code=code)


@pytest.fixture
def code_component_with_multiple_outputs_with_add_tool_output():
code = Path("src/backend/tests/data/component_multiple_outputs.py").read_text(encoding="utf-8")
code = code.replace(
"class MultipleOutputsComponent(Component):",
"class MultipleOutputsComponent(Component):\n add_tool_output = True",
)
return Component(_code=code)


@pytest.fixture
def component(
client, # noqa: ARG001
Expand Down Expand Up @@ -43,9 +53,14 @@ def test_list_flows_return_type(component):
assert isinstance(flows, list)


def test_feature_flags_add_toolkit_output(active_user, code_component_with_multiple_outputs):
def test_feature_flags_add_toolkit_output(
active_user, code_component_with_multiple_outputs, code_component_with_multiple_outputs_with_add_tool_output
):
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
len_outputs = len(frontnd_node_dict["outputs"])
FEATURE_FLAGS.add_toolkit_output = True
frontnd_node_dict, _ = build_custom_component_template(code_component_with_multiple_outputs, active_user.id)
code_component_with_multiple_outputs_with_add_tool_output.add_tool_output = True
frontnd_node_dict, _ = build_custom_component_template(
code_component_with_multiple_outputs_with_add_tool_output, active_user.id
)
assert len(frontnd_node_dict["outputs"]) == len_outputs + 1
2 changes: 1 addition & 1 deletion src/frontend/src/components/dropdownComponent/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ export default function Dropdown({
<ForwardedIconComponent
name="ChevronsUpDown"
className={cn(
"text-placeholder ml-2 h-4 w-4 shrink-0",
"ml-2 h-4 w-4 shrink-0 text-placeholder",
disabled ? "hover:text-placeholder" : "hover:text-foreground",
)}
/>
Expand Down
9 changes: 8 additions & 1 deletion src/frontend/src/contexts/authContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import {
LANGFLOW_REFRESH_TOKEN,
} from "@/constants/constants";
import { useGetUserData } from "@/controllers/API/queries/auth";
import { useGetGlobalVariablesMutation } from "@/controllers/API/queries/variables/use-get-mutation-global-variables";
import useAuthStore from "@/stores/authStore";
import { createContext, useEffect, useState } from "react";
import Cookies from "universal-cookie";
import { Cookies } from "react-cookie";
import { useStoreStore } from "../stores/storeStore";
import { Users } from "../types/api";
import { AuthContextType } from "../types/contexts/auth";
Expand Down Expand Up @@ -41,6 +42,7 @@ export function AuthProvider({ children }): React.ReactElement {
const setIsAuthenticated = useAuthStore((state) => state.setIsAuthenticated);

const { mutate: mutateLoggedUser } = useGetUserData();
const { mutate: mutateGetGlobalVariables } = useGetGlobalVariablesMutation();

useEffect(() => {
const storedAccessToken = cookies.get(LANGFLOW_ACCESS_TOKEN);
Expand Down Expand Up @@ -86,12 +88,17 @@ export function AuthProvider({ children }): React.ReactElement {
setAccessToken(newAccessToken);
setIsAuthenticated(true);
getUser();
getGlobalVariables();
}

function storeApiKey(apikey: string) {
setApiKey(apikey);
}

function getGlobalVariables() {
mutateGetGlobalVariables({});
}

return (
// !! to convert string to boolean
<AuthContext.Provider
Expand Down
20 changes: 13 additions & 7 deletions src/frontend/src/controllers/API/api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ import { useCustomApiHeaders } from "@/customization/hooks/use-custom-api-header
import useAuthStore from "@/stores/authStore";
import axios, { AxiosError, AxiosInstance, AxiosRequestConfig } from "axios";
import * as fetchIntercept from "fetch-intercept";
import { useContext, useEffect } from "react";
import { useEffect } from "react";
import { Cookies } from "react-cookie";
import { BuildStatus } from "../../constants/enums";
import { AuthContext } from "../../contexts/authContext";
import useAlertStore from "../../stores/alertStore";
import useFlowStore from "../../stores/flowStore";
import { checkDuplicateRequestAndStoreRequest } from "./helpers/check-duplicate-requests";
Expand All @@ -21,7 +20,14 @@ const cookies = new Cookies();
function ApiInterceptor() {
const autoLogin = useAuthStore((state) => state.autoLogin);
const setErrorData = useAlertStore((state) => state.setErrorData);
let { accessToken, authenticationErrorCount } = useContext(AuthContext);
const accessToken = useAuthStore((state) => state.accessToken);
const authenticationErrorCount = useAuthStore(
(state) => state.authenticationErrorCount,
);
const setAuthenticationErrorCount = useAuthStore(
(state) => state.setAuthenticationErrorCount,
);

const { mutate: mutationLogout } = useLogout();
const { mutate: mutationRenewAccessToken } = useRefreshAccessToken();
const isLoginPage = location.pathname.includes("login");
Expand Down Expand Up @@ -149,10 +155,10 @@ function ApiInterceptor() {
function checkErrorCount() {
if (isLoginPage) return;

authenticationErrorCount = authenticationErrorCount + 1;
setAuthenticationErrorCount(authenticationErrorCount + 1);

if (authenticationErrorCount > 3) {
authenticationErrorCount = 0;
setAuthenticationErrorCount(0);
mutationLogout();
return false;
}
Expand All @@ -169,9 +175,9 @@ function ApiInterceptor() {
}
mutationRenewAccessToken(undefined, {
onSuccess: async () => {
authenticationErrorCount = 0;
setAuthenticationErrorCount(0);
await remakeRequest(error);
authenticationErrorCount = 0;
setAuthenticationErrorCount(0);
},
onError: (error) => {
console.error(error);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { DEFAULT_FOLDER, STARTER_FOLDER_NAME } from "@/constants/constants";
import { FolderType } from "@/pages/MainPage/entities";
import useAuthStore from "@/stores/authStore";
import { useFolderStore } from "@/stores/foldersStore";
import { useTypesStore } from "@/stores/typesStore";
import { useQueryFunctionType } from "@/types/api";
Expand All @@ -18,7 +19,10 @@ export const useGetFoldersQuery: useQueryFunctionType<
const setMyCollectionId = useFolderStore((state) => state.setMyCollectionId);
const setFolders = useFolderStore((state) => state.setFolders);

const isAuthenticated = useAuthStore((state) => state.isAuthenticated);

const getFoldersFn = async (): Promise<FolderType[]> => {
if (!isAuthenticated) return [];
const res = await api.get(`${getURL("FOLDERS")}/`);
const data = res.data;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import useAuthStore from "@/stores/authStore";
import { useGlobalVariablesStore } from "@/stores/globalVariablesStore/globalVariables";
import getUnavailableFields from "@/stores/globalVariablesStore/utils/get-unavailable-fields";
import { useQueryFunctionType } from "@/types/api";
Expand All @@ -20,7 +21,10 @@ export const useGetGlobalVariables: useQueryFunctionType<
(state) => state.setUnavailableFields,
);

const isAuthenticated = useAuthStore((state) => state.isAuthenticated);

const getGlobalVariablesFn = async (): Promise<GlobalVariable[]> => {
if (!isAuthenticated) return [];
const res = await api.get(`${getURL("VARIABLES")}/`);
setGlobalVariablesEntries(res.data.map((entry) => entry.name));
setUnavailableFields(getUnavailableFields(res.data));
Expand Down
Loading

0 comments on commit 3c1a1c5

Please sign in to comment.