Skip to content

Commit

Permalink
Add RekaClient, the Vibe-Eval (Scenario and Auto-Evaluator) (#2675)
Browse files Browse the repository at this point in the history
  • Loading branch information
ImKeTT authored May 30, 2024
1 parent 742fa39 commit b1714d2
Show file tree
Hide file tree
Showing 11 changed files with 617 additions and 4 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ PyWavelets==1.4.1
PyYAML==6.0.1
referencing==0.35.1
regex==2024.5.10
reka-api==2.0.0
requests==2.31.0
requests-oauthlib==2.0.0
retrying==1.3.4
Expand Down
7 changes: 7 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,15 @@ models =
crfm-helm[google]
crfm-helm[mistral]
crfm-helm[openai]
crfm-helm[reka]
crfm-helm[together]
crfm-helm[tsinghua]
crfm-helm[yandex]
crfm-helm[openvino]

reka =
reka-api~=2.0.0

vlm =
crfm-helm[openai]

Expand All @@ -182,6 +186,9 @@ vlm =
scipy~=1.10
torchvision>=0.14.1,<3.0.0

# For Reka AI
crfm-helm[reka]

# VLM scenarios
crfm-helm[images]
crfm-helm[image2structure]
Expand Down
158 changes: 158 additions & 0 deletions src/helm/benchmark/metrics/reka_vibe_critique_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import Dict, List, Optional
import re

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats, add_context
from helm.benchmark.metrics.metric_name import MetricContext, MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat, merge_stat
from helm.common.critique_request import CritiqueTaskTemplate, CritiqueQuestionTemplate, CritiqueRequest, QuestionType
from helm.common.hierarchical_logger import hlog
from helm.common.request import RequestResult, GeneratedOutput
from helm.common.media_object import MultimediaObject, IMAGE_TYPE, TEXT_TYPE, MediaObject


class RekaVibeCritiqueMetric(MetricInterface):
"""
Critique evaluation for evaluating the correctness of generated response given the image and
reference by Reka-vibe-eval.
"""

# We can add more evaluation aspects here
VIBE_EVAL_NAME: str = "reka_vibe"
REKA_VIBE_PROMPT_WITH_IMAGE: str = """\
[Question]
{{prompt}}
[Assistant Response]
{{generation}}
[Ground Truth Response]
{{reference}}
[System]
Rate whether the assistant response correctly matches the ground truth, in regards to the image above.
The rating should be 1-5, where 1 is incorrect and 5 is correct.
Your response should be in the format:
Short Explanation: (explanation in only one sentence)
Rating: (int)"""

def __init__(self, num_respondents: int, max_tokens: int):
self._num_respondents = num_respondents
self._max_tokens = max_tokens

def __repr__(self) -> str:
return "RekaVibeCritiqueMetric()"

def _extract_score_from_reka_output(self, evaluator_response: str):
"""
Extract the score from the evaluator response. Refer to the official Vibe-Eval implementation:
https://github.com/reka-ai/reka-vibe-eval/blob/3852d4712da172a7b85dddeffc4f9c3482a6f4c9/evaluate.py#L159-#L164
"""
re_match = re.search(r"Rating:\s*([1-5])", evaluator_response)
if re_match is None:
hlog(f"Error parsing answer: {evaluator_response}. Skipping question (and so the respondent entirely)")
return None
return int(re_match.group(1))

def evaluate(
self,
scenario_state: ScenarioState,
metric_service: MetricService,
eval_cache_path: str,
parallelism: int,
) -> MetricResult:
request_states: List[RequestState] = scenario_state.request_states

all_stats: Dict[MetricName, Stat] = {}
per_instance_stats: List[PerInstanceStats] = []
for request_state in request_states:
context = MetricContext.from_instance(request_state.instance)
stats_without_context = self.evaluate_generation(
scenario_state.adapter_spec,
request_state,
metric_service,
eval_cache_path,
)
stats = [add_context(stat_without_context, context) for stat_without_context in stats_without_context]
for stat in stats:
merge_stat(all_stats, stat)
assert request_state.instance.id is not None
per_instance_stats.append(
PerInstanceStats(
instance_id=request_state.instance.id,
perturbation=request_state.instance.perturbation,
train_trial_index=request_state.train_trial_index,
stats=stats,
)
)
return MetricResult(aggregated_stats=list(all_stats.values()), per_instance_stats=per_instance_stats)

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
input_content = request_state.instance.input
# Predicted outputs and their originality scores
assert request_state.result is not None
request_result: RequestResult = request_state.result
# Get input image and generated response for the originality evaluation
assert input_content.multimedia_content is not None
completions: List[GeneratedOutput] = request_result.completions
generated_text: str = completions[0].text
input_media: MultimediaObject = input_content.multimedia_content
ref_text: str = request_state.instance.references[0].output.text

image_objects: List[MediaObject] = [
item for item in input_media.media_objects if item.is_type(IMAGE_TYPE) and item.location
]
input_text: Optional[str] = [item for item in input_media.media_objects if item.is_type(TEXT_TYPE)][0].text

template = CritiqueTaskTemplate(
name="vhelm_vibe_eval",
instructions=self.REKA_VIBE_PROMPT_WITH_IMAGE,
num_respondents=self._num_respondents,
max_tokens=self._max_tokens,
questions=[
CritiqueQuestionTemplate(
name=self.VIBE_EVAL_NAME,
question_type=QuestionType.FREE_RESPONSE,
text="",
options=[],
media_object=image_objects[0], # we only take the first image as input
)
],
)

request = CritiqueRequest(
template=template,
fields={
"prompt": input_text if input_text is not None else "",
"generation": generated_text,
"reference": ref_text,
},
)

# send to critique request
result = metric_service.make_critique_request(request)
if not result or not result.responses:
# Skip computing metrics if there aren't any responses yet
hlog("Waiting for responses to be generated.")
return []
stats: Dict[str, Stat] = {}
for question in template.questions:
stats[question.name] = Stat(MetricName(question.name))

for response in result.responses:
for answer_name, answer in response.answers.items():
assert isinstance(answer, str)
answer_value: float
answer_value = self._extract_score_from_reka_output(answer)
stats[answer_name].add(answer_value)

return list(stats.values())
4 changes: 4 additions & 0 deletions src/helm/benchmark/presentation/run_entries_vhelm.conf
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ entries: [
{description: "mme:subject=artwork,model=vlm", priority: 1}
{description: "mme:subject=landmark,model=vlm", priority: 1}

# Vibe-Eval
{description: "vibe_eval:subject=difficulty-normal,model=vlm,num_respondents=1", priority: 1}
{description: "vibe_eval:subject=difficulty-hard,model=vlm,num_respondents=1", priority: 1}

####################################################################################################################
# Originality: Does the model generate creative content (e.g., poetry, art)?
####################################################################################################################
Expand Down
33 changes: 33 additions & 0 deletions src/helm/benchmark/run_specs/vlm_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,18 @@ def get_gpt4v_critique_originality_metric_specs(num_respondents: int) -> List[Me
]


def get_vibe_eval_critique_metric_specs(num_respondents: int, max_tokens: int) -> List[MetricSpec]:
return [
MetricSpec(
class_name="helm.benchmark.metrics.reka_vibe_critique_metrics.RekaVibeCritiqueMetric",
args={
"num_respondents": num_respondents,
"max_tokens": max_tokens,
},
)
]


############################################################
# VHELM run specs

Expand Down Expand Up @@ -806,3 +818,24 @@ def get_mementos_spec(subject: str, num_respondents: int) -> RunSpec:
metric_specs=metric_specs,
groups=[run_spec_name],
)


@run_spec_function("vibe_eval")
def get_vibe_eval_spec(subject: str, num_respondents: int) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.vibe_eval_scenario.VibeEvalScenario",
args={"subject": subject},
)
adapter_spec: AdapterSpec = get_open_end_answer_generation_adapter_spec()
metric_specs: List[MetricSpec] = get_vibe_eval_critique_metric_specs(
num_respondents=num_respondents, max_tokens=200
)

