Skip to content

Commit

Permalink
Update qianfan_llm_load.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-TG001 committed Oct 29, 2024
1 parent dda0f14 commit 6fbfa32
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions python/qianfan/dataset/stress_test/qianfan_llm_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,16 @@ def _remove_access_token_url_parameter(url: str) -> str:

class _InnerResponseProcessRet:
def __init__(
self, request_meta: Dict, last_resp: Optional[QfResponse], merged_result: str
self,
request_meta: Dict,
last_resp: Optional[QfResponse],
merged_result: str,
res_choices: Dict,
):
self.request_meta = request_meta
self.last_resp = last_resp
self.merged_result = merged_result
self.res_choices = res_choices


class QianfanCustomHttpSession(CustomHttpSession):
Expand Down Expand Up @@ -281,6 +286,8 @@ def _request_internal(
res["request"]["url"] = _remove_access_token_url_parameter(
res["request"]["url"]
)
if res.get("body", {}).get("choices", None) is not None:
res["body"]["choices"] = processed_resp.res_choices

self._write_result(res)

Expand Down Expand Up @@ -439,6 +446,7 @@ def _process_responses(
) -> _InnerResponseProcessRet:
last_resp: Optional[QfResponse] = None
merged_query = ""
res_choices = {}
first_flag, all_empty = True, True
clear_history = False

Expand Down Expand Up @@ -497,16 +505,24 @@ def _process_responses(

if len(resp.body["choices"]) == 0:
break

stream_json = resp.body["choices"][0]
clear_history = stream_json.get("need_clear_history", False)
index = stream_json.get("index", "")
if "delta" in stream_json:
content = stream_json["delta"].get("content", "")
merged_query += content
if index not in res_choices:
if "delta" in stream_json:
res_choices[index] = {
"index": index,
"is_truncated": stream_json["is_truncated"],
"content": stream_json["delta"]["content"],
"need_clear_history": stream_json["need_clear_history"],
}
else:
res_choices[index]["content"] += stream_json["delta"]["content"]
else:
self.exc = Exception("ERROR CODE 结果无法解析")
break

if len(content) != 0:
all_empty = False

Expand All @@ -521,8 +537,9 @@ def _process_responses(
not self.is_v2 and not last_resp["body"]["is_end"]
):
self.exc = Exception("NOT 200 OR is_end is False")

return _InnerResponseProcessRet(request_meta, last_resp, merged_query)
return _InnerResponseProcessRet(
request_meta, last_resp, merged_query, res_choices
)

def _get_request(self, context: Dict, **kwargs: Any) -> Iterator[QfResponse]:
if "messages" in kwargs:
Expand Down Expand Up @@ -674,7 +691,7 @@ def _process_responses(
)
break

return _InnerResponseProcessRet(request_meta, last_resp, merged_query)
return _InnerResponseProcessRet(request_meta, last_resp, merged_query, {})

def _get_request(self, context: Dict, **kwargs: Any) -> Iterator[QfResponse]:
if "prompt" in kwargs:
Expand Down

0 comments on commit 6fbfa32

Please sign in to comment.