Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

普通会话和知识库会话支持deepseek-r1深度思考模式 #5214

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
class OpenAIBaseOutput(BaseModel):
id: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
model: Optional[str] = None
object: Literal[
"chat.completion", "chat.completion.chunk"
Expand Down Expand Up @@ -150,6 +151,7 @@ def model_dump(self) -> dict:
{
"delta": {
"content": self.content,
"reasoning_content": self.reasoning_content,
"tool_calls": self.tool_calls,
},
"role": self.role,
Expand Down
152 changes: 152 additions & 0 deletions libs/chatchat-server/chatchat/server/chat/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio
import logging
from typing import Any, Optional, Iterator
from typing import AsyncIterator
from typing import Any, Dict, Iterator, List, Optional
from typing_extensions import List, TypedDict

from langchain.schema import HumanMessage, AIMessage, SystemMessage,ChatMessage
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, LLMResult
from langchain_core.callbacks import CallbackManagerForLLMRun

from langchain_openai import ChatOpenAI

logger = logging.getLogger(__name__)

class DeepseekChatOpenAI(ChatOpenAI):
async def _astream(
self,
messages: Any,
stop: Optional[Any] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> AsyncIterator[AIMessageChunk]:
openai_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
openai_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
openai_messages.append({"role": "assistant", "content": msg.content})
elif isinstance(msg, SystemMessage):
openai_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, ChatMessage):
openai_messages.append({"role": msg.role, "content": msg.content})
else:
raise ValueError(f"Unsupported message type: {type(msg)}")

params = {
"model": self.model_name,
"messages": openai_messages,
**self.model_kwargs,
**kwargs,
"extra_body": {
"enable_enhanced_generation": True,
**(kwargs.get("extra_body", {})),
**(self.model_kwargs.get("extra_body", {}))
}
}
params = {k: v for k, v in params.items() if v not in (None, {}, [])}

# Create and process the stream
async for chunk in await self.async_client.create(
stream=True,
**params
):
content = chunk.choices[0].delta.content or ""
reasoning = chunk.choices[0].delta.model_extra.get("reasoning_content", "") if chunk.choices[
0].delta.model_extra else ""
if content:
yield ChatGenerationChunk(
message=AIMessageChunk(content=content),
generation_info={"reasoning_content": reasoning}
)
if reasoning:
chunk=ChatGenerationChunk(
message=AIMessageChunk(
content="",
additional_kwargs={"reasoning_content": reasoning}
),
generation_info={"reasoning_content": reasoning}
)
yield chunk
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
openai_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
openai_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
openai_messages.append({"role": "assistant", "content": msg.content})
elif isinstance(msg, SystemMessage):
openai_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, ChatMessage):
openai_messages.append({"role": msg.role, "content": msg.content})
else:
raise ValueError(f"Unsupported message type: {type(msg)}")

params = {
"model": self.model_name,
"messages": openai_messages,
**self.model_kwargs,
**kwargs,
"extra_body": {
"enable_enhanced_generation": True,
**(kwargs.get("extra_body", {})),
**(self.model_kwargs.get("extra_body", {}))
}
}
params = {k: v for k, v in params.items() if v not in (None, {}, [])}

# Create and process the stream
for chunk in self.client.create(
stream=True,
**params
):
content = chunk.choices[0].delta.content or ""
reasoning = chunk.choices[0].delta.model_extra.get("reasoning_content", "") if chunk.choices[
0].delta.model_extra else ""
if content:
yield ChatGenerationChunk(
message=AIMessageChunk(content=content),
generation_info={"reasoning_content": reasoning}
)
if reasoning:
yield ChatGenerationChunk(
message=AIMessageChunk(
content="",
additional_kwargs={"reasoning_content": reasoning}
),
generation_info={"reasoning_content": reasoning}
)

def invoke(
self,
messages: Any,
stop: Optional[Any] = None,
run_manager: Optional[Any] = None,
**kwargs: Any,
) -> AIMessage:

async def _ainvoke():
combined_content = []
combined_reasoning = []
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
if chunk.message.content:
combined_content.append(chunk.message.content)
# If reasoning is in additional_kwargs, gather that too
if "reasoning_content" in chunk.message.additional_kwargs:
combined_reasoning.append(
chunk.message.additional_kwargs["reasoning_content"]
)
return AIMessage(
content="".join(combined_content),
additional_kwargs={"reasoning_content": "".join(combined_reasoning)} if combined_reasoning else {}
)

return asyncio.run(_ainvoke())
53 changes: 30 additions & 23 deletions libs/chatchat-server/chatchat/server/chat/kb_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
) .model_dump_json()
return

callback = AsyncIteratorCallbackHandler()
callbacks = [callback]

# Enable langchain-chatchat to support langfuse
import os
langfuse_secret_key = os.environ.get('LANGFUSE_SECRET_KEY')
Expand All @@ -142,8 +139,7 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
llm = get_ChatOpenAI(
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
max_tokens=max_tokens
)
# TODO: 视情况使用 API
# # 加入reranker
Expand Down Expand Up @@ -171,12 +167,6 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]:

chain = chat_prompt | llm

# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.ainvoke({"context": context, "question": query}),
callback.done),
)