run_spec_name: str = "vibe_eval"
return RunSpec(
name=run_spec_name,
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
)
95 changes: 95 additions & 0 deletions src/helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os.path
from typing import List

from datasets import load_dataset
from tqdm import tqdm

from helm.benchmark.scenarios.scenario import (
CORRECT_TAG,
TEST_SPLIT,
Instance,
Input,
Output,
Reference,
Scenario,
)
from helm.common.media_object import MediaObject, MultimediaObject
from helm.common.general import ensure_directory_exists


class VibeEvalScenario(Scenario):
"""
Vibe-Eval: A hard evaluation suite for measuring progress of multimodal language models
We introduce Vibe-Eval: a new open benchmark and framework for evaluating multimodal chat
models. Vibe-Eval consists of 269 visual understanding prompts, including 100 of hard
difficulty, complete with gold-standard responses authored by experts. Vibe-Eval is
open-ended and challenging with dual objectives: (i) vibe checking multimodal chat models
for day-to-day tasks and (ii) rigorously testing and probing the capabilities of present
frontier models. Notably, our hard set contains >50% questions that all frontier models
answer incorrectly. We also discuss trade-offs between human and automatic evaluation,
and show that automatic model evaluation using Reka Core roughly correlates to human judgment.
@article{padlewski2024vibe,
title={Vibe-Eval: A hard evaluation suite for measuring progress of multimodal language models},
author={Padlewski, Piotr and Bain, Max and Henderson, Matthew and Zhu, Zhongkai
and Relan, Nishant and Pham, Hai and Ong, Donovan and Aleksiev, Kaloyan and Ormazabal, Aitor
and Phua, Samuel and others},
journal={arXiv preprint arXiv:2405.02287},
year={2024}
}
Paper: https://arxiv.org/abs/2306.13394
"""

