Skip to content

Commit

Permalink
Merge pull request #81 from etalab-ia/feat/locust
Browse files Browse the repository at this point in the history
Feat/locust
  • Loading branch information
dtrckd authored Jan 16, 2025
2 parents 73efa7b + 0247d32 commit b9d762d
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 20 deletions.
39 changes: 39 additions & 0 deletions api/alembic/versions/4de07d9a67df_add_locustrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Add LocustRun
Revision ID: 4de07d9a67df
Revises: f4342e16e891
Create Date: 2025-01-17 00:37:14.821839
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '4de07d9a67df'
down_revision: Union[str, None] = 'f4342e16e891'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('locustrun',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=True),
sa.Column('scenario', sa.Text(), nullable=True),
sa.Column('model', sa.Text(), nullable=True),
sa.Column('api_url', sa.Text(), nullable=True),
sa.Column('stats_df', sa.Text(), nullable=True),
sa.Column('history_df', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('locustrun')
# ### end Alembic commands ###
56 changes: 39 additions & 17 deletions api/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,18 +307,19 @@ def upsert_observation(
return db_observation




#
# LeaderBoard
#


def get_leaderboard(db: Session, metric_name: str = "judge_notator", dataset_name: str = None, limit: int = 100):
def get_leaderboard(
db: Session, metric_name: str = "judge_notator", dataset_name: str = None, limit: int = 100
):
# Filter Experiment by dataset and metric
main_metric_subquery = (
db.query(models.Result.experiment_id,
func.max(models.ObservationTable.score).label('main_score'))
db.query(
models.Result.experiment_id, func.max(models.ObservationTable.score).label("main_score")
)
.join(models.ObservationTable)
.filter(models.Result.metric_name == metric_name)
.group_by(models.Result.experiment_id)
Expand All @@ -327,12 +328,14 @@ def get_leaderboard(db: Session, metric_name: str = "judge_notator", dataset_nam

# Query
query = (
db.query(models.Experiment.id.label('experiment_id'),
models.Model.name.label('model_name'),
models.Dataset.name.label('dataset_name'),
main_metric_subquery.c.main_score.label('main_metric_score'),
models.Model.sampling_params,
models.Model.extra_params)
db.query(
models.Experiment.id.label("experiment_id"),
models.Model.name.label("model_name"),
models.Dataset.name.label("dataset_name"),
main_metric_subquery.c.main_score.label("main_metric_score"),
models.Model.sampling_params,
models.Model.extra_params,
)
.join(models.Model)
.join(models.Dataset)
.join(main_metric_subquery, models.Experiment.id == main_metric_subquery.c.experiment_id)
Expand All @@ -341,19 +344,25 @@ def get_leaderboard(db: Session, metric_name: str = "judge_notator", dataset_nam
if dataset_name:
query = query.filter(models.Dataset.name == dataset_name)

query = query.order_by(desc('main_metric_score')).limit(limit)
query = query.order_by(desc("main_metric_score")).limit(limit)

# Execute
results = query.all()

# Fetch result in leaderboard
entries = []
for result in results:
other_metrics = db.query(models.Result.metric_name, models.ObservationTable.score)\
.join(models.ObservationTable)\
.filter(and_(models.Result.experiment_id == result.experiment_id,
models.Result.metric_name != metric_name))\
other_metrics = (
db.query(models.Result.metric_name, models.ObservationTable.score)
.join(models.ObservationTable)
.filter(
and_(
models.Result.experiment_id == result.experiment_id,
models.Result.metric_name != metric_name,
)
)
.all()
)

entry = schemas.LeaderboardEntry(
experiment_id=result.experiment_id,
Expand All @@ -362,9 +371,22 @@ def get_leaderboard(db: Session, metric_name: str = "judge_notator", dataset_nam
main_metric_score=result.main_metric_score,
other_metrics={metric: score for metric, score in other_metrics},
sampling_param={k: str(v) for k, v in (result.sampling_params or {}).items()},
extra_param={k: str(v) for k, v in (result.extra_params or {}).items()}
extra_param={k: str(v) for k, v in (result.extra_params or {}).items()},
)
entries.append(entry)

return schemas.Leaderboard(entries=entries)


#
# LOCUST
#


def create_locustrun(db: Session, run: schemas.LocustRunCreate) -> models.LocustRun:
run = run.to_table_init(db) if isinstance(run, schemas.EgBaseModel) else run
db_run = create_object_from_dict(db, models.LocustRun, run)
db.add(db_run)
db.commit()
db.refresh(db_run)
return db_run
34 changes: 34 additions & 0 deletions api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,37 @@ def read_leaderboard(
db: Session = Depends(get_db),
):
return crud.get_leaderboard(db, metric_name=metric_name, dataset_name=dataset_name, limit=limit)


#
# LOCUST
#


@router.post(
"/locust",
response_model=schemas.LocustRun,
description="""Save a locust run.
To format the locust CSV as a dataframe, here is how you must convert it to a dataframe:
```
import pandas as pd
import requests
stats_df = pd.read_csv("stats.csv").to_json()
# Then you can just pass the data in the POST request along the other parameters.
```
""",
tags=["locust"],
)
def create_locustrun(run: schemas.LocustRunCreate, db: Session = Depends(get_db)):
try:
db_run = crud.create_locustrun(db, run)
return db_run

except (SchemaError, ValidationError) as e:
raise HTTPException(status_code=400, detail=str(e))
except IntegrityError as e:
return CustomIntegrityError.from_integrity_error(e.orig).to_http_response()
except Exception as e:
raise e
17 changes: 17 additions & 0 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,20 @@ class ExperimentSet(Base):

# Many
experiments = relationship("Experiment", back_populates="experiment_set")


#
# LOCUST
#


class LocustRun(Base):
__tablename__ = "locustrun"

id = Column(Integer, primary_key=True)
created_at = Column(DateTime, server_default=func.now())
scenario = Column(Text)
model = Column(Text)
api_url = Column(Text)
stats_df = Column(Text)
history_df = Column(Text)
28 changes: 25 additions & 3 deletions api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ class RetryRuns(EgBaseModel):
result_ids: list[int]



#
# LeaderBoard
#
Expand All @@ -439,11 +438,34 @@ class LeaderboardEntry(EgBaseModel):
dataset_name: str
main_metric_score: float | None
other_metrics: dict[str, float | None]
sampling_param: dict[str, str | None]
extra_param: dict[str, str | None]
sampling_param: dict[str, str | None]
extra_param: dict[str, str | None]


class Leaderboard(EgBaseModel):
entries: list[LeaderboardEntry]


#
# LOCUST
#


class LocustRunBase(EgBaseModel):
metric_name: MetricEnum
scenario: str = Field(..., description="The locust scenario name.")
model: str | None = Field(None, description="The LLM model name/id targeted if any.")
api_url: str = Field(..., description="The url targeted.")
stats_df: str = Field(..., description="The stats csv file serialzed as a dataframe.")
history_df: str = Field(
..., description="The stats history CSV file serialized as a dataframe."
)


class LocustRunCreate(LocustRunBase):
pass


class LocustRun(LocustRunBase):
id: int
created_at: datetime

0 comments on commit b9d762d

Please sign in to comment.