From 5f615cdcf183659359010d6f4ab897507f3fb183 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=9F=E4=B9=A6?= Date: Tue, 21 Jan 2025 10:57:52 +0800 Subject: [PATCH] use json repair for llm client --- kag/common/llm/ollama_client.py | 28 ------------- kag/common/llm/openai_client.py | 63 ++++++------------------------ kag/common/llm/vllm_client.py | 27 ------------- kag/interface/common/llm_client.py | 8 +++- requirements.txt | 1 + 5 files changed, 18 insertions(+), 109 deletions(-) diff --git a/kag/common/llm/ollama_client.py b/kag/common/llm/ollama_client.py index 7791e792..6d77346c 100644 --- a/kag/common/llm/ollama_client.py +++ b/kag/common/llm/ollama_client.py @@ -10,13 +10,10 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. -import json - import logging from ollama import Client from kag.interface import LLMClient -from tenacity import retry, stop_after_attempt # logging.basicConfig(level=logging.DEBUG) @@ -83,28 +80,3 @@ def __call__(self, prompt, image=None): """ return self.sync_request(prompt, image) - - @retry(stop=stop_after_attempt(3)) - def call_with_json_parse(self, prompt): - """ - Calls the model and attempts to parse the response into JSON format. - - Parameters: - prompt (str): The prompt provided to the model. - - Returns: - Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response. - """ - - rsp = self.sync_request(prompt) - _end = rsp.rfind("```") - _start = rsp.find("```json") - if _end != -1 and _start != -1: - json_str = rsp[_start + len("```json") : _end].strip() - else: - json_str = rsp - try: - json_result = json.loads(json_str) - except: - return rsp - return json_result diff --git a/kag/common/llm/openai_client.py b/kag/common/llm/openai_client.py index e4af7e2e..e8bc537b 100644 --- a/kag/common/llm/openai_client.py +++ b/kag/common/llm/openai_client.py @@ -11,12 +11,10 @@ # or implied. -import json from openai import OpenAI, AzureOpenAI import logging from kag.interface import LLMClient -from tenacity import retry, stop_after_attempt from typing import Callable logging.getLogger("openai").setLevel(logging.ERROR) @@ -25,9 +23,9 @@ AzureADTokenProvider = Callable[[], str] + @LLMClient.register("maas") @LLMClient.register("openai") - class OpenAIClient(LLMClient): """ A client class for interacting with the OpenAI API. @@ -113,32 +111,9 @@ def __call__(self, prompt: str, image_url: str = None): rsp = response.choices[0].message.content return rsp - @retry(stop=stop_after_attempt(3)) - def call_with_json_parse(self, prompt): - """ - Calls the model and attempts to parse the response into JSON format. - - Parameters: - prompt (str): The prompt provided to the model. - Returns: - Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response. - """ - # Call the model and attempt to parse the response into JSON format - rsp = self(prompt) - _end = rsp.rfind("```") - _start = rsp.find("```json") - if _end != -1 and _start != -1: - json_str = rsp[_start + len("```json") : _end].strip() - else: - json_str = rsp - try: - json_result = json.loads(json_str) - except: - return rsp - return json_result @LLMClient.register("azure_openai") -class AzureOpenAIClient (LLMClient): +class AzureOpenAIClient(LLMClient): def __init__( self, api_key: str, @@ -180,7 +155,15 @@ def __init__( self.api_version = api_version self.azure_ad_token = azure_ad_token self.azure_ad_token_provider = azure_ad_token_provider - self.client = AzureOpenAI(api_key=self.api_key, base_url=self.base_url,azure_deployment=self.azure_deployment ,model=self.model,api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider) + self.client = AzureOpenAI( + api_key=self.api_key, + base_url=self.base_url, + azure_deployment=self.azure_deployment, + model=self.model, + api_version=self.api_version, + azure_ad_token=self.azure_ad_token, + azure_ad_token_provider=self.azure_ad_token_provider, + ) self.check() def __call__(self, prompt: str, image_url: str = None): @@ -229,27 +212,3 @@ def __call__(self, prompt: str, image_url: str = None): ) rsp = response.choices[0].message.content return rsp - @retry(stop=stop_after_attempt(3)) - def call_with_json_parse(self, prompt): - """ - Calls the model and attempts to parse the response into JSON format. - - Parameters: - prompt (str): The prompt provided to the model. - - Returns: - Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response. - """ - # Call the model and attempt to parse the response into JSON format - rsp = self(prompt) - _end = rsp.rfind("```") - _start = rsp.find("```json") - if _end != -1 and _start != -1: - json_str = rsp[_start + len("```json") : _end].strip() - else: - json_str = rsp - try: - json_result = json.loads(json_str) - except: - return rsp - return json_result \ No newline at end of file diff --git a/kag/common/llm/vllm_client.py b/kag/common/llm/vllm_client.py index 47a37b19..36868c5c 100644 --- a/kag/common/llm/vllm_client.py +++ b/kag/common/llm/vllm_client.py @@ -15,7 +15,6 @@ import logging import requests from kag.interface import LLMClient -from tenacity import retry, stop_after_attempt # logging.basicConfig(level=logging.DEBUG) @@ -89,29 +88,3 @@ def __call__(self, prompt): content = [{"role": "user", "content": prompt}] return self.sync_request(content) - - @retry(stop=stop_after_attempt(3)) - def call_with_json_parse(self, prompt): - """ - Calls the model and attempts to parse the response into JSON format. - - Parameters: - prompt (str): The prompt provided to the model. - - Returns: - Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response. - """ - - content = [{"role": "user", "content": prompt}] - rsp = self.sync_request(content) - _end = rsp.rfind("```") - _start = rsp.find("```json") - if _end != -1 and _start != -1: - json_str = rsp[_start + len("```json") : _end].strip() - else: - json_str = rsp - try: - json_result = json.loads(json_str) - except: - return rsp - return json_result diff --git a/kag/interface/common/llm_client.py b/kag/interface/common/llm_client.py index aba82756..4a644b07 100644 --- a/kag/interface/common/llm_client.py +++ b/kag/interface/common/llm_client.py @@ -10,7 +10,10 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. -import json +try: + from json_repair import loads +except: + from json import loads from typing import Union, Dict, List, Any import logging import traceback @@ -67,7 +70,7 @@ def call_with_json_parse(self, prompt: Union[str, dict, list]): else: json_str = res try: - json_result = json.loads(json_str) + json_result = loads(json_str) except: return res return json_result @@ -108,6 +111,7 @@ def invoke( logger.debug(f"Result: {result}") except Exception as e: import traceback + logger.info(f"Error {e} during invocation: {traceback.format_exc()}") if with_except: raise RuntimeError( diff --git a/requirements.txt b/requirements.txt index 7c72d036..ffb90fe2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,3 +45,4 @@ zodb matplotlib PyPDF2 ruamel.yaml +json_repair