Skip to content

Commit

Permalink
Fix evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
waleko committed Jul 10, 2024
1 parent d4dffc8 commit b1d6dbf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion code_editing/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def extract_patch(resp: str) -> Optional[str]:
if resp.strip() == "":
return ""

if resp.startswith("diff --git"):
if resp.strip().startswith("diff --git") or resp.strip().startswith("--- a/"):
# If the response is a diff, return it as is
return resp

Expand Down
24 changes: 19 additions & 5 deletions code_editing/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,36 @@
dotenv.load_dotenv()

from code_editing.configs.evaluation_config import RunEvaluationConfig
from code_editing.data_sources import SWEBenchDataSource
from code_editing.data_sources.extract_code_base import CodeBaseExtractor
from code_editing.metrics.base_metric import BaseMetric


@hydra.main(version_base=None, config_path="conf", config_name="evaluation")
def main(cfg: RunEvaluationConfig):
"""This script evaluates a csv file with columns diff_true and diff_pred."""
# Instantiate the extractor and data source
extractor: CodeBaseExtractor = instantiate(cfg.extractor)
data_source = instantiate(cfg.data_source, extractor=extractor)

# Read the input file specified by the user
if cfg.input_path.endswith(".csv"):
df = pd.read_csv(cfg.input_path)
else:
df = pd.read_json(cfg.input_path, lines=True)
# if SWE format, convert to the expected format
if "model_name_or_path" in df.columns and isinstance(data_source, SWEBenchDataSource):
# rename model_name_or_path to model_name
dataset_df = data_source._dataset.to_pandas()
# add diff_true, message, repo, base_hash from the dataset merged on instance_id
df = df.merge(dataset_df, left_on="instance_id", right_on="instance_id")
df["model_name"] = df["model_name_or_path"]
df["diff_pred"] = df["model_patch"]
df["diff_true"] = df["patch"]
df["message"] = df["problem_statement"]
df["base_hash"] = df["base_commit"]
df["viewed_lines"] = df["diff_true"].apply(lambda x: "{}")
df = df[["diff_pred", "diff_true", "repo", "base_hash", "message", "viewed_lines", "model_name"]]

# Get the 'diff_pred' column from the dataframe, replace any NaN values with an empty string
diff_pred = df["diff_pred"].fillna("")
Expand All @@ -39,14 +57,10 @@ def main(cfg: RunEvaluationConfig):
run_name = model_name or os.path.split(cfg.input_path)[-2]
tags = []
# If the model name is in the format <run_name>_hex, split it
if model_name[-9] == "_":
if len(model_name) > 8 and model_name[-9] == "_":
run_name = model_name[:-9]
tags.append(f"unique:{model_name[-8:]}")

# Instantiate the extractor and data source
extractor: CodeBaseExtractor = instantiate(cfg.extractor)
data_source = instantiate(cfg.data_source, extractor=extractor)

res = {}
pbar = tqdm(cfg.metrics.items(), position=0, desc="Running metrics")
tags += cfg.metrics.keys()
Expand Down

0 comments on commit b1d6dbf

Please sign in to comment.