Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Groq #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ conda create -n ai_scientist python=3.11
conda activate ai_scientist

# LLM APIs
pip install anthropic aider-chat backoff openai
pip install anthropic aider-chat backoff openai groq
# Viz
pip install matplotlib pypdf pymupdf4llm
# Install pdflatex
Expand All @@ -55,7 +55,7 @@ pip install torch numpy transformers datasets tiktoken wandb tqdm

We use the following environment variables for the different API providers for different models:

`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `DEEPSEEK_API_KEY`, `OPENROUTER_API_KEY`
`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `DEEPSEEK_API_KEY`, `OPENROUTER_API_KEY`, `GROQ_API_KEY`

Our code can also optionally use a Semantic Scholar API Key (`S2_API_KEY`) for higher throughput [if you have one](https://www.semanticscholar.org/product/api), though in principle it should work without it.

Expand Down Expand Up @@ -115,6 +115,7 @@ conda activate ai_scientist
# Run the paper generation.
python launch_scientist.py --model "gpt-4o-2024-05-13" --experiment nanoGPT_lite --num-ideas 2
python launch_scientist.py --model "claude-3-5-sonnet-20240620" --experiment nanoGPT_lite --num-ideas 2
python launch_scientist.py --model "llama3-70b-8192" --experiment nanoGPT_lite --num-ideas 2
```

## Getting an LLM Generated Paper Review
Expand Down
6 changes: 6 additions & 0 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def check_idea_novelty(
"gpt-4o-2024-05-13",
"deepseek-coder-v2-0724",
"llama3.1-405b",
"llama3-70b-8192",
],
help="Model to use for AI Scientist.",
)
Expand Down Expand Up @@ -496,6 +497,11 @@ def check_idea_novelty(
print(f"Using OpenAI API with model {args.model}.")
client_model = "gpt-4o-2024-05-13"
client = openai.OpenAI()
elif args.model == "llama3-70b-8192":
from groq import Groq
print(f"Using Groq API with {args.model}.")
client_model = "llama3-70b-8192"
client = Groq(api_key=os.environ["GROQ_API_KEY"])
elif args.model == "deepseek-coder-v2-0724":
import openai

Expand Down
52 changes: 32 additions & 20 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import backoff
import openai
import json

import os
from groq import Groq

# Get N responses from a single message, used for ensembling.
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
Expand Down Expand Up @@ -57,23 +58,29 @@ def get_batch_responses_from_llm(
new_msg_history = [
new_msg_history + [{"role": "assistant", "content": c}] for c in content
]
elif model == "llama-3-1-405b-instruct":
elif model in ["llama-3-1-405b-instruct", "llama3-70b-8192"]:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = client.chat.completions.create(
model="meta-llama/llama-3.1-405b-instruct",
messages=[
{"role": "system", "content": system_message},
*new_msg_history,
],
temperature=temperature,
max_tokens=3000,
n=n_responses,
stop=None,
)
content = [r.message.content for r in response.choices]
new_msg_history = [
new_msg_history + [{"role": "assistant", "content": c}] for c in content
]
if model == "llama-3-1-405b-instruct":
model_name = "meta-llama/llama-3.1-405b-instruct"
else:
model_name = "llama3-70b-8192"
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

content = []
new_msg_history = []
for _ in range(n_responses):
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_message},
*new_msg_history,
],
temperature=temperature,
max_tokens=3000,
stop=None,
)
content.append(response.choices[0].message.content)
new_msg_history.append(new_msg_history + [{"role": "assistant", "content": content[-1]}])
elif model == "claude-3-5-sonnet-20240620":
content, new_msg_history = [], []
for _ in range(n_responses):
Expand All @@ -89,7 +96,6 @@ def get_batch_responses_from_llm(
content.append(c)
new_msg_history.append(hist)
else:
# TODO: This is only supported for GPT-4 in our reviewer pipeline.
raise ValueError(f"Model {model} not supported.")

if print_debug:
Expand Down Expand Up @@ -184,10 +190,16 @@ def get_response_from_llm(
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct", "llama3-70b-8192"]:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
if model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
model_name = "meta-llama/llama-3.1-405b-instruct"
else:
model_name = "llama3-70b-8192"
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

response = client.chat.completions.create(
model="meta-llama/llama-3.1-405b-instruct",
model=model_name,
messages=[
{"role": "system", "content": system_message},
*new_msg_history,
Expand Down
9 changes: 9 additions & 0 deletions ai_scientist/perform_writeup.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def perform_writeup(
"gpt-4o-2024-05-13",
"deepseek-coder-v2-0724",
"llama3.1-405b",
"llama3-70b-8192",
],
help="Model to use for AI Scientist.",
)
Expand All @@ -538,6 +539,12 @@ def perform_writeup(
print(f"Using Anthropic API with model {args.model}.")
client_model = "claude-3-5-sonnet-20240620"
client = anthropic.Anthropic()
elif args.model == "llama3-70b-8192":
from groq import Groq
print(f"Using Groq API with {args.model}.")
client_model = "llama3-70b-8192"
client = Groq(api_key=os.environ["GROQ_API_KEY"])

elif args.model == "gpt-4o-2024-05-13" or args.model == "hybrid":
import openai

Expand Down Expand Up @@ -586,6 +593,8 @@ def perform_writeup(
main_model = Model("deepseek/deepseek-coder")
elif args.model == "llama3.1-405b":
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
elif args.model == "llama3-70b-8192":
main_model = Model("groq/llama3-70b-8192")
else:
main_model = Model(model)
coder = Coder.create(
Expand Down
10 changes: 10 additions & 0 deletions experimental/launch_oe_scientist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def parse_arguments():
"gpt-4o-2024-05-13",
"deepseek-coder-v2-0724",
"llama3.1-405b",
"llama3-70b-8192",
],
help="Model to use for AI Scientist.",
)
Expand Down Expand Up @@ -189,6 +190,8 @@ def do_idea(
main_model = Model("deepseek/deepseek-coder")
elif model == "llama3.1-405b":
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
elif model == "llama3-70b-8192":
main_model = Model("groq/llama3-70b-8192")
else:
main_model = Model(model)
coder = Coder.create(
Expand Down Expand Up @@ -225,6 +228,8 @@ def do_idea(
main_model = Model("deepseek/deepseek-coder")
elif model == "llama3.1-405b":
main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct")
elif model == "llama3-70b-8192":
main_model = Model("llama3-70b-8192")
else:
main_model = Model(model)
coder = Coder.create(
Expand Down Expand Up @@ -348,6 +353,11 @@ def do_idea(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
)
elif args.model == "llama3-70b-8192":
from groq import Groq
print(f"Using Groq API with {args.model}.")
client_model = "llama3-70b-8192"
client = Groq(api_key=os.environ["GROQ_API_KEY"])
else:
raise ValueError(f"Model {args.model} not supported.")

Expand Down
Loading