Skip to content

Commit

Permalink
use json repair for llm client
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzhongshu123 committed Jan 21, 2025
1 parent cdf0ea3 commit 5f615cd
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 109 deletions.
28 changes: 0 additions & 28 deletions kag/common/llm/ollama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

import json

import logging
from ollama import Client

from kag.interface import LLMClient
from tenacity import retry, stop_after_attempt


# logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -83,28 +80,3 @@ def __call__(self, prompt, image=None):
"""

return self.sync_request(prompt, image)

@retry(stop=stop_after_attempt(3))
def call_with_json_parse(self, prompt):
"""
Calls the model and attempts to parse the response into JSON format.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
"""

rsp = self.sync_request(prompt)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json") : _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result
63 changes: 11 additions & 52 deletions kag/common/llm/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
# or implied.


import json
from openai import OpenAI, AzureOpenAI
import logging

from kag.interface import LLMClient
from tenacity import retry, stop_after_attempt
from typing import Callable

logging.getLogger("openai").setLevel(logging.ERROR)
Expand All @@ -25,9 +23,9 @@

AzureADTokenProvider = Callable[[], str]


@LLMClient.register("maas")
@LLMClient.register("openai")

class OpenAIClient(LLMClient):
"""
A client class for interacting with the OpenAI API.
Expand Down Expand Up @@ -113,32 +111,9 @@ def __call__(self, prompt: str, image_url: str = None):
rsp = response.choices[0].message.content
return rsp

@retry(stop=stop_after_attempt(3))
def call_with_json_parse(self, prompt):
"""
Calls the model and attempts to parse the response into JSON format.
Parameters:
prompt (str): The prompt provided to the model.

Returns:
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
"""
# Call the model and attempt to parse the response into JSON format
rsp = self(prompt)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json") : _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result
@LLMClient.register("azure_openai")
class AzureOpenAIClient (LLMClient):
class AzureOpenAIClient(LLMClient):
def __init__(
self,
api_key: str,
Expand Down Expand Up @@ -180,7 +155,15 @@ def __init__(
self.api_version = api_version
self.azure_ad_token = azure_ad_token
self.azure_ad_token_provider = azure_ad_token_provider
self.client = AzureOpenAI(api_key=self.api_key, base_url=self.base_url,azure_deployment=self.azure_deployment ,model=self.model,api_version=self.api_version, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider)
self.client = AzureOpenAI(
api_key=self.api_key,
base_url=self.base_url,
azure_deployment=self.azure_deployment,
model=self.model,
api_version=self.api_version,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
)
self.check()

def __call__(self, prompt: str, image_url: str = None):
Expand Down Expand Up @@ -229,27 +212,3 @@ def __call__(self, prompt: str, image_url: str = None):
)
rsp = response.choices[0].message.content
return rsp
@retry(stop=stop_after_attempt(3))
def call_with_json_parse(self, prompt):
"""
Calls the model and attempts to parse the response into JSON format.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
"""
# Call the model and attempt to parse the response into JSON format
rsp = self(prompt)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json") : _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result
27 changes: 0 additions & 27 deletions kag/common/llm/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import requests
from kag.interface import LLMClient
from tenacity import retry, stop_after_attempt


# logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -89,29 +88,3 @@ def __call__(self, prompt):

content = [{"role": "user", "content": prompt}]
return self.sync_request(content)

@retry(stop=stop_after_attempt(3))
def call_with_json_parse(self, prompt):
"""
Calls the model and attempts to parse the response into JSON format.
Parameters:
prompt (str): The prompt provided to the model.
Returns:
Union[dict, str]: If the response is valid JSON, returns the parsed dictionary; otherwise, returns the original response.
"""

content = [{"role": "user", "content": prompt}]
rsp = self.sync_request(content)
_end = rsp.rfind("```")
_start = rsp.find("```json")
if _end != -1 and _start != -1:
json_str = rsp[_start + len("```json") : _end].strip()
else:
json_str = rsp
try:
json_result = json.loads(json_str)
except:
return rsp
return json_result
8 changes: 6 additions & 2 deletions kag/interface/common/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

import json
try:
from json_repair import loads
except:
from json import loads
from typing import Union, Dict, List, Any
import logging
import traceback
Expand Down Expand Up @@ -67,7 +70,7 @@ def call_with_json_parse(self, prompt: Union[str, dict, list]):
else:
json_str = res
try:
json_result = json.loads(json_str)
json_result = loads(json_str)
except:
return res
return json_result
Expand Down Expand Up @@ -108,6 +111,7 @@ def invoke(
logger.debug(f"Result: {result}")
except Exception as e:
import traceback

logger.info(f"Error {e} during invocation: {traceback.format_exc()}")
if with_except:
raise RuntimeError(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ zodb
matplotlib
PyPDF2
ruamel.yaml
json_repair

0 comments on commit 5f615cd

Please sign in to comment.