diff --git a/ai_scientist/generate_ideas.py b/ai_scientist/generate_ideas.py index d5f4f97..cf6179b 100644 --- a/ai_scientist/generate_ideas.py +++ b/ai_scientist/generate_ideas.py @@ -130,9 +130,18 @@ def generate_ideas( msg_history=msg_history, ) ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert json_output is not None, "Failed to extract JSON from LLM output" - print(json_output) + try: + json_output = extract_json_between_markers(text) + if json_output is None: + print("Failed to extract JSON from LLM output") + continue + print(json_output) + except ValueError as e: + print(f"Error extracting JSON: {e}") + continue + except Exception as e: + print(f"Unexpected error while extracting JSON: {e}") + continue # Iteratively improve task. if num_reflections > 1: @@ -148,11 +157,18 @@ def generate_ideas( msg_history=msg_history, ) ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert ( - json_output is not None - ), "Failed to extract JSON from LLM output" - print(json_output) + try: + json_output = extract_json_between_markers(text) + if json_output is None: + print("Failed to extract JSON from LLM output") + continue + print(json_output) + except ValueError as e: + print(f"Error extracting JSON: {e}") + continue + except Exception as e: + print(f"Unexpected error while extracting JSON: {e}") + continue if "I am done" in text: print(f"Idea generation converged after {j + 2} iterations.") @@ -229,9 +245,18 @@ def generate_next_idea( msg_history=msg_history, ) ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert json_output is not None, "Failed to extract JSON from LLM output" - print(json_output) + try: + json_output = extract_json_between_markers(text) + if json_output is None: + print("Failed to extract JSON from LLM output") + continue + print(json_output) + except ValueError as e: + print(f"Error extracting JSON: {e}") + continue + except Exception as e: + print(f"Unexpected error while extracting JSON: {e}") + continue # Iteratively improve task. if num_reflections > 1: @@ -247,11 +272,18 @@ def generate_next_idea( msg_history=msg_history, ) ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert ( - json_output is not None - ), "Failed to extract JSON from LLM output" - print(json_output) + try: + json_output = extract_json_between_markers(text) + if json_output is None: + print("Failed to extract JSON from LLM output") + continue + print(json_output) + except ValueError as e: + print(f"Error extracting JSON: {e}") + continue + except Exception as e: + print(f"Unexpected error while extracting JSON: {e}") + continue if "I am done" in text: print( @@ -409,29 +441,42 @@ def check_idea_novelty( break ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert json_output is not None, "Failed to extract JSON from LLM output" - - ## SEARCH FOR PAPERS - query = json_output["Query"] - papers = search_for_papers(query, result_limit=10) - if papers is None: - papers_str = "No papers found." - - paper_strings = [] - for i, paper in enumerate(papers): - paper_strings.append( - """{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format( - i=i, - title=paper["title"], - authors=paper["authors"], - venue=paper["venue"], - year=paper["year"], - cites=paper["citationCount"], - abstract=paper["abstract"], + try: + json_output = extract_json_between_markers(text) + if json_output is None: + print("Failed to extract JSON from LLM output") + continue + + ## SEARCH FOR PAPERS + query = json_output["Query"] + papers = search_for_papers(query, result_limit=10) + if papers is None: + papers_str = "No papers found." + + paper_strings = [] + for i, paper in enumerate(papers): + paper_strings.append( + """{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format( + i=i, + title=paper["title"], + authors=paper["authors"], + venue=paper["venue"], + year=paper["year"], + cites=paper["citationCount"], + abstract=paper["abstract"], + ) ) - ) - papers_str = "\n\n".join(paper_strings) + papers_str = "\n\n".join(paper_strings) + + except ValueError as e: + print(f"Error extracting JSON: {e}") + continue + except KeyError as e: + print(f"Missing required field in JSON: {e}") + continue + except Exception as e: + print(f"Unexpected error while extracting JSON: {e}") + continue except Exception as e: print(f"Error: {e}") diff --git a/ai_scientist/perform_writeup.py b/ai_scientist/perform_writeup.py index 7dc9eeb..d0ddee9 100644 --- a/ai_scientist/perform_writeup.py +++ b/ai_scientist/perform_writeup.py @@ -312,10 +312,20 @@ def get_citation_aider_prompt( return None, True ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert json_output is not None, "Failed to extract JSON from LLM output" - query = json_output["Query"] - papers = search_for_papers(query) + try: + json_output = extract_json_between_markers(text) + if json_output is None: + print("Failed to extract JSON from LLM output") + return None, False + query = json_output["Query"] + papers = search_for_papers(query) + except ValueError as e: + print(f"Error extracting JSON: {e}") + return None, False + except KeyError as e: + print(f"Missing required field in JSON: {e}") + return None, False + except Exception as e: print(f"Error: {e}") return None, False @@ -354,21 +364,31 @@ def get_citation_aider_prompt( print("Do not add any.") return None, False ## PARSE OUTPUT - json_output = extract_json_between_markers(text) - assert json_output is not None, "Failed to extract JSON from LLM output" - desc = json_output["Description"] - selected_papers = json_output["Selected"] - selected_papers = str(selected_papers) - - # convert to list - if selected_papers != "[]": - selected_papers = list(map(int, selected_papers.strip("[]").split(","))) - assert all( - [0 <= i < len(papers) for i in selected_papers] - ), "Invalid paper index" - bibtexs = [papers[i]["citationStyles"]["bibtex"] for i in selected_papers] - bibtex_string = "\n".join(bibtexs) - else: + try: + json_output = extract_json_between_markers(text) + if json_output is None: + print("Failed to extract JSON from LLM output") + return None, False + desc = json_output["Description"] + selected_papers = json_output["Selected"] + selected_papers = str(selected_papers) + + # convert to list + if selected_papers != "[]": + selected_papers = list(map(int, selected_papers.strip("[]").split(","))) + assert all( + [0 <= i < len(papers) for i in selected_papers] + ), "Invalid paper index" + bibtexs = [papers[i]["citationStyles"]["bibtex"] for i in selected_papers] + bibtex_string = "\n".join(bibtexs) + else: + return None, False + + except ValueError as e: + print(f"Error extracting JSON: {e}") + return None, False + except KeyError as e: + print(f"Missing required field in JSON: {e}") return None, False except Exception as e: