Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

My docs feature #3805

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions backend/alembic/versions/9aadf32dfeb4_add_user_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""add user files
Revision ID: 9aadf32dfeb4
Revises: 8f43500ee275
Create Date: 2025-01-26 16:08:21.551022
"""
from alembic import op
import sqlalchemy as sa
import datetime


# revision identifiers, used by Alembic.
revision = "9aadf32dfeb4"
down_revision = "8f43500ee275"
branch_labels = None
depends_on = None


def upgrade() -> None:
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)

# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)

# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)

# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)

op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True),
)


def downgrade() -> None:
# Drop the persona__user_folder table
op.drop_table("persona__user_folder")
# Drop the persona__user_file table
op.drop_table("persona__user_file")
# Drop the user_file table
op.drop_table("user_file")
# Drop the user_folder table
op.drop_table("user_folder")
op.drop_column("connector_credential_pair", "is_user_file")
Original file line number Diff line number Diff line change
@@ -319,8 +319,10 @@ def dispatch_separated(
sep: str = DISPATCH_SEP_CHAR,
) -> list[BaseMessage_Content]:
num = 1
accumulated_tokens = ""
streamed_tokens: list[BaseMessage_Content] = []
for token in tokens:
accumulated_tokens += cast(str, token.content)
content = cast(str, token.content)
if sep in content:
sub_question_parts = content.split(sep)
19 changes: 17 additions & 2 deletions backend/onyx/chat/process_message.py
Original file line number Diff line number Diff line change
@@ -86,6 +86,7 @@
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import load_all_user_files
from onyx.file_store.utils import save_files
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
@@ -262,8 +263,11 @@ def _get_force_search_settings(
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)

if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
if new_msg_req.force_user_file_search:
return ForceUseTool(force_use=True, tool_name=SearchTool._NAME)
else:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)

tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
@@ -279,6 +283,7 @@ def _get_force_search_settings(

should_force_search = any(
[
new_msg_req.force_user_file_search,
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
@@ -538,6 +543,15 @@ def stream_chat_message_objects(
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]

if not new_msg_req.force_user_file_search:
user_files = load_all_user_files(
new_msg_req.user_file_ids,
new_msg_req.user_folder_ids,
db_session,
)

latest_query_files += user_files

if user_message:
attach_files_to_chat_message(
chat_message=user_message,
@@ -681,6 +695,7 @@ def stream_chat_message_objects(
user=user,
llm=llm,
fast_llm=fast_llm,
use_file_search=new_msg_req.force_user_file_search,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
2 changes: 1 addition & 1 deletion backend/onyx/configs/chat_configs.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"

USER_FOLDERS_YAML = "./onyx/seeding/user_folders.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
1 change: 1 addition & 0 deletions backend/onyx/context/search/models.py
Original file line number Diff line number Diff line change
@@ -98,6 +98,7 @@ class BaseFilters(BaseModel):
document_set: list[str] | None = None
time_cutoff: datetime | None = None
tags: list[Tag] | None = None
user_file_ids: list[int] | None = None


class IndexFilters(BaseFilters):
9 changes: 9 additions & 0 deletions backend/onyx/context/search/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
@@ -160,7 +160,16 @@ def retrieval_preprocessing(
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
user_file_ids = preset_filters.user_file_ids
if persona and persona.user_files:
user_file_ids = user_file_ids + [
file.id
for file in persona.user_files
if file.id not in preset_filters.user_file_ids
]

final_filters = IndexFilters(
user_file_ids=user_file_ids,
source_type=preset_filters.source_type or predicted_source_filters,
document_set=preset_filters.document_set,
time_cutoff=time_filter or predicted_time_cutoff,
12 changes: 10 additions & 2 deletions backend/onyx/db/connector_credential_pair.py
Original file line number Diff line number Diff line change
@@ -104,6 +104,7 @@ def get_connector_credential_pairs_for_user(
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
include_user_files: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
@@ -126,6 +127,9 @@ def get_connector_credential_pairs_for_user(
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))

if not include_user_files:
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712

return list(db_session.scalars(stmt).unique().all())


@@ -153,14 +157,16 @@ def get_connector_credential_pairs_for_user_parallel(


def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
db_session: Session, ids: list[int] | None = None, include_user_files: bool = False
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()

if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))

if not include_user_files:
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712

return list(db_session.scalars(stmt).all())


@@ -446,6 +452,7 @@ def add_credential_to_connector(
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
last_successful_index_time: datetime | None = None,
seeding_flow: bool = False,
is_user_file: bool = False,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)

@@ -511,6 +518,7 @@ def add_credential_to_connector(
access_type=access_type,
auto_sync_options=auto_sync_options,
last_successful_index_time=last_successful_index_time,
is_user_file=is_user_file,
)
db_session.add(association)
db_session.flush() # make sure the association has an id
4 changes: 2 additions & 2 deletions backend/onyx/db/document.py
Original file line number Diff line number Diff line change
@@ -274,7 +274,7 @@ def get_document_counts_for_cc_pairs_parallel(
def get_access_info_for_document(
db_session: Session,
document_id: str,
) -> tuple[str, list[str | None], bool] | None:
) -> tuple[str, list[str | None], bool, list[int], list[int]] | None:
"""Gets access info for a single document by calling the get_access_info_for_documents function
and passing a list with a single document ID.
Args:
@@ -294,7 +294,7 @@ def get_access_info_for_document(
def get_access_info_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[str | None], bool]]:
) -> Sequence[tuple[str, list[str | None], bool, list[int], list[int]]]:
"""Gets back all relevant access info for the given documents. This includes
the user_ids for cc pairs that the document is associated with + whether any
of the associated cc pairs are intending to make the document globally public.
1 change: 0 additions & 1 deletion backend/onyx/db/document_set.py
Original file line number Diff line number Diff line change
@@ -605,7 +605,6 @@ def fetch_document_sets_for_document(
result = fetch_document_sets_for_documents([document_id], db_session)
if not result:
return []

return result[0][1]


Loading