VIBE_EVAL_HUGGINGFACE_DATASET_NAME: str = "RekaAI/VibeEval"

SUBJECTS: List[str] = [
"difficulty-hard",
"difficulty-normal",
]

name = "vibe_eval"
description = "Evaluate multimodal models on ([paper](https://arxiv.org/abs/2405.02287))."
tags = ["vision-language"]

def __init__(self, subject: str):
super().__init__()
assert subject in self.SUBJECTS, f"Invalid subject: {subject}"
self._subject: str = subject

def get_instances(self, output_path: str) -> List[Instance]:
images_path: str = os.path.join(output_path, "images")
ensure_directory_exists(images_path)

instances: List[Instance] = []
# Process the test set
for row in tqdm(
load_dataset(
self.VIBE_EVAL_HUGGINGFACE_DATASET_NAME,
split=TEST_SPLIT,
cache_dir=output_path,
)
):
if row["category"] != self._subject:
continue
example_id: str = row["example_id"].replace("/", "-")
# Save the image locally
local_image_path: str = os.path.join(images_path, f"{example_id}.png")
if not os.path.exists(local_image_path):
row["image"].convert("RGB").save(local_image_path, "PNG", optimize=True)

content: List[MediaObject] = [
MediaObject(location=local_image_path, content_type="image/png"),
MediaObject(text=row["prompt"], content_type="text/plain"),
]
answer: str = row["reference"]
instances.append(
Instance(
Input(multimedia_content=MultimediaObject(content)),
references=[Reference(Output(text=answer), tags=[CORRECT_TAG])],
split=TEST_SPLIT,
)
)

return instances
Loading

0 comments on commit b1714d2

Please sign in to comment.