Skip to content

Commit

Permalink
Fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
allenporter committed Aug 4, 2024
1 parent 9a1207c commit fb898b2
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions home_assistant_datasets/tool/leaderboard/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
from collections.abc import Callable
import math
import pathlib
import subprocess
from typing import Any

import yaml

from .config import REPORT_DIR, DATASETS, IGNORE_REPORTS, REPORT_FILE, eval_reports, EvalReport, COLORS
from .config import (
REPORT_DIR,
DATASETS,
eval_reports,
COLORS,
)


__all__ = []
Expand Down Expand Up @@ -60,7 +63,9 @@ def stddev(self) -> float:
return math.sqrt((p * (1 - p)) / self.total)


def best_score_func(model_scores: dict[str, dict[str, ModelRecord]], dataset_name: str) -> Callable[[str], float]:
def best_score_func(
model_scores: dict[str, dict[str, ModelRecord]], dataset_name: str
) -> Callable[[str], float]:
"""Best score function."""

def func(model_id: str) -> float:
Expand All @@ -78,7 +83,9 @@ def run(args: argparse.Namespace) -> int:
for eval_report in eval_reports(report_dir):
report_file = eval_report.report_file
if not report_file.exists:
raise ValueError(f"Report file {report_file} does not exist, run `prebuild` first")
raise ValueError(
f"Report file {report_file} does not exist, run `prebuild` first"
)

report = yaml.load(eval_report.report_file.read_text(), Loader=yaml.CSafeLoader)
for model_data in report:
Expand All @@ -96,7 +103,6 @@ def run(args: argparse.Namespace) -> int:
)
)


# Sort reports by their best scores
for model_id in model_scores:
for dataset in DATASETS:
Expand All @@ -106,8 +112,6 @@ def run(args: argparse.Namespace) -> int:
records = sorted(records, key=ModelRecord.good_percent_value, reverse=True)
model_scores[model_id][dataset] = records



# Generate overall report sorted by the first dataset score
best_score = best_score_func(model_scores, DATASETS[0])
sorted_model_ids = sorted(model_scores.keys(), key=best_score, reverse=True)
Expand All @@ -122,13 +126,14 @@ def run(args: argparse.Namespace) -> int:
records = model_scores[model_id][dataset]
if records:
best_record = records[0]
row.append(f"| {best_record.good_percent_value()*100:0.1f}% (+/- {best_record.stddev*100:0.1f}%) {best_record.dataset_label} ")
row.append(
f"| {best_record.good_percent_value()*100:0.1f}% (+/- {best_record.stddev*100:0.1f}%) {best_record.dataset_label} "
)
else:
row.append(f"| 0 ")
row.append("| 0 ")
row.append("|")
results.append(row)


# Generate a bar chart for each dataset
for dataset in DATASETS:

Expand All @@ -145,25 +150,26 @@ def run(args: argparse.Namespace) -> int:
bar.append(0)
continue
best_record = records[0]
# x_axis.append(model_id)
# x_axis.append(model_id)
bar.append(float(f"{best_record.good_percent_value()*100:0.2f}"))


def make_bar(index: int, bars: list[int]) -> str:
values = ["0"] * len(bars)
values[index] = str(bars[index])
return ", ".join(values)

x_axis_str = ", ".join([model_id for model_id in x_axis])
color_str = ", ".join(COLORS[0:len(sorted_model_ids)])
bar_str = "\n".join([
f" bar [{make_bar(index, bar)}]"
for index, model_id in enumerate(sorted_model_ids)
])

results.extend([
"",
f"""```mermaid
color_str = ", ".join(COLORS[0 : len(sorted_model_ids)])
bar_str = "\n".join(
[
f" bar [{make_bar(index, bar)}]"
for index, model_id in enumerate(sorted_model_ids)
]
)
results.extend(
[
"",
f"""```mermaid
---
config:
xyChart:
Expand All @@ -185,16 +191,11 @@ def make_bar(index: int, bars: list[int]) -> str:
{bar_str}
```
""",
])
]
)

leaderboard_file = report_dir / LEADERBOARD_FILE
print(f"Updating {leaderboard_file}")
leaderboard_file.write_text("\n".join([
"".join(row)
for row in results
]))



leaderboard_file.write_text("\n".join(["".join(row) for row in results]))

return 0

0 comments on commit fb898b2

Please sign in to comment.