From 3976617769e6e43bca3b82719fd30a49acf21177 Mon Sep 17 00:00:00 2001 From: Grant <50287275+granawkins@users.noreply.github.com> Date: Tue, 26 Mar 2024 10:25:39 +0700 Subject: [PATCH] SWE-Benchmarks (from Huggingface) (#544) --- .gitignore | 1 + benchmarks/arg_parser.py | 8 +++++ benchmarks/benchmark_runner.py | 10 ++++++ benchmarks/run_sample.py | 31 +++++++++++++++-- benchmarks/swe_bench_runner.py | 38 +++++++++++++++++++++ dev-requirements.txt | 1 + mentat/sampler/CHANGELOG.md | 29 ++++++++++++++++ mentat/sampler/README.md | 39 +++++++++++---------- mentat/sampler/__init__.py | 2 +- mentat/sampler/sample.py | 62 ++++++++++++++++++++++++++++++++++ tests/sampler_test.py | 2 +- 11 files changed, 202 insertions(+), 21 deletions(-) create mode 100644 benchmarks/swe_bench_runner.py create mode 100644 mentat/sampler/CHANGELOG.md diff --git a/.gitignore b/.gitignore index d71934260..b1aa130be 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ build benchmark_repos docs/build .DS_Store +benchmarks/benchmarks/swe_bench_samples diff --git a/benchmarks/arg_parser.py b/benchmarks/arg_parser.py index d98c9b40b..c054842c3 100644 --- a/benchmarks/arg_parser.py +++ b/benchmarks/arg_parser.py @@ -64,5 +64,13 @@ def common_benchmark_parser(): action="store_true", help="Evaluate the baseline for the benchmark", ) + parser.add_argument( + "--swe_bench", + nargs="?", + const="dev", + default=None, + type=str, + help="Fetch or load SWE-bench examples from split: dev (default), train or test.", + ) return parser diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index 10675f6cf..74b3a7cef 100755 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -16,6 +16,7 @@ from benchmarks.benchmark_result import BenchmarkResult from benchmarks.benchmark_run import BenchmarkRun from benchmarks.run_sample import run_sample +from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples from mentat.config import Config from mentat.git_handler import get_git_diff, get_mentat_branch, get_mentat_hexsha from mentat.llm_api_handler import model_context_size, prompt_tokens @@ -314,6 +315,15 @@ def run_benchmarks(user_benchmarks: list[str], directory: str, retries: int = 1) if __name__ == "__main__": parser = common_benchmark_parser() args = parser.parse_args() + if args.swe_bench: + if args.swe_bench not in {"dev", "train", "test"}: + print("Invalid SWE-Bench split.") + exit(1) + # Download and save SWE benchmarks as Samples + samples = get_swe_samples(args.swe_bench, args.max_benchmarks) + sample_titles = [sample.title for sample in samples] + args.benchmarks = sample_titles + args.directory = SWE_BENCH_SAMPLES_DIR / args.swe_bench run_benchmarks( args.benchmarks, args.directory, diff --git a/benchmarks/run_sample.py b/benchmarks/run_sample.py index c440a676e..deafa0e5d 100644 --- a/benchmarks/run_sample.py +++ b/benchmarks/run_sample.py @@ -1,4 +1,6 @@ from pathlib import Path +import subprocess +import re from typing import Any from mentat import Mentat @@ -6,7 +8,7 @@ from mentat.git_handler import get_git_diff from mentat.parsers.git_parser import GitParser from mentat.sampler.sample import Sample -from mentat.sampler.utils import get_active_snapshot_commit, setup_repo +from mentat.sampler.utils import get_active_snapshot_commit, setup_repo, apply_diff_to_repo from mentat.session_context import SESSION_CONTEXT from mentat.utils import convert_string_to_asynciter @@ -45,7 +47,10 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str, conversation.add_model_message(content, [], parsed_llm_response) else: raise SampleError(f"Invalid role found in message_history: {msg['role']}") - await mentat.call_mentat_auto_accept(sample.message_prompt) + prompt = sample.message_prompt + if sample.hint_text: + prompt += f"\n{80 * '-'}\nHint Text:\n{sample.hint_text}" + await mentat.call_mentat_auto_accept(prompt) await mentat.shutdown() # Get the diff between pre- and post-edit @@ -54,6 +59,27 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str, message_eval = str(transcript_messages[-1].get("message", "")) diff_eval = get_git_diff(commit_active or "HEAD", cwd=cwd) + test_results = {"passed": 0, "failed": 0, "error": ""} + if sample.test_command: + if sample.test_patch: + apply_diff_to_repo(sample.test_patch, repo) + try: + output = subprocess.run( + sample.test_command, + shell=True, + capture_output=True, + text=True, + cwd=cwd, + ) + matches = re.search(r"(?:(\d+) passed)?(?:, )?(?:(\d+) failed)?", output.stdout) + if matches: + test_results["passed"] = int(matches.group(1)) or 0 + test_results["failed"] = int(matches.group(2)) or 0 + else: + raise SampleError(f"Test command failed: {output.stdout}") + except Exception as e: + test_results["error"] = str(e) + return { "id": sample.id, "message_eval": message_eval, @@ -64,4 +90,5 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None) -> dict[str, "id": sample.id, "messages": transcript_messages, }, + "test_results": test_results, } diff --git a/benchmarks/swe_bench_runner.py b/benchmarks/swe_bench_runner.py new file mode 100644 index 000000000..b707f7c83 --- /dev/null +++ b/benchmarks/swe_bench_runner.py @@ -0,0 +1,38 @@ +from pathlib import Path + +from datasets import load_dataset, DatasetDict # type: ignore + +from mentat.sampler.sample import Sample + + +SWE_BENCH_SAMPLES_DIR = Path(__file__).parent / "benchmarks" / "swe_bench_samples" + + +def download_swe_benchmarks(split: str = "dev") -> list[dict[str, str]]: + """3 splits are available: dev (225), test (2.29k), and train (19k).""" + dataset: DatasetDict = load_dataset("princeton-nlp/SWE-bench", split=split) # type: ignore + dataset: list[dict[str, str]] = [dict(benchmark) for benchmark in dataset] + return dataset + + +def get_swe_samples(split: str = "dev", max_benchmarks: int | None = None) -> list[Sample]: + """Return a list of SWE-Bench samples. + + If missing, download, convert to Samples and save locally. + """ + split_dir = SWE_BENCH_SAMPLES_DIR / split + saved_benchmarks = list(split_dir.glob("*.json")) + if not split_dir.exists() or max_benchmarks and len(saved_benchmarks) < max_benchmarks: + print(f"Downloading {split} split from SWE-Bench...") + split_dir.mkdir(parents=True, exist_ok=True) + dataset = download_swe_benchmarks(split) + samples = [Sample.from_swe_bench(benchmark) for benchmark in dataset] + for sample in samples: + sample.save(split_dir / f"{sample.id}.json") + else: + samples = [Sample.load(fname) for fname in saved_benchmarks] + + if max_benchmarks: + samples = samples[:max_benchmarks] + print(f"Selected {len(samples)} benchmarks from '{split}'") + return samples diff --git a/dev-requirements.txt b/dev-requirements.txt index e850acc0f..401bf3c8e 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,6 @@ aiomultiprocess==0.9.0 black==23.9.1 +datasets==2.18.0 gitpython==3.1.41 isort==5.12.0 pip-licenses==4.3.3 diff --git a/mentat/sampler/CHANGELOG.md b/mentat/sampler/CHANGELOG.md new file mode 100644 index 000000000..1df0ee131 --- /dev/null +++ b/mentat/sampler/CHANGELOG.md @@ -0,0 +1,29 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +## [2024-03-22]: Add fields to cover content of SWE-Bench +Fields in common are left with their current name. Missing fields (*) are added. + +Sampler SWE-Bench +- title +- description +- id instance_id +- parent_id +- repo repo +- environment_setup_commit * +- merge_base base_commit +- diff_merge_base +- diff_active +- message_history +- message_prompt problem_statement + hint_text * +- message_edit +- context +- diff_edit patch + test_patch * +- test_command FAIL_TO_PASS + PASS_TO_PASS * +- version +- version +- created_at diff --git a/mentat/sampler/README.md b/mentat/sampler/README.md index 4731fba43..c3562adef 100644 --- a/mentat/sampler/README.md +++ b/mentat/sampler/README.md @@ -12,23 +12,28 @@ In any github-connected repo: ## `Sample` API A `Sample` captures interactions between a developer and any LLM Coding Assistant. It consists of a starting codebase, a user command, and the expected LLM response - text, a git diff, or both. It can also include a list of paths/line-numbers to be included with the prompt, diffs to setup the git environment, and more: -| Field | Req | Type | Description | -|------------------|-----|------------------------|-------------| -| title | | `str` | plaintext by creator | -| description | | `str` | plaintext by creator | -| id | | `uuid` | | -| parent_id | | `uuid` | id of sample immediately before this | -| repo | * | `str` | a url to download the code | -| merge_base | * | `str` | the latest permanent commit | -| diff_merge_base | | `str` | between merge_base and latest commit | -| diff_active | | `str` | between latest commit and active (pre-edit) code | -| args | | `list[str]` | list of `[:-]` | -| message_history | | `list[dict[str, str]]` | list of prior user and assistant messages | -| message_prompt | * | `str` | the sample task | -| message_edit | | `str` | plaintext response returned for sample edit | -| diff_edit | * | `str` | between starting (diff_head) and ending code. | -| test_command | | `str` | discrete pass/fail, e.g. ‘pytest -k diff_active’ | -| version | | `str` | current Sample API version | +| Field | Req | Type | Description | +|---------------------------|-----|------------------------|-------------| +| title | | `str` | plaintext by creator | +| description | | `str` | plaintext by creator | +| id | | `uuid` | | +| parent_id | | `uuid` | id of sample immediately before this | +| repo | * | `str` | a url to download the code | +| environment_setup_commit | | `str` | commit hash to use for environment setup and installation | +| merge_base | * | `str` | the latest permanent commit | +| diff_merge_base | | `str` | between merge_base and latest commit | +| diff_active | | `str` | between latest commit and active (pre-edit) code | +| context | | `list[str]` | list of `[:-]` | +| message_history | | `list[dict[str, str]]` | list of prior user and assistant messages | +| message_prompt | * | `str` | the sample task | +| hint_text | | `str` | extra information, e.g. github issue comments +| message_edit | | `str` | plaintext response returned for sample edit | +| diff_edit | * | `str` | between starting (diff_head) and ending code. | +| test_patch | | `str` | A patch to files used to evaluate the samples +| test_command | | `str` | discrete pass/fail, e.g. ‘pytest -k diff_active’ | +| PASS_TO_PASS | | `str` | discrete pass/fail, expected to pass + +| version | | `str` | current Sample API version | Notes: - All diffs and code changes follow standard git-diff format (`diff --git a/new_filename...`) diff --git a/mentat/sampler/__init__.py b/mentat/sampler/__init__.py index d3ec452c3..493f7415d 100644 --- a/mentat/sampler/__init__.py +++ b/mentat/sampler/__init__.py @@ -1 +1 @@ -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/mentat/sampler/sample.py b/mentat/sampler/sample.py index 5c9ce983d..f74a617fb 100644 --- a/mentat/sampler/sample.py +++ b/mentat/sampler/sample.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import re from pathlib import Path import attr @@ -17,15 +18,19 @@ class Sample: id: str = attr.field(default="") parent_id: str = attr.field(default="") repo: str = attr.field(default="") + environment_setup_commit: str | None = attr.field(default=None) merge_base: str | None = attr.field(default=None) diff_merge_base: str = attr.field(default="") diff_active: str = attr.field(default="") message_history: list[dict[str, str]] = attr.field(default=[]) # type: ignore message_prompt: str = attr.field(default="") + hint_text: str = attr.field(default="") message_edit: str = attr.field(default="") context: list[str] = attr.field(default=[]) # type: ignore diff_edit: str = attr.field(default="") + test_patch: str = attr.field(default="") test_command: str = attr.field(default="") + PASS_TO_PASS: str = attr.field(default="") version: str = attr.field(default=__version__) def save(self, fname: str | Path) -> None: @@ -41,8 +46,65 @@ def load(cls, fname: str | Path) -> Sample: kwargs["message_history"] = kwargs.get("message_history", [])[::-1] kwargs["version"] = "0.2.0" _version = kwargs["version"] + if _version < "0.3.0": + # Additional fields from SWE-Bench + kwargs["environment_setup_commit"] = "" + kwargs["hint_text"] = "" + kwargs["test_patch"] = "" + kwargs["PASS_TO_PASS"] = "" + kwargs["version"] = "0.3.0" if _version != __version__: raise SampleError( f"Warning: sample version ({_version}) does not match current" f" version ({__version__})." ) return cls(**kwargs) + + @classmethod + def from_swe_bench(cls, benchmark: dict[str, str]) -> Sample: + """Create a Sample from a SWE-Bench benchmark. + + SWE-Bench Fields (https://huggingface.co/datasets/princeton-nlp/SWE-bench#dataset-structure) + - instance_id: (str) - A formatted instance identifier, usually as repo_owner__repo_name-PR-number. + - patch: (str) - The gold patch, the patch generated by the PR (minus test-related code), that resolved + the issue. + - repo: (str) - The repository owner/name identifier from GitHub. + - base_commit: (str) - The commit hash of the repository representing the HEAD of the repository before the + solution PR is applied. + - hints_text: (str) - Comments made on the issue prior to the creation of the solution PR’s first commit + creation date. + - created_at: (str) - The creation date of the pull request. + - test_patch: (str) - A test-file patch that was contributed by the solution PR. + - problem_statement: (str) - The issue title and body. + - version: (str) - Installation version to use for running evaluation. + - environment_setup_commit: (str) - commit hash to use for environment setup and installation. + - FAIL_TO_PASS: (str) - A json list of strings that represent the set of tests resolved by the PR and tied to + the issue resolution. + - PASS_TO_PASS: (str) - A json list of strings that represent tests that should pass before and after the PR + application. + """ + patch = benchmark.get("patch", "") + edited_files = re.findall(r"diff --git a/(.*?) b/\1", patch) + return cls( + title=f"SWE-bench-{benchmark['instance_id']}", + description="", + id=benchmark["instance_id"], + parent_id="", + repo=f"https://github.com/{benchmark.get('repo')}", + environment_setup_commit=benchmark.get("environment_setup_commit", ""), + merge_base=benchmark.get("base_commit"), + diff_merge_base="", + diff_active="", + message_history=[], + message_prompt=benchmark.get("problem_statement", ""), + hint_text=benchmark.get("hint_text", ""), + message_edit="", + context=edited_files, + diff_edit=patch, + test_patch=benchmark.get("test_patch", ""), + test_command=( + "" if not benchmark.get("FAIL_TO_PASS") else "pytest " + " ".join(json.loads(benchmark["FAIL_TO_PASS"])) + ), + PASS_TO_PASS=( + "" if not benchmark.get("PASS_TO_PASS") else "pytest " + " ".join(json.loads(benchmark["PASS_TO_PASS"])) + ), + ) diff --git a/tests/sampler_test.py b/tests/sampler_test.py index 100367276..89e390b41 100644 --- a/tests/sampler_test.py +++ b/tests/sampler_test.py @@ -171,7 +171,7 @@ async def test_sample_command(temp_testbed, mock_collect_user_input, mock_call_l assert "test_file.py" in edits[1] assert "+# forty two" in edits[1] assert sample.test_command == "test_test_command" - assert sample.version == "0.2.0" + assert sample.version == "0.3.0" test_sample = {