Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: update poetry dependency #65

Merged
merged 2 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions backend/app/alembic/versions/7bb637385f51_modify_skills_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,32 @@
Create Date: 2024-09-26 13:41:33.595237

"""
from alembic import op

import sqlalchemy as sa
import sqlmodel.sql.sqltypes

from alembic import op

# revision identifiers, used by Alembic.
revision = '7bb637385f51'
down_revision = 'b5d6291d6db9'
revision = "7bb637385f51"
down_revision = "b5d6291d6db9"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('skill', sa.Column('display_name', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.drop_column('skill', 'icon')
op.add_column(
"skill",
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
)
op.drop_column("skill", "icon")
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('skill', sa.Column('icon', sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_column('skill', 'display_name')
op.add_column(
"skill", sa.Column("icon", sa.VARCHAR(), autoincrement=False, nullable=True)
)
op.drop_column("skill", "display_name")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion backend/app/api/routes/providermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from app.curd.models import (
_create_model,
_delete_model,
_update_model,
get_all_models,
get_models_by_provider,
_update_model,
)
from app.models import Models, ModelsBase, ModelsOut

Expand Down
13 changes: 9 additions & 4 deletions backend/app/core/celery_app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import os

from celery import Celery

from app.core.config import settings
import logging

# 配置基本日志
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(os.getcwd(), "fastembed_cache")
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
Expand All @@ -31,10 +35,11 @@
# 配置 Celery 日志
celery_app.conf.update(
worker_hijack_root_logger=False,
worker_log_format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
worker_task_log_format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
worker_log_format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
worker_task_log_format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)


@celery_app.task(acks_late=True)
def test_celery(word: str) -> str:
logging.info(f"Test task received: {word}")
Expand Down
18 changes: 9 additions & 9 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,11 @@ def _enforce_non_default_secrets(self) -> Self:

# Qdrant
QDRANT_SERVICE_API_KEY: str | None = "XMj3HXm5GlBKQLwZuStOlkwZiOWTdd_IwZNDJINFh-w"
# QDRANT_URL: str = "http://localhost:6333"
QDRANT_URL: str = "http://127.0.0.1:6333"
QDRANT_URL: str = "http://localhost:6333"
# QDRANT_URL: str = "http://127.0.0.1:6333"

QDRANT_COLLECTION: str | None = "kb_uploads"

# LangSmith
# USE_LANGSMITH: bool = True
# LANGCHAIN_TRACING_V2: bool = False
# LANGCHAIN_ENDPOINT: str | None = None
# LANGCHAIN_API_KEY: str | None = None
# LANGCHAIN_PROJECT: str | None = None

# Embeddings
# EMBEDDING_MODEL: str = "local" # 或者你想使用的其他模型
EMBEDDING_MODEL: str = "zhipuai" # 或者你想使用的其他模型
Expand All @@ -179,5 +172,12 @@ def _enforce_non_default_secrets(self) -> Self:

OPENAI_API_KEY: str

# LangSmith
# USE_LANGSMITH: bool = True
# LANGCHAIN_TRACING_V2: bool = False
# LANGCHAIN_ENDPOINT: str | None = None
# LANGCHAIN_API_KEY: str | None = None
# LANGCHAIN_PROJECT: str | None = None


settings = Settings() # type: ignore
2 changes: 1 addition & 1 deletion backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ async def generator(
formatted_output = f"data: {response.model_dump_json()}\n\n"
yield formatted_output
snapshot = await root.aget_state(config)

if snapshot.next:
# Interrupt occured
message = snapshot.values["messages"][-1]
Expand Down
16 changes: 1 addition & 15 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from collections.abc import Mapping, Sequence
from typing import Annotated, Any
from app.core.rag.qdrant import QdrantStore
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field

from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
Expand Down Expand Up @@ -42,18 +40,6 @@ def tool(self) -> BaseTool:
raise ValueError("Skill is not managed and no definition provided.")


# class GraphUpload(BaseModel):
# name: str = Field(description="Name of the upload")
# description: str = Field(description="Description of the upload")
# owner_id: int = Field(description="Id of the user that owns this upload")
# upload_id: int = Field(description="Id of the upload")

# @property
# def tool(self) -> BaseTool:
# retriever = QdrantStore().retriever(self.owner_id, self.upload_id)
# return create_retriever_tool(retriever)


class GraphUpload(BaseModel):
name: str = Field(description="Name of the upload")
description: str = Field(description="Description of the upload")
Expand Down
10 changes: 6 additions & 4 deletions backend/app/core/rag/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from typing import List
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings

import requests
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra
import requests
from langchain_openai import OpenAIEmbeddings

from app.core.config import settings
import logging

logger = logging.getLogger(__name__)

Expand Down
60 changes: 35 additions & 25 deletions backend/app/core/rag/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from typing import List, Callable
import logging
from typing import Callable, List

from langchain_community.document_loaders import PyMuPDFLoader, WebBaseLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_qdrant import QdrantVectorStore
from langchain_text_splitters import RecursiveCharacterTextSplitter
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
from qdrant_client.models import VectorParams, Distance
import pymupdf
from qdrant_client.models import Distance, VectorParams

from app.core.config import settings
from app.core.rag.embeddings import get_embedding_model

import logging

logger = logging.getLogger(__name__)


class QdrantStore:
def __init__(self) -> None:
self.collection_name = settings.QDRANT_COLLECTION
# self.url = "http://localhost:6333"
self.url = settings.QDRANT_URL
self.embedding_model = get_embedding_model(settings.EMBEDDING_MODEL)

Expand All @@ -33,7 +34,9 @@ def __init__(self) -> None:
def _initialize_vector_store(self):
try:
collections = self.client.get_collections().collections
if self.collection_name not in [collection.name for collection in collections]:
if self.collection_name not in [
collection.name for collection in collections
]:
logger.info(f"Creating new collection: {self.collection_name}")
self.client.create_collection(
collection_name=self.collection_name,
Expand Down Expand Up @@ -113,7 +116,7 @@ def delete(self, upload_id: int, user_id: int) -> bool:
match=rest.MatchValue(value=upload_id),
),
],
)
),
)
logger.info(f"Delete operation result: {result}")
return True
Expand All @@ -136,46 +139,53 @@ def update(
callback()

def search(self, user_id: int, upload_ids: List[int], query: str) -> List[Document]:
logger.info(f"Searching with query: '{query}' for user_id: {user_id}, upload_ids: {upload_ids}")

logger.info(
f"Searching with query: '{query}' for user_id: {user_id}, upload_ids: {upload_ids}"
)

query_vector = self.embedding_model.embed_query(query)

filter_condition = {
"must": [
{"key": "metadata.user_id", "match": {"value": user_id}},
{"key": "metadata.upload_id", "match": {"any": upload_ids}}
{"key": "metadata.upload_id", "match": {"any": upload_ids}},
]
}
logger.info(f"Search filter condition: {filter_condition}")

search_results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=filter_condition,
limit=4
limit=4,
)

documents = [Document(page_content=result.payload.get('page_content', ''), metadata=result.payload.get('metadata', {})) for result in search_results]


documents = [
Document(
page_content=result.payload.get("page_content", ""),
metadata=result.payload.get("metadata", {}),
)
for result in search_results
]

logger.info(f"Search results: {len(documents)} documents found")
for doc in documents:
logger.info(f"Document metadata: {doc.metadata}")

return documents

def retriever(self, user_id: int, upload_id: int):
logger.info(f"Creating retriever for user_id: {user_id}, upload_id: {upload_id}")
logger.info(
f"Creating retriever for user_id: {user_id}, upload_id: {upload_id}"
)
filter_condition = {
"must": [
{"key": "metadata.user_id", "match": {"value": user_id}},
{"key": "metadata.upload_id", "match": {"value": upload_id}}
{"key": "metadata.upload_id", "match": {"value": upload_id}},
]
}
retriever = self.vector_store.as_retriever(
search_kwargs={
"filter": filter_condition,
"k": 5
},
search_kwargs={"filter": filter_condition, "k": 5},
search_type="similarity",
)
logger.info(f"Retriever created: {retriever}")
Expand Down Expand Up @@ -205,4 +215,4 @@ def debug_retriever(self, user_id: int, upload_id: int, query: str):
def get_collection_info(self):
collection_info = self.client.get_collection(self.collection_name)
logger.info(f"Collection info: {collection_info}")
return collection_info
return collection_info
2 changes: 1 addition & 1 deletion backend/app/core/rag/qdrant_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def _get_relevant_documents(
limit=self.k,
)
documents.append(document)
return documents
return documents
2 changes: 1 addition & 1 deletion backend/app/core/rag/rag_test_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

sys.path.append("./")

from app.core.rag.qdrant import QdrantStore
from app.core.rag.embeddings import get_embedding_model
from app.core.rag.qdrant import QdrantStore

# 初始化 QdrantStore 和嵌入模型
qdrant_store = QdrantStore()
Expand Down
6 changes: 3 additions & 3 deletions backend/app/core/rag/ragtest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import logging

from app.core.rag.qdrant import QdrantStore
from app.core.tools.retriever_tool import create_retriever_tool
import logging
import json
from qdrant_client.http import models as rest

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down
4 changes: 2 additions & 2 deletions backend/app/core/tools/math/math.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This is an example showing how to create a simple calculator skill

import numexpr as ne
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import StructuredTool
import numexpr as ne


class MathInput(BaseModel):
Expand All @@ -14,7 +14,7 @@ def math_cal(expression: str) -> str:
result = ne.evaluate(expression)
result_str = str(result)
return f"{result_str}"
except Exception as e:
except Exception:

return f"Error evaluating expression: {expression}"

Expand Down
2 changes: 0 additions & 2 deletions backend/app/core/tools/openweather/openweather.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,3 @@ def open_weather_qry(
args_schema=WeatherSearchInput,
return_direct=True,
)


Loading
Loading