diff --git a/backend/api/routers/chat.py b/backend/api/routers/chat.py index b514ef69..4239d7b7 100644 --- a/backend/api/routers/chat.py +++ b/backend/api/routers/chat.py @@ -566,6 +566,8 @@ async def reply(response: AskResponse): create_time=current_time, update_time=current_time ) + if use_team: + new_conv.source_id = config.openai_web.team_account_id conversation = BaseConversation(**new_conv.model_dump(exclude_unset=True)) session.add(conversation) diff --git a/backend/api/sources/openai_web.py b/backend/api/sources/openai_web.py index eb847578..21f6dcdc 100644 --- a/backend/api/sources/openai_web.py +++ b/backend/api/sources/openai_web.py @@ -2,7 +2,6 @@ import json import uuid from mimetypes import guess_type -from typing import AsyncGenerator import websockets import base64 @@ -11,10 +10,11 @@ import httpx from fastapi.encoders import jsonable_encoder import aiohttp +from httpx import AsyncClient from pydantic import ValidationError from api.conf import Config, Credentials -from api.enums import OpenaiWebChatModels, ChatSourceTypes +from api.enums import OpenaiWebChatModels from api.exceptions import InvalidParamsException, OpenaiWebException, ResourceNotFoundException from api.file_provider import FileProvider from api.models.doc import OpenaiWebChatMessageMetadata, OpenaiWebConversationHistoryDocument, \ @@ -229,7 +229,7 @@ async def _receive_from_websocket(wss_url): class OpenaiWebChatManager(metaclass=SingletonMeta): def __init__(self): self.semaphore = asyncio.Semaphore(1) - self.session = None + self.session: AsyncClient | None = None self.reset_session() def is_busy(self): @@ -306,7 +306,8 @@ async def clear_conversations(self, use_team: bool = False): response = await self.session.patch(url, json={"is_visible": False}, headers=req_headers(use_team)) await _check_response(response) - async def complete(self, model: OpenaiWebChatModels, text_content: str, use_team: bool, conversation_id: uuid.UUID = None, + async def complete(self, model: OpenaiWebChatModels, text_content: str, use_team: bool, + conversation_id: uuid.UUID = None, parent_message_id: uuid.UUID = None, plugin_ids: list[str] = None, attachments: list[OpenaiWebChatMessageMetadataAttachment] = None, @@ -369,15 +370,12 @@ async def complete(self, model: OpenaiWebChatModels, text_content: str, use_team completion_request["arkose_token"] = None data_json = json.dumps(jsonable_encoder(completion_request)) - async with self.session.stream( - method="POST", - url=f"{config.openai_web.chatgpt_base_url}conversation", - data=data_json, - timeout=timeout, - headers=req_headers(use_team) | { - "referer": "https://chat.openai.com/" + (f"c/{conversation_id}" if conversation_id else "") - } - ) as response: + async with self.session.stream(method="POST", url=f"{config.openai_web.chatgpt_base_url}conversation", + data=data_json, timeout=timeout, + headers=req_headers(use_team) | { + "referer": "https://chat.openai.com/" + ( + f"c/{conversation_id}" if conversation_id else "") + }) as response: await _check_response(response) async for line in response.aiter_lines(): @@ -428,7 +426,7 @@ async def generate_conversation_title(self, conversation_id: str, message_id: st response = await self.session.post( url, json={"message_id": message_id}, - hearders=req_headers(use_team) + headers=req_headers(use_team) ) await _check_response(response) result = response.json()