-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
257 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Environment Setup | ||
|
||
## ML-Agent-Bench Docker Setup | ||
|
||
To run the ML-Agent-Bench Docker container, you can use the following command: | ||
|
||
```bash | ||
docker pull public.ecr.aws/i5g0m1f6/ml-bench | ||
docker run -it public.ecr.aws/i5g0m1f6/ml-bench /bin/bash | ||
``` | ||
|
||
This will pull the latest ML-Agent-Bench Docker image and run it in an interactive shell. The container includes all the necessary dependencies to run the ML-Agent-Bench codebase. | ||
|
||
For ML-Agent-Bench in OpenDevin, please refer to the [OpenDevin setup guide](https://github.com/OpenDevin/OpenDevin/blob/main/evaluation/ml_bench/README.md). |
Submodule open_devin
added at
35c4c9
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# OpenAI Calling | ||
|
||
To reproduce OpenAI's performance on this task, use the following script: | ||
```bash | ||
bash script/openai/run.sh | ||
``` | ||
|
||
## Parameter Settings | ||
|
||
You need to change the parameter settings in `script/openai/run.sh`: | ||
|
||
- `type`: Choose from `quarter` or `full`. | ||
- `model`: Model name. | ||
- `input_file`: File path of the dataset. | ||
- `answer_file`: Original answer in JSON format from GPT. | ||
- `parsing_file`: Post-process the output of GPT in JSONL format to obtain executable code segments. | ||
- `readme_type`: Choose from `oracle_segment` and `readme`. | ||
- `oracle_segment`: The code paragraph in the README that is most relevant to the task. | ||
- `readme`: The entire text of the README in the repository where the task is located. | ||
- `engine_name`: Choose from `gpt-35-turbo-16k` and `gpt-4-32`. | ||
- `n_turn`: Number of executable codes GPT returns (5 times in the paper experiment). | ||
- `openai_key`: Your OpenAI API key. |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import openai | ||
import yaml | ||
import os | ||
|
||
|
||
def call_GPT(function_prompt,model_name,function_type,function): | ||
if function_type == "auto": | ||
with open("./config/config_azure.yml", "r") as yaml_file: | ||
config = yaml.safe_load(yaml_file) | ||
openai.api_base = config["api_base"] | ||
openai.api_type = config["api_type"] | ||
openai.api_version = config["api_version"] | ||
#openai.api_proxy = config["api_proxy"] | ||
openai.api_key = config["openai_keys"][model_name][0]["api_key"] | ||
try: | ||
res = openai.ChatCompletion.create( | ||
engine=model_name, | ||
messages=[ | ||
{"role": "user", | ||
"content": function_prompt} | ||
], | ||
functions = [function], | ||
function_call = "auto" , | ||
) | ||
return res | ||
except Exception as e: | ||
print("An exception occurred:", e) | ||
elif function_type == "none": | ||
with open("./config/config_azure.yml", "r") as yaml_file: | ||
config = yaml.safe_load(yaml_file) | ||
openai.api_base = config["api_base"] | ||
openai.api_type = config["api_type"] | ||
openai.api_version = config["api_version"] | ||
#openai.api_proxy = config["api_proxy"] | ||
openai.api_key = config["openai_keys"][model_name][0]["api_key"] | ||
try: | ||
res = openai.ChatCompletion.create( | ||
engine=model_name, | ||
messages=[ | ||
{"role": "user", | ||
"content": function_prompt} | ||
] | ||
) | ||
return res | ||
except Exception as e: | ||
print("An exception occurred:", e) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import openai | ||
import yaml | ||
import os | ||
|
||
def call_GPT(function_prompt,model_name,function_type,function): | ||
if function_type == "auto": | ||
with open("./config/config_openai.yml", "r") as yaml_file: | ||
config = yaml.safe_load(yaml_file) | ||
openai.api_base = config["api_base"] | ||
openai.api_proxy = config["api_proxy"] | ||
openai.api_key = config["openai_keys"][model_name][0]["api_key"] | ||
try: | ||
res = openai.ChatCompletion.create( | ||
model = model_name, | ||
messages = [ | ||
{"role": "user", | ||
"content": function_prompt} | ||
], | ||
functions = [function], | ||
function_call = "auto" , | ||
) | ||
return res | ||
except Exception as e: | ||
print("An exception occurred:", e) | ||
elif function_type == "none": | ||
with open("./config/config_openai.yml", "r") as yaml_file: | ||
config = yaml.safe_load(yaml_file) | ||
openai.api_base = config["api_base"] | ||
#openai.api_proxy = config["api_proxy"] | ||
openai.api_key = config["openai_keys"][model_name][0]["api_key"] | ||
try: | ||
res = openai.ChatCompletion.create( | ||
model=model_name, | ||
messages=[ | ||
{"role": "user", | ||
"content": function_prompt} | ||
] | ||
) | ||
return res | ||
except Exception as e: | ||
print("An exception occurred:", e) | ||
|
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
|
||
def write_indexfile(repo_name,root_directory): | ||
directory_path = root_directory | ||
all_contents = [os.path.relpath(os.path.join(root, item), directory_path) for root, _, items in os.walk(directory_path) for item in items] | ||
output_file = 'directory_contents.txt' | ||
with open(repo_name+"_index.txt", 'w') as file: | ||
file.write('\n'.join(all_contents) + '\n') | ||
print(f'All directory contents written') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import argparse | ||
|
||
def Get_args(): | ||
parser = argparse.ArgumentParser(description="Please choose a model,api_type,function_call to use agent.") | ||
parser.add_argument('--model_name', type=str, required=True, help="Model name") | ||
parser.add_argument('--api_type', type=str, required=True, help="Api type") | ||
parser.add_argument('--function_type', type=str, required=True, help="Function type:auto or none") | ||
args = parser.parse_args() | ||
model_name = args.model_name | ||
api_type =args.api_type | ||
function_type = args.function_type | ||
return model_name,api_type,function_type |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import re | ||
from tools.read_yml import read_yaml_file | ||
import json | ||
|
||
def get_keywords(query,model_name,api_type,function_type): | ||
if api_type == "openai": | ||
from tools.call_openai import call_GPT | ||
elif api_type == "azure": | ||
from tools.call_azure import call_GPT | ||
function_file = "./functions/step1_function.yml" | ||
function_prompt, function = read_yaml_file(function_file) | ||
function_prompt = function_prompt.format(query) | ||
response = call_GPT(function_prompt,model_name,function_type,function) | ||
print(response) | ||
function_call_message = response["choices"][0]["message"]["function_call"] | ||
function_call_json = json.loads(json.dumps(function_call_message.to_dict())) | ||
res_keywords = json.loads(function_call_json["arguments"])["keywords"] | ||
keywords = res_keywords.split(', ') | ||
return keywords | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import yaml | ||
|
||
def read_yaml_file(yaml_file_path): | ||
try: | ||
with open(yaml_file_path, 'r') as file: | ||
data = yaml.safe_load(file) | ||
|
||
function_prompt = data.get('function_prompt', '') | ||
function = data.get('function', {}) | ||
|
||
return function_prompt, function | ||
except Exception as e: | ||
return None, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
|
||
def find_readme_files(directory): | ||
readme_files = [] | ||
for root, dirs, files in os.walk(directory): | ||
for file in files: | ||
if file.lower() == "readme.md": | ||
readme_files.append(os.path.join(root, file)) | ||
readme_files = sorted(readme_files, key=len) | ||
return readme_files |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from tools import keywords | ||
import requests | ||
|
||
def search_github_repositories_by_keywords(keywords): | ||
query_string = '+'.join(keywords) | ||
|
||
print(query_string) | ||
api_url = f"https://api.github.com/search/repositories?q={query_string}&page=1&per_page=10" | ||
|
||
response = requests.get(api_url) | ||
|
||
if response.status_code == 200: | ||
search_results = response.json() | ||
|
||
repo_urls = [repo["html_url"] for repo in search_results["items"]] | ||
return repo_urls | ||
else: | ||
return [] | ||
|
||
def get_repo_urls(query,model_name,api_type,function_type): | ||
keywds = keywords.get_keywords(query,model_name,api_type,function_type) | ||
print(keywds) | ||
repo_urls = search_github_repositories_by_keywords(keywds) | ||
return repo_urls | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
import requests | ||
|
||
def get_repo_description(repo_url): | ||
parts = repo_url.strip('/').split('/') | ||
if len(parts) != 5 or parts[2] != 'github.com': | ||
return "Invalid GitHub repo URL" | ||
username, repository_name = parts[3], parts[4] | ||
|
||
api_url = f"https://api.github.com/repos/{username}/{repository_name}" | ||
|
||
response = requests.get(api_url) | ||
|
||
if response.status_code == 200: | ||
repo_info = response.json() | ||
repo_description = repo_info["description"] | ||
return repo_description | ||
else: | ||
return "Unable to get warehouse information" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
|
||
def get_repo_name(repo_url): | ||
repo_name = os.path.basename(repo_url.rstrip('/')) | ||
return repo_name |