Skip to content

Commit

Permalink
Add ability to specify persona in API request (#2302)
Browse files Browse the repository at this point in the history
* persona

* all prepared excluding configuration

* more sensical model structure

* update tstream

* type updates

* rm

* quick and simple updates

* minor updates

* te

* ensure typing + naming

* remove old todo + rebase update

* remove unnecessary check
  • Loading branch information
pablonyx authored Sep 16, 2024
1 parent df464fc commit 2dd3870
Show file tree
Hide file tree
Showing 14 changed files with 283 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add nullable to persona id in Chat Session
Revision ID: c99d76fcd298
Revises: 5c7fdadae813
Create Date: 2024-07-09 19:27:01.579697
"""

from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)


def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)
30 changes: 17 additions & 13 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,11 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
),
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
Expand Down Expand Up @@ -786,16 +788,18 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else []
),
)

logger.debug("Committing messages")
Expand Down
27 changes: 18 additions & 9 deletions backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional
from typing import TypeVar

from fastapi import HTTPException
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
Expand Down Expand Up @@ -153,15 +154,23 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non

with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
if new_message_request.persona_config:
raise HTTPException(
status_code=403,
detail="Slack bot does not support persona config",
)

elif new_message_request.persona_id:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)

llm, _ = get_llms_for_persona(persona)

# In cases of threads, split the available tokens between docs and thread context
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def create_chat_session(
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int,
persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
Expand Down
16 changes: 16 additions & 0 deletions backend/danswer/db/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from sqlalchemy.orm import Session

from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import DocumentSet
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings
from danswer.db.models import Tool as ToolModel
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
Expand Down Expand Up @@ -103,6 +105,20 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())


def fetch_existing_doc_sets(
db_session: Session, doc_ids: list[int]
) -> list[DocumentSet]:
return list(
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
)


def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
return list(
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
)


def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
Expand Down
6 changes: 3 additions & 3 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,9 @@ class ChatSession(Base):

id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
Expand Down Expand Up @@ -900,7 +902,6 @@ class ChatSession(Base):
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)

time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
Expand All @@ -909,7 +910,6 @@ class ChatSession(Base):
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
Expand Down
8 changes: 5 additions & 3 deletions backend/danswer/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,13 +563,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
)


def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()

return prompts
return list(prompts)


def get_prompt_by_id(
Expand Down
45 changes: 28 additions & 17 deletions backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
from danswer.llm.answering.answer import Answer
Expand Down Expand Up @@ -60,7 +61,7 @@
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time

from ee.danswer.server.query_and_chat.utils import create_temporary_persona

logger = setup_logger()

Expand Down Expand Up @@ -118,7 +119,17 @@ def stream_answer_objects(
one_shot=True,
danswerbot_flow=danswerbot_flow,
)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)

temporary_persona: Persona | None = None
if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona

persona = temporary_persona if temporary_persona else chat_session.persona

llm, fast_llm = get_llms_for_persona(persona=persona)

llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
Expand Down Expand Up @@ -153,11 +164,11 @@ def stream_answer_objects(
prompt_id=query_req.prompt_id, user=None, db_session=db_session
)
if prompt is None:
if not chat_session.persona.prompts:
if not persona.prompts:
raise RuntimeError(
"Persona does not have any prompts - this should never happen"
)
prompt = chat_session.persona.prompts[0]
prompt = persona.prompts[0]

# Create the first User query message
new_user_message = create_new_chat_message(
Expand All @@ -174,9 +185,7 @@ def stream_answer_objects(
prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
chat_session.persona.num_chunks
if chat_session.persona.num_chunks is not None
else default_num_chunks
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
),
max_tokens=max_document_tokens,
)
Expand All @@ -187,16 +196,16 @@ def stream_answer_objects(
evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type,
persona=chat_session.persona,
persona=persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
)

answer_config = AnswerStyleConfig(
Expand All @@ -209,23 +218,23 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
tools=[search_tool] if search_tool else [],
force_use_tool=(
ForceUseTool(
tool_name=search_tool.name,
args={"query": rephrased_query},
force_use=True,
)
),
# for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet
skip_explicit_tool_calling=True,
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)

# won't be any ImageGenerationDisplay responses since that tool is never passed in

for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
Expand Down Expand Up @@ -261,6 +270,7 @@ def stream_answer_objects(
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
)

yield initial_response

elif packet.id == SEARCH_DOC_CONTENT_ID:
Expand All @@ -287,6 +297,7 @@ def stream_answer_objects(
relevance_summary=evaluation_response,
)
yield evaluation_response

else:
yield packet

Expand Down
Loading

0 comments on commit 2dd3870

Please sign in to comment.