if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")

Expand All @@ -191,20 +181,38 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
docs=source_documents,
)
yield ret.model_dump_json()

async for token in callback.aiter():
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content=token,
role="assistant",
model=model,
)

async for chunk in chain.astream({"context": context, "question": query}):
if chunk.additional_kwargs.get("reasoning_content"):
reasoning_token = chunk.additional_kwargs["reasoning_content"]
if reasoning_token:
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
reasoning_content=reasoning_token,
role="assistant",
model=model,
)
# Otherwise, treat it as an answer token
else:
ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content=chunk.content,
role="assistant",
model=model,
)
yield ret.model_dump_json()
else:
answer = ""
async for token in callback.aiter():
answer += token
async for chunk in chain.astream({"context": context, "question": query}):
if chunk.additional_kwargs.get("reasoning_content"):
reasoning_token = chunk.additional_kwargs["reasoning_content"]
if reasoning_token:
answer += reasoning_token
else:
answer += chunk.content

ret = OpenAIChatOutput(
id=f"chat{uuid.uuid4()}",
object="chat.completion",
Expand All @@ -213,7 +221,6 @@ async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
model=model,
)
yield ret.model_dump_json()
await task
except asyncio.exceptions.CancelledError:
logger.warning("streaming progress has been interrupted by user.")
return
Expand Down
7 changes: 4 additions & 3 deletions libs/chatchat-server/chatchat/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from langchain.tools import BaseTool
from langchain_core.embeddings import Embeddings
from langchain_openai.chat_models import ChatOpenAI
from chatchat.server.chat.deepseek import DeepseekChatOpenAI
from langchain_openai.llms import OpenAI
from memoization import cached, CachingAlgorithmFlag

Expand Down Expand Up @@ -225,7 +226,7 @@ def get_ChatOpenAI(
verbose: bool = True,
local_wrap: bool = False, # use local wrapped api
**kwargs: Any,
) -> ChatOpenAI:
) -> DeepseekChatOpenAI:
model_info = get_model_info(model_name)
params = dict(
streaming=streaming,
Expand Down Expand Up @@ -253,7 +254,7 @@ def get_ChatOpenAI(
openai_api_key=model_info.get("api_key"),
openai_proxy=model_info.get("api_proxy"),
)
model = ChatOpenAI(**params)
model = DeepseekChatOpenAI(**params)
except Exception as e:
logger.exception(f"failed to create ChatOpenAI for model: {model_name}.")
model = None
Expand Down Expand Up @@ -817,7 +818,7 @@ def get_httpx_client(
default_proxies.update(proxies)

# construct Client
kwargs.update(timeout=timeout, proxies=default_proxies)
kwargs.update(timeout=timeout)

if use_async:
return httpx.AsyncClient(**kwargs)
Expand Down
23 changes: 19 additions & 4 deletions libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def on_conv_change():

chat_box.ai_say("正在思考...")
text = ""
reasoning_text= ""
started = False

client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
Expand Down Expand Up @@ -519,10 +520,24 @@ def on_conv_change():
for img in d.tool_output.get("images", []):
chat_box.insert_msg(Image(f"{api.base_url}/media/{img}"), pos=-2)
else:
text += d.choices[0].delta.content or ""
chat_box.update_msg(
text.replace("\n", "\n\n"), streaming=True, metadata=metadata
)
reasoning_content = getattr(d.choices[0].delta, "reasoning_content", None)
if reasoning_content:
if reasoning_text=="":
chat_box.insert_msg(
Markdown("...", in_expander=True, title="深度思考", state="running", expanded=True)
)
reasoning_text += reasoning_content
chat_box.update_msg(reasoning_text, streaming=True, state="running")
continue
else:
content = getattr(d.choices[0].delta, "content", None)
if content:
if text=="" and reasoning_text!="":
#正式答案开始首次输出后,结束之前的深度思考
chat_box.update_msg(reasoning_text, streaming=False, state="complete")
chat_box.insert_msg("")
text += content
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True)
chat_box.update_msg(text, streaming=False, metadata=metadata)
except Exception as e:
st.error(e.body)
Expand Down
20 changes: 18 additions & 2 deletions libs/chatchat-server/chatchat/webui_pages/kb_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def on_conv_change():
])

text = ""
reasoning_text=""
first = True

try:
Expand All @@ -228,8 +229,23 @@ def on_conv_change():
chat_box.update_msg("", streaming=False)
first = False
continue
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True)
reasoning_content = getattr(d.choices[0].delta, "reasoning_content", None)
if reasoning_content:
if reasoning_text=="":
chat_box.insert_msg(
Markdown("...", in_expander=True, title="深度思考", state="running", expanded=True)
)
reasoning_text += reasoning_content
chat_box.update_msg(reasoning_text, streaming=True, state="running")
continue
else:
content = getattr(d.choices[0].delta, "content", None)
if content:
if text=="" and reasoning_text!="":
chat_box.update_msg(reasoning_text, streaming=False, state="complete")
chat_box.insert_msg("")
text += content
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True)
chat_box.update_msg(text, streaming=False)
# TODO: 搜索未配置API KEY时产生报错
except Exception as e:
Expand Down