Skip to content

Commit

Permalink
Merge pull request #362 from yongman/add-wss-subprotocol
Browse files Browse the repository at this point in the history
wss: make it more reliable by using subprotocols with ack
  • Loading branch information
moeakwak authored Feb 3, 2024
2 parents b9aa278 + 624cb3b commit ecb4047
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions backend/api/sources/openai_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,33 +189,41 @@ def make_session() -> httpx.AsyncClient:


async def _receive_from_websocket(wss_url):
async with websockets.connect(wss_url) as websocket:
recv_msg_count = 0
async with websockets.connect(wss_url, subprotocols=["json.reliable.webpubsub.azure.v1"]) as websocket:
logger.debug(f"Connected to Websocket {wss_url[:65]}...{wss_url[-10:]}")
try:
while True:
message = await websocket.recv()
message = json.loads(message)
line_data = base64.b64decode(message['body']).decode('utf-8')
if not line_data or line_data is None:
continue
if "data: " in line_data:
line_data = line_data[6:]
if "[DONE]" in line_data:
break
try:
line_data = json.loads(line_data)
except json.decoder.JSONDecodeError:
while True:
message = await websocket.recv()
message = json.loads(message)
if "data" not in message:
continue
sequence_id = message["sequenceId"]
data = base64.b64decode(message['data']['body']).decode('utf-8')
if not data or data is None:
continue
if "data: " in data:
data = data[6:]
if "[DONE]" in data:
# send ack to server
await websocket.send(json.dumps({"type": "sequenceAck", "sequenceId": sequence_id}))
break
try:
data = json.loads(data)
except json.decoder.JSONDecodeError:
continue
if not _check_fields(data):
if "error" in data:
raise OpenaiWebException(data["error"])
else:
logger.warning(f"Field missing. Details: {str(data)}")
continue
if not _check_fields(line_data):
if "error" in line_data:
raise OpenaiWebException(line_data["error"])
else:
logger.warning(f"Field missing. Details: {str(line_data)}")
continue
yield line_data
except websockets.exceptions.ConnectionClosedError:
logger.debug("Connection closed.")
pass
recv_msg_count += 1
# batch ack to server every 10 messages
if recv_msg_count > 10:
await websocket.send(json.dumps({"type": "sequenceAck", "sequenceId": sequence_id}))
recv_msg_count = 0
yield data
logger.debug("Connection closed.")


class OpenaiWebChatManager(metaclass=SingletonMeta):
Expand Down

0 comments on commit ecb4047

Please sign in to comment.