Skip to content

Commit

Permalink
fix: Add retry limits and improve error handling for JSON extraction
Browse files Browse the repository at this point in the history
- Add MAX_JSON_RETRIES constant to limit retries
- Enhance error messages in extract_json_between_markers
- Implement retry limits in generate_ideas.py and perform_writeup.py
- Prevent infinite loops when JSON extraction fails

Fixes SakanaAI#154

Co-Authored-By: Erkin Alp Güney <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and erkinalp committed Dec 18, 2024
1 parent c19f0f8 commit 6909fdd
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 58 deletions.
111 changes: 60 additions & 51 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import backoff
import requests

from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS, MAX_JSON_RETRIES

S2_API_KEY = os.getenv("S2_API_KEY")

Expand Down Expand Up @@ -112,56 +112,65 @@ def generate_ideas(
for _ in range(max_num_generations):
print()
print(f"Generating idea {_ + 1}/{max_num_generations}")
try:
prev_ideas_string = "\n\n".join(idea_str_archive)

msg_history = []
print(f"Iteration 1/{num_reflections}")
text, msg_history = get_response_from_llm(
idea_first_prompt.format(
task_description=prompt["task_description"],
code=code,
prev_ideas_string=prev_ideas_string,
num_reflections=num_reflections,
),
client=client,
model=model,
system_message=idea_system_prompt,
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
print(json_output)

# Iteratively improve task.
if num_reflections > 1:
for j in range(num_reflections - 1):
print(f"Iteration {j + 2}/{num_reflections}")
text, msg_history = get_response_from_llm(
idea_reflection_prompt.format(
current_round=j + 2, num_reflections=num_reflections
),
client=client,
model=model,
system_message=idea_system_prompt,
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert (
json_output is not None
), "Failed to extract JSON from LLM output"
print(json_output)

if "I am done" in text:
print(f"Idea generation converged after {j + 2} iterations.")
break

idea_str_archive.append(json.dumps(json_output))
except Exception as e:
print(f"Failed to generate idea: {e}")
continue
retry_count = 0
while retry_count < MAX_JSON_RETRIES:
try:
prev_ideas_string = "\n\n".join(idea_str_archive)

msg_history = []
print(f"Iteration 1/{num_reflections} (Attempt {retry_count + 1}/{MAX_JSON_RETRIES})")
text, msg_history = get_response_from_llm(
idea_first_prompt.format(
task_description=prompt["task_description"],
code=code,
prev_ideas_string=prev_ideas_string,
num_reflections=num_reflections,
),
client=client,
model=model,
system_message=idea_system_prompt,
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
if json_output is None:
retry_count += 1
continue
print(json_output)

# Iteratively improve task.
if num_reflections > 1:
for j in range(num_reflections - 1):
print(f"Iteration {j + 2}/{num_reflections}")
text, msg_history = get_response_from_llm(
idea_reflection_prompt.format(
current_round=j + 2, num_reflections=num_reflections
),
client=client,
model=model,
system_message=idea_system_prompt,
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
if json_output is None:
retry_count += 1
continue
print(json_output)

if "I am done" in text:
print(f"Idea generation converged after {j + 2} iterations.")
break

idea_str_archive.append(json.dumps(json_output))
break
except Exception as e:
print(f"Failed to generate idea: {e}")
retry_count += 1
if retry_count >= MAX_JSON_RETRIES:
print(f"Max retries ({MAX_JSON_RETRIES}) reached, skipping idea")
break
continue

## SAVE IDEAS
ideas = []
Expand Down
9 changes: 7 additions & 2 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import openai

MAX_NUM_TOKENS = 4096
MAX_JSON_RETRIES = 3

AVAILABLE_LLMS = [
"claude-3-5-sonnet-20240620",
Expand Down Expand Up @@ -272,7 +273,10 @@ def extract_json_between_markers(llm_output):
try:
parsed_json = json.loads(json_string)
return parsed_json
except json.JSONDecodeError:
except json.JSONDecodeError as e:
# Provide detailed error message
error_msg = f"JSON parse error: {str(e)}\nContent: {json_string[:100]}..."
print(error_msg)
# Attempt to fix common JSON issues
try:
# Remove invalid control characters
Expand All @@ -282,7 +286,8 @@ def extract_json_between_markers(llm_output):
except json.JSONDecodeError:
continue # Try next match

return None # No valid JSON found
print("No valid JSON found in LLM output")
return None


def create_client(model):
Expand Down
22 changes: 17 additions & 5 deletions ai_scientist/perform_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional, Tuple

from ai_scientist.generate_ideas import search_for_papers
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS, MAX_JSON_RETRIES


# GENERATE LATEX
Expand Down Expand Up @@ -312,8 +312,14 @@ def get_citation_aider_prompt(
return None, True

## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
retry_count = 0
while retry_count < MAX_JSON_RETRIES:
json_output = extract_json_between_markers(text)
if json_output is not None:
break
retry_count += 1
if retry_count >= MAX_JSON_RETRIES:
raise ValueError("Failed to extract JSON after max retries")
query = json_output["Query"]
papers = search_for_papers(query)
except Exception as e:
Expand Down Expand Up @@ -354,8 +360,14 @@ def get_citation_aider_prompt(
print("Do not add any.")
return None, False
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
retry_count = 0
while retry_count < MAX_JSON_RETRIES:
json_output = extract_json_between_markers(text)
if json_output is not None:
break
retry_count += 1
if retry_count >= MAX_JSON_RETRIES:
raise ValueError("Failed to extract JSON after max retries")
desc = json_output["Description"]
selected_papers = json_output["Selected"]
selected_papers = str(selected_papers)
Expand Down

0 comments on commit 6909fdd

Please sign in to comment.