Skip to content

Commit

Permalink
Autofix fix malformed json output from llm models
Browse files Browse the repository at this point in the history
  • Loading branch information
CTY-git committed Nov 4, 2024
1 parent ef4fef5 commit ba0e76c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 2 additions & 0 deletions patchwork/steps/PR/PR.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def __handle_modified_code_files(self):
if self.inputs.get("modified_code_files") is not None:
return

self.inputs["modified_code_files"] = []

input_modified_files = self.inputs.get("modified_files")
if input_modified_files is None or len(input_modified_files) < 1:
return
Expand Down
19 changes: 13 additions & 6 deletions patchwork/steps/SimplifiedLLM/SimplifiedLLM.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from functools import partial

from json_repair import repair_json
from patchwork.common.client.llm.utils import example_json_to_schema
from patchwork.common.utils.utils import RetryData, exclude_none_dict, retry
from patchwork.logger import logger
Expand All @@ -13,11 +13,18 @@
from patchwork.steps.SimplifiedLLM.typed import SimplifiedLLMInputs


def json_loads(s: str) -> dict:
def json_loads(json_str: str) -> dict:
try:
return json.loads(json_str, strict=False)
except json.JSONDecodeError as e:
logger.debug(f"Json to decode: \n{json_str}\nError: \n{e}")

try:
return json.loads(s, strict=False)
except json.JSONDecodeError:
return dict()
json_str = repair_json(json_str, skip_json_loads=True)
return json.loads(json_str, strict=False)
except json.JSONDecodeError as e:
logger.debug(f"Json to decode: \n{json_str}\nError: \n{e}")
raise e


class SimplifiedLLM(Step):
Expand Down Expand Up @@ -49,7 +56,7 @@ def __retry_unit(self, prepare_prompt_outputs, call_llm_inputs, retry_data: Retr
json_responses = []
for response in call_llm_outputs.get("openai_responses"):
try:
json_response = json.loads(response, strict=False)
json_response = json_loads(response)
json_responses.append(json_response)
except json.JSONDecodeError as e:
logger.error(f"Json to decode: \n{response}\nError: \n{e}")
Expand Down

0 comments on commit ba0e76c

Please sign in to comment.