Skip to content

Commit

Permalink
Merge branch 'main' into cz/mvp-components-flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Cristhianzl authored Oct 27, 2024
2 parents 5a0cd92 + c8bdcf3 commit a5e81b5
Show file tree
Hide file tree
Showing 32 changed files with 900 additions and 696 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ dependencies = [
"yfinance>=0.2.40",
"langchain-google-community~=2.0.1",
"wolframalpha>=5.1.3",
"astra-assistants~=2.2.2",
"astra-assistants[tools]~=2.2.5",
"composio-langchain==0.5.9",
"spider-client>=0.0.27",
"nltk>=3.9.1",
Expand Down
5 changes: 3 additions & 2 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from langflow.graph.graph.base import Graph
from langflow.graph.utils import log_vertex_build
from langflow.schema.schema import OutputValue
from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session, get_telemetry_service
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
Expand Down Expand Up @@ -493,7 +494,7 @@ async def build_vertex(
error_message = None
try:
cache = await chat_service.get_cache(flow_id_str)
if not cache:
if isinstance(cache, CacheMiss):
# If there's no cache
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
graph: Graph = await build_graph_from_db(
Expand Down Expand Up @@ -621,7 +622,7 @@ async def _stream_vertex(flow_id: str, vertex_id: str, chat_service: ChatService
yield str(StreamData(event="error", data={"error": str(exc)}))
return

if not cache:
if isinstance(cache, CacheMiss):
# If there's no cache
msg = f"No cache found for {flow_id}."
logger.error(msg)
Expand Down
4 changes: 3 additions & 1 deletion src/backend/base/langflow/base/astra_assistants/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
from astra_assistants import OpenAIWithDefaultKey, patch
from astra_assistants.tools.tool_interface import ToolInterface

from langflow.services.cache.utils import CacheMiss

client_lock = threading.Lock()
client = None


def get_patched_openai_client(shared_component_cache):
os.environ["ASTRA_ASSISTANTS_QUIET"] = "true"
client = shared_component_cache.get("client")
if client is None:
if isinstance(client, CacheMiss):
client = patch(OpenAIWithDefaultKey())
shared_component_cache.set("client", client)
return client
Expand Down
41 changes: 40 additions & 1 deletion src/backend/base/langflow/components/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.inputs.inputs import HandleInput
from langflow.io import DictInput, DropdownInput, IntInput, SecretStrInput, StrInput
from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput


class HuggingFaceEndpointsComponent(LCModelComponent):
Expand All @@ -20,6 +20,45 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
inputs = [
*LCModelComponent._base_inputs,
StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"),
IntInput(
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
),
IntInput(
name="top_k",
display_name="Top K",
advanced=True,
info="The number of highest probability vocabulary tokens to keep for top-k-filtering",
),
FloatInput(
name="top_p",
display_name="Top P",
value=0.95,
advanced=True,
info=(
"If set to < 1, only the smallest set of most probable tokens with "
"probabilities that add up to `top_p` or higher are kept for generation"
),
),
FloatInput(
name="typical_p",
display_name="Typical P",
value=0.95,
advanced=True,
info="Typical Decoding mass.",
),
FloatInput(
name="temperature",
display_name="Temperature",
value=0.8,
advanced=True,
info="The value used to module the logits distribution",
),
FloatInput(
name="repetition_penalty",
display_name="Repetition Penalty",
info="The parameter for repetition penalty. 1.0 means no penalty.",
advanced=True,
),
StrInput(
name="inference_endpoint",
display_name="Inference Endpoint",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def add_inputs_to_build_config(self, inputs_vertex: list[Vertex], build_config:
async def generate_results(self) -> list[Data]:
tweaks: dict = {}
for field in self._attributes:
if field != "flow_name":
if field != "flow_name" and "|" in field:
[node, name] = field.split("|")
if node not in tweaks:
tweaks[node] = {}
Expand Down
55 changes: 28 additions & 27 deletions src/backend/base/langflow/components/vectorstores/astradb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import orjson
from astrapy.admin import parse_api_endpoint
from loguru import logger

Expand Down Expand Up @@ -116,6 +117,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
display_name="Metric",
info="Optional distance metric for vector comparisons in the vector store.",
options=["cosine", "dot_product", "euclidean"],
value="cosine",
advanced=True,
),
IntInput(
Expand Down Expand Up @@ -145,8 +147,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
DropdownInput(
name="setup_mode",
display_name="Setup Mode",
info="Configuration mode for setting up the vector store, with options like 'Sync', 'Async', or 'Off'.",
options=["Sync", "Async", "Off"],
info="Configuration mode for setting up the vector store, with options like 'Sync' or 'Off'.",
options=["Sync", "Off"],
advanced=True,
value="Sync",
),
Expand All @@ -160,18 +162,21 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
name="metadata_indexing_include",
display_name="Metadata Indexing Include",
info="Optional list of metadata fields to include in the indexing.",
is_list=True,
advanced=True,
),
StrInput(
name="metadata_indexing_exclude",
display_name="Metadata Indexing Exclude",
info="Optional list of metadata fields to exclude from the indexing.",
is_list=True,
advanced=True,
),
StrInput(
name="collection_indexing_policy",
display_name="Collection Indexing Policy",
info="Optional dictionary defining the indexing policy for the collection.",
info='Optional JSON string for the "indexing" field of the collection. '
"See https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option",
advanced=True,
),
IntInput(
Expand Down Expand Up @@ -401,31 +406,27 @@ def build_vector_store(self, vectorize_options=None):
),
}

vector_store_kwargs = {
**embedding_dict,
"collection_name": self.collection_name,
"token": self.token,
"api_endpoint": self.api_endpoint,
"namespace": self.namespace or None,
"environment": parse_api_endpoint(self.api_endpoint).environment,
"metric": self.metric or None,
"batch_size": self.batch_size or None,
"bulk_insert_batch_concurrency": self.bulk_insert_batch_concurrency or None,
"bulk_insert_overwrite_concurrency": self.bulk_insert_overwrite_concurrency or None,
"bulk_delete_concurrency": self.bulk_delete_concurrency or None,
"setup_mode": setup_mode_value,
"pre_delete_collection": self.pre_delete_collection or False,
}

if self.metadata_indexing_include:
vector_store_kwargs["metadata_indexing_include"] = self.metadata_indexing_include
elif self.metadata_indexing_exclude:
vector_store_kwargs["metadata_indexing_exclude"] = self.metadata_indexing_exclude
elif self.collection_indexing_policy:
vector_store_kwargs["collection_indexing_policy"] = self.collection_indexing_policy

try:
vector_store = AstraDBVectorStore(**vector_store_kwargs)
vector_store = AstraDBVectorStore(
collection_name=self.collection_name,
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace or None,
environment=parse_api_endpoint(self.api_endpoint).environment,
metric=self.metric,
batch_size=self.batch_size or None,
bulk_insert_batch_concurrency=self.bulk_insert_batch_concurrency or None,
bulk_insert_overwrite_concurrency=self.bulk_insert_overwrite_concurrency or None,
bulk_delete_concurrency=self.bulk_delete_concurrency or None,
setup_mode=setup_mode_value,
pre_delete_collection=self.pre_delete_collection,
metadata_indexing_include=[s for s in self.metadata_indexing_include if s],
metadata_indexing_exclude=[s for s in self.metadata_indexing_exclude if s],
collection_indexing_policy=orjson.dumps(self.collection_indexing_policy)
if self.collection_indexing_policy
else None,
**embedding_dict,
)
except Exception as e:
msg = f"Error initializing AstraDBVectorStore: {e}"
raise ValueError(msg) from e
Expand Down
15 changes: 8 additions & 7 deletions src/backend/base/langflow/custom/custom_component/component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import asyncio
import inspect
from copy import deepcopy
from textwrap import dedent
Expand Down Expand Up @@ -506,11 +507,10 @@ def __call__(self, **kwargs):
async def _run(self):
# Resolve callable inputs
for key, _input in self._inputs.items():
if callable(_input.value):
result = _input.value()
if inspect.iscoroutine(result):
result = await result
self._inputs[key].value = result
if asyncio.iscoroutinefunction(_input.value):
self._inputs[key].value = await _input.value()
elif callable(_input.value):
self._inputs[key].value = await asyncio.to_thread(_input.value)

self.set_attributes({})

Expand Down Expand Up @@ -718,10 +718,11 @@ async def _build_results(self):
_results[output.name] = output.value
result = output.value
else:
result = method()
# If the method is asynchronous, we need to await it
if inspect.iscoroutinefunction(method):
result = await result
result = await method()
else:
result = await asyncio.to_thread(method)
if (
self._vertex is not None
and isinstance(result, Message)
Expand Down
54 changes: 1 addition & 53 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,58 +707,6 @@ async def _run(

return vertex_outputs

def run(
self,
inputs: list[dict[str, str]],
*,
input_components: list[list[str]] | None = None,
types: list[InputType | None] | None = None,
outputs: list[str] | None = None,
session_id: str | None = None,
stream: bool = False,
fallback_to_env_vars: bool = False,
) -> list[RunOutputs]:
"""Run the graph with the given inputs and return the outputs.
Args:
inputs (Dict[str, str]): A dictionary of input values.
input_components (Optional[list[str]]): A list of input components.
types (Optional[list[str]]): A list of types.
outputs (Optional[list[str]]): A list of output components.
session_id (Optional[str]): The session ID.
stream (bool): Whether to stream the outputs.
fallback_to_env_vars (bool): Whether to fallback to environment variables.
Returns:
List[RunOutputs]: A list of RunOutputs objects representing the outputs.
"""
# run the async function in a sync way
# this could be used in a FastAPI endpoint
# so we should take care of the event loop
coro = self.arun(
inputs=inputs,
inputs_components=input_components,
types=types,
outputs=outputs,
session_id=session_id,
stream=stream,
fallback_to_env_vars=fallback_to_env_vars,
)

try:
# Attempt to get the running event loop; if none, an exception is raised
loop = asyncio.get_running_loop()
except RuntimeError:
# If there's no running event loop, use asyncio.run
return asyncio.run(coro)

# If the event loop is closed, use asyncio.run
if loop.is_closed():
return asyncio.run(coro)

# If there's an existing, open event loop, use it to run the async function
return loop.run_until_complete(coro)

async def arun(
self,
inputs: list[dict[str, str]],
Expand Down Expand Up @@ -1356,7 +1304,7 @@ async def build_vertex(
if get_cache is not None:
cached_result = await get_cache(key=vertex.id)
else:
cached_result = None
cached_result = CacheMiss()
if isinstance(cached_result, CacheMiss):
should_build = True
else:
Expand Down
5 changes: 1 addition & 4 deletions src/backend/base/langflow/graph/vertex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,10 +814,7 @@ async def build(
# Run steps
for step in self.steps:
if step not in self.steps_ran:
if inspect.iscoroutinefunction(step):
await step(user_id=user_id, event_manager=event_manager, **kwargs)
else:
step(user_id=user_id, event_manager=event_manager, **kwargs)
await step(user_id=user_id, event_manager=event_manager, **kwargs)
self.steps_ran.append(step)

self.finalize_build()
Expand Down
8 changes: 1 addition & 7 deletions src/backend/base/langflow/initial_setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import shutil
import time
from collections import defaultdict
from collections.abc import Awaitable
from copy import deepcopy
from datetime import datetime, timezone
from pathlib import Path
Expand Down Expand Up @@ -600,12 +599,7 @@ def find_existing_flow(session, flow_id, flow_endpoint_name):
return None


async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]) -> None:
try:
all_types_dict = await get_all_components_coro
except Exception:
logger.exception("Error loading components")
raise
def create_or_update_starter_projects(all_types_dict: dict) -> None:
with session_scope() as session:
new_folder = create_starter_folder(session)
starter_projects = load_starter_projects()
Expand Down

Large diffs are not rendered by default.

Loading

0 comments on commit a5e81b5

Please sign in to comment.