diff --git a/requirements.txt b/requirements.txt index e2e12ca3e9..570cc28554 100644 --- a/requirements.txt +++ b/requirements.txt @@ -238,6 +238,7 @@ PyWavelets==1.4.1 PyYAML==6.0.1 referencing==0.35.1 regex==2024.5.10 +reka-api==2.0.0 requests==2.31.0 requests-oauthlib==2.0.0 retrying==1.3.4 diff --git a/setup.cfg b/setup.cfg index e030d14b7b..2235370aad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -161,11 +161,15 @@ models = crfm-helm[google] crfm-helm[mistral] crfm-helm[openai] + crfm-helm[reka] crfm-helm[together] crfm-helm[tsinghua] crfm-helm[yandex] crfm-helm[openvino] +reka = + reka-api~=2.0.0 + vlm = crfm-helm[openai] @@ -182,6 +186,9 @@ vlm = scipy~=1.10 torchvision>=0.14.1,<3.0.0 + # For Reka AI + crfm-helm[reka] + # VLM scenarios crfm-helm[images] crfm-helm[image2structure] diff --git a/src/helm/benchmark/metrics/reka_vibe_critique_metrics.py b/src/helm/benchmark/metrics/reka_vibe_critique_metrics.py new file mode 100644 index 0000000000..faf4cf72a5 --- /dev/null +++ b/src/helm/benchmark/metrics/reka_vibe_critique_metrics.py @@ -0,0 +1,158 @@ +from typing import Dict, List, Optional +import re + +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.adaptation.scenario_state import ScenarioState +from helm.benchmark.adaptation.adapter_spec import AdapterSpec +from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats, add_context +from helm.benchmark.metrics.metric_name import MetricContext, MetricName +from helm.benchmark.metrics.metric_service import MetricService +from helm.benchmark.metrics.statistic import Stat, merge_stat +from helm.common.critique_request import CritiqueTaskTemplate, CritiqueQuestionTemplate, CritiqueRequest, QuestionType +from helm.common.hierarchical_logger import hlog +from helm.common.request import RequestResult, GeneratedOutput +from helm.common.media_object import MultimediaObject, IMAGE_TYPE, TEXT_TYPE, MediaObject + + +class RekaVibeCritiqueMetric(MetricInterface): + """ + Critique evaluation for evaluating the correctness of generated response given the image and + reference by Reka-vibe-eval. + """ + + # We can add more evaluation aspects here + VIBE_EVAL_NAME: str = "reka_vibe" + REKA_VIBE_PROMPT_WITH_IMAGE: str = """\ +[Question] +{{prompt}} + +[Assistant Response] +{{generation}} + +[Ground Truth Response] +{{reference}} + +[System] +Rate whether the assistant response correctly matches the ground truth, in regards to the image above. +The rating should be 1-5, where 1 is incorrect and 5 is correct. +Your response should be in the format: +Short Explanation: (explanation in only one sentence) +Rating: (int)""" + + def __init__(self, num_respondents: int, max_tokens: int): + self._num_respondents = num_respondents + self._max_tokens = max_tokens + + def __repr__(self) -> str: + return "RekaVibeCritiqueMetric()" + + def _extract_score_from_reka_output(self, evaluator_response: str): + """ + Extract the score from the evaluator response. Refer to the official Vibe-Eval implementation: + https://github.com/reka-ai/reka-vibe-eval/blob/3852d4712da172a7b85dddeffc4f9c3482a6f4c9/evaluate.py#L159-#L164 + """ + re_match = re.search(r"Rating:\s*([1-5])", evaluator_response) + if re_match is None: + hlog(f"Error parsing answer: {evaluator_response}. Skipping question (and so the respondent entirely)") + return None + return int(re_match.group(1)) + + def evaluate( + self, + scenario_state: ScenarioState, + metric_service: MetricService, + eval_cache_path: str, + parallelism: int, + ) -> MetricResult: + request_states: List[RequestState] = scenario_state.request_states + + all_stats: Dict[MetricName, Stat] = {} + per_instance_stats: List[PerInstanceStats] = [] + for request_state in request_states: + context = MetricContext.from_instance(request_state.instance) + stats_without_context = self.evaluate_generation( + scenario_state.adapter_spec, + request_state, + metric_service, + eval_cache_path, + ) + stats = [add_context(stat_without_context, context) for stat_without_context in stats_without_context] + for stat in stats: + merge_stat(all_stats, stat) + assert request_state.instance.id is not None + per_instance_stats.append( + PerInstanceStats( + instance_id=request_state.instance.id, + perturbation=request_state.instance.perturbation, + train_trial_index=request_state.train_trial_index, + stats=stats, + ) + ) + return MetricResult(aggregated_stats=list(all_stats.values()), per_instance_stats=per_instance_stats) + + def evaluate_generation( + self, + adapter_spec: AdapterSpec, + request_state: RequestState, + metric_service: MetricService, + eval_cache_path: str, + ) -> List[Stat]: + input_content = request_state.instance.input + # Predicted outputs and their originality scores + assert request_state.result is not None + request_result: RequestResult = request_state.result + # Get input image and generated response for the originality evaluation + assert input_content.multimedia_content is not None + completions: List[GeneratedOutput] = request_result.completions + generated_text: str = completions[0].text + input_media: MultimediaObject = input_content.multimedia_content + ref_text: str = request_state.instance.references[0].output.text + + image_objects: List[MediaObject] = [ + item for item in input_media.media_objects if item.is_type(IMAGE_TYPE) and item.location + ] + input_text: Optional[str] = [item for item in input_media.media_objects if item.is_type(TEXT_TYPE)][0].text + + template = CritiqueTaskTemplate( + name="vhelm_vibe_eval", + instructions=self.REKA_VIBE_PROMPT_WITH_IMAGE, + num_respondents=self._num_respondents, + max_tokens=self._max_tokens, + questions=[ + CritiqueQuestionTemplate( + name=self.VIBE_EVAL_NAME, + question_type=QuestionType.FREE_RESPONSE, + text="", + options=[], + media_object=image_objects[0], # we only take the first image as input + ) + ], + ) + + request = CritiqueRequest( + template=template, + fields={ + "prompt": input_text if input_text is not None else "", + "generation": generated_text, + "reference": ref_text, + }, + ) + + # send to critique request + result = metric_service.make_critique_request(request) + if not result or not result.responses: + # Skip computing metrics if there aren't any responses yet + hlog("Waiting for responses to be generated.") + return [] + stats: Dict[str, Stat] = {} + for question in template.questions: + stats[question.name] = Stat(MetricName(question.name)) + + for response in result.responses: + for answer_name, answer in response.answers.items(): + assert isinstance(answer, str) + answer_value: float + answer_value = self._extract_score_from_reka_output(answer) + stats[answer_name].add(answer_value) + + return list(stats.values()) diff --git a/src/helm/benchmark/presentation/run_entries_vhelm.conf b/src/helm/benchmark/presentation/run_entries_vhelm.conf index 609fa0f1ad..130300e15f 100644 --- a/src/helm/benchmark/presentation/run_entries_vhelm.conf +++ b/src/helm/benchmark/presentation/run_entries_vhelm.conf @@ -89,6 +89,10 @@ entries: [ {description: "mme:subject=artwork,model=vlm", priority: 1} {description: "mme:subject=landmark,model=vlm", priority: 1} + # Vibe-Eval + {description: "vibe_eval:subject=difficulty-normal,model=vlm,num_respondents=1", priority: 1} + {description: "vibe_eval:subject=difficulty-hard,model=vlm,num_respondents=1", priority: 1} + #################################################################################################################### # Originality: Does the model generate creative content (e.g., poetry, art)? #################################################################################################################### diff --git a/src/helm/benchmark/run_specs/vlm_run_specs.py b/src/helm/benchmark/run_specs/vlm_run_specs.py index bac7325e34..11a1da8da0 100644 --- a/src/helm/benchmark/run_specs/vlm_run_specs.py +++ b/src/helm/benchmark/run_specs/vlm_run_specs.py @@ -156,6 +156,18 @@ def get_gpt4v_critique_originality_metric_specs(num_respondents: int) -> List[Me ] +def get_vibe_eval_critique_metric_specs(num_respondents: int, max_tokens: int) -> List[MetricSpec]: + return [ + MetricSpec( + class_name="helm.benchmark.metrics.reka_vibe_critique_metrics.RekaVibeCritiqueMetric", + args={ + "num_respondents": num_respondents, + "max_tokens": max_tokens, + }, + ) + ] + + ############################################################ # VHELM run specs @@ -806,3 +818,24 @@ def get_mementos_spec(subject: str, num_respondents: int) -> RunSpec: metric_specs=metric_specs, groups=[run_spec_name], ) + + +@run_spec_function("vibe_eval") +def get_vibe_eval_spec(subject: str, num_respondents: int) -> RunSpec: + scenario_spec = ScenarioSpec( + class_name="helm.benchmark.scenarios.vision_language.vibe_eval_scenario.VibeEvalScenario", + args={"subject": subject}, + ) + adapter_spec: AdapterSpec = get_open_end_answer_generation_adapter_spec() + metric_specs: List[MetricSpec] = get_vibe_eval_critique_metric_specs( + num_respondents=num_respondents, max_tokens=200 + ) + + run_spec_name: str = "vibe_eval" + return RunSpec( + name=run_spec_name, + scenario_spec=scenario_spec, + adapter_spec=adapter_spec, + metric_specs=metric_specs, + groups=[run_spec_name], + ) diff --git a/src/helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py b/src/helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py new file mode 100644 index 0000000000..7fbcba068b --- /dev/null +++ b/src/helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py @@ -0,0 +1,95 @@ +import os.path +from typing import List + +from datasets import load_dataset +from tqdm import tqdm + +from helm.benchmark.scenarios.scenario import ( + CORRECT_TAG, + TEST_SPLIT, + Instance, + Input, + Output, + Reference, + Scenario, +) +from helm.common.media_object import MediaObject, MultimediaObject +from helm.common.general import ensure_directory_exists + + +class VibeEvalScenario(Scenario): + """ + Vibe-Eval: A hard evaluation suite for measuring progress of multimodal language models + + We introduce Vibe-Eval: a new open benchmark and framework for evaluating multimodal chat + models. Vibe-Eval consists of 269 visual understanding prompts, including 100 of hard + difficulty, complete with gold-standard responses authored by experts. Vibe-Eval is + open-ended and challenging with dual objectives: (i) vibe checking multimodal chat models + for day-to-day tasks and (ii) rigorously testing and probing the capabilities of present + frontier models. Notably, our hard set contains >50% questions that all frontier models + answer incorrectly. We also discuss trade-offs between human and automatic evaluation, + and show that automatic model evaluation using Reka Core roughly correlates to human judgment. + + @article{padlewski2024vibe, + title={Vibe-Eval: A hard evaluation suite for measuring progress of multimodal language models}, + author={Padlewski, Piotr and Bain, Max and Henderson, Matthew and Zhu, Zhongkai + and Relan, Nishant and Pham, Hai and Ong, Donovan and Aleksiev, Kaloyan and Ormazabal, Aitor + and Phua, Samuel and others}, + journal={arXiv preprint arXiv:2405.02287}, + year={2024} + } + + Paper: https://arxiv.org/abs/2306.13394 + """ + + VIBE_EVAL_HUGGINGFACE_DATASET_NAME: str = "RekaAI/VibeEval" + + SUBJECTS: List[str] = [ + "difficulty-hard", + "difficulty-normal", + ] + + name = "vibe_eval" + description = "Evaluate multimodal models on ([paper](https://arxiv.org/abs/2405.02287))." + tags = ["vision-language"] + + def __init__(self, subject: str): + super().__init__() + assert subject in self.SUBJECTS, f"Invalid subject: {subject}" + self._subject: str = subject + + def get_instances(self, output_path: str) -> List[Instance]: + images_path: str = os.path.join(output_path, "images") + ensure_directory_exists(images_path) + + instances: List[Instance] = [] + # Process the test set + for row in tqdm( + load_dataset( + self.VIBE_EVAL_HUGGINGFACE_DATASET_NAME, + split=TEST_SPLIT, + cache_dir=output_path, + ) + ): + if row["category"] != self._subject: + continue + example_id: str = row["example_id"].replace("/", "-") + # Save the image locally + local_image_path: str = os.path.join(images_path, f"{example_id}.png") + if not os.path.exists(local_image_path): + row["image"].convert("RGB").save(local_image_path, "PNG", optimize=True) + + content: List[MediaObject] = [ + MediaObject(location=local_image_path, content_type="image/png"), + MediaObject(text=row["prompt"], content_type="text/plain"), + ] + answer: str = row["reference"] + instances.append( + Instance( + Input(multimedia_content=MultimediaObject(content)), + references=[Reference(Output(text=answer), tags=[CORRECT_TAG])], + split=TEST_SPLIT, + ) + ) + + return instances diff --git a/src/helm/clients/reka_client.py b/src/helm/clients/reka_client.py new file mode 100644 index 0000000000..1eab894a84 --- /dev/null +++ b/src/helm/clients/reka_client.py @@ -0,0 +1,189 @@ +# mypy: check_untyped_defs = False +import requests +from typing import Any, Dict, List, Optional, TypedDict + +from helm.proxy.retry import NonRetriableException +from helm.common.cache import CacheConfig +from helm.common.media_object import TEXT_TYPE +from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput +from helm.common.hierarchical_logger import hlog +from helm.common.optional_dependencies import handle_module_not_found_error +from helm.tokenizers.tokenizer import Tokenizer +from .client import CachingClient, truncate_and_tokenize_response_text + +try: + import reka +except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["reka-api"]) + + +class RekaAIRequest(TypedDict): + """Data passed between make_request and _send_request. Used as the cache key.""" + + model_name: str + conversation_history: List[Dict[str, str]] + request_output_len: int + temperature: float + runtime_top_p: float + random_seed: Optional[int] + stop_words: Optional[List[str]] + presence_penalty: float + frequency_penalty: float + + +class RekaClient(CachingClient): + REKA_CHAT_ROLE_MAPPING: Dict[str, str] = { + "user": "human", + "assistant": "model", + } + + def __init__( + self, + tokenizer: Tokenizer, + tokenizer_name: str, + cache_config: CacheConfig, + api_key: Optional[str] = None, + ): + super().__init__(cache_config=cache_config) + self.tokenizer = tokenizer + self.tokenizer_name = tokenizer_name + self.client = reka + self.client.API_KEY = api_key + + def _is_reka_model_engine(self, model_engine: str) -> bool: + if ( + model_engine.startswith("reka-edge") + or model_engine.startswith("reka-flash") + or model_engine.startswith("reka-core") + ): + return True + else: + return False + + def _get_model_for_request(self, request: Request) -> str: + return request.model_engine + + def _get_random_seed(self, request: Request, completion_index: int) -> Optional[int]: + if request.random is None and completion_index == 0: + return None + + # Treat the user's request.random as an integer for the random seed. + try: + request_random_seed = int(request.random) if request.random is not None else 0 + except ValueError: + raise NonRetriableException("RekaAIClient only supports integer values for request.random") + + # A large prime is used so that the resulting values are unlikely to collide + # with request.random values chosen by the user. + fixed_large_prime = 1911011 + completion_index_random_seed = completion_index * fixed_large_prime + + return request_random_seed + completion_index_random_seed + + def _convert_messages_to_reka_chat_history(self, messages: List[Dict[str, Any]]): + chat_history = [] + num_images: int = 0 + for chat_turn, message in enumerate(messages): + role = message["role"] + content = message["content"] + current_chat_history: Dict[str, Any] = { + "type": self.REKA_CHAT_ROLE_MAPPING[role], + "text": "", # text placeholder + "media_url": None, + } + for item in content: + if item["type"] == "image_url": + if chat_turn == 0 and num_images == 0: + current_chat_history["media_url"] = item["image_url"]["url"] + num_images += 1 + else: + raise ValueError( + f"Only the first message can contain one image. Found image input " + f"in message {chat_turn + 1}" + ) + elif item["type"] == "text": + current_chat_history["text"] = item["text"] + else: + raise ValueError(f"Unrecognized message type {item['type']}") + chat_history.append(current_chat_history) + return chat_history + + def make_request(self, request: Request) -> RequestResult: + completions: List[GeneratedOutput] = [] + messages: Optional[List[Dict[str, Any]]] = request.messages + reka_chat_history: List[Dict[str, Any]] + if messages is not None: + # Checks that all messages have a role and some content + for message in messages: + if not message.get("role") or not message.get("content"): + raise ValueError("All messages must have a role and content") + # Checks that the last role is "user" + if messages[-1]["role"] != "user": + raise ValueError("Last message must have role 'user'") + if request.prompt != "": + hlog("WARNING: Since message is set, prompt will be ignored") + reka_chat_history = self._convert_messages_to_reka_chat_history(messages) + else: + current_chat_history: Dict[str, Any] = { + "type": "human", + "text": "", + "media_url": None, + } + if request.multimodal_prompt is not None: + for media_object in request.multimodal_prompt.media_objects: + if media_object.is_type("image") and media_object.location: + from helm.common.images_utils import encode_base64 + + base64_image: str = encode_base64(media_object.location) + current_chat_history["media_url"] = f"data:image/jpeg;base64,{base64_image}" + elif media_object.is_type(TEXT_TYPE): + if media_object.text is None: + raise ValueError("MediaObject of text type has missing text field value") + current_chat_history["text"] = media_object.text + else: + raise ValueError(f"Unrecognized MediaObject type {media_object.type}") + + else: + current_chat_history["text"] = request.prompt + reka_chat_history = [current_chat_history] + + # `num_completions` is not supported, so instead make `num_completions` separate requests. + for completion_index in range(request.num_completions): + try: + raw_request: RekaAIRequest = { + "model_name": self._get_model_for_request(request), + "conversation_history": reka_chat_history, # we only use chat_history as the input + "request_output_len": request.max_tokens, + "temperature": request.temperature, + "random_seed": self._get_random_seed(request, completion_index), + "stop_words": request.stop_sequences or None, # API doesn't like empty list + "runtime_top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + } + + def do_it() -> Dict[str, Any]: + return self.client.chat(**raw_request) + + response, cached = self.cache.get(raw_request, wrap_request_time(do_it)) + except (requests.exceptions.RequestException, AssertionError) as e: + error: str = f"RekaClient error: {e}" + return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[]) + + response_message: Dict[str, Any] = response + assert response_message["type"] == "model" + response_text: str = response_message["text"] + + # The Reka API doesn't support echo. If `echo_prompt` is true, combine the prompt and completion. + text: str = request.prompt + response_text if request.echo_prompt else response_text + completion = truncate_and_tokenize_response_text(text, request, self.tokenizer, self.tokenizer_name) + completions.append(completion) + + return RequestResult( + success=True, + cached=cached, + request_time=response["request_time"], + request_datetime=response.get("request_datetime"), + completions=completions, + embedding=[], + ) diff --git a/src/helm/common/critique_request.py b/src/helm/common/critique_request.py index 677718149b..cecfa4bac6 100644 --- a/src/helm/common/critique_request.py +++ b/src/helm/common/critique_request.py @@ -59,6 +59,9 @@ class CritiqueTaskTemplate: questions: List[CritiqueQuestionTemplate] """List of templates for the questions.""" + max_tokens: Optional[int] = None + """Max token to be generated for the free-end generation.""" + @dataclass(frozen=True) class CritiqueRequest: diff --git a/src/helm/config/model_deployments.yaml b/src/helm/config/model_deployments.yaml index fb3a728a42..01eee1a4e5 100644 --- a/src/helm/config/model_deployments.yaml +++ b/src/helm/config/model_deployments.yaml @@ -2161,3 +2161,53 @@ model_deployments: max_sequence_length: 8191 client_spec: class_name: "helm.clients.vision_language.qwen_vlm_client.QwenVLMClient" + +# Reka + - name: reka/reka-core + model_name: reka/reka-core + tokenizer_name: openai/cl100k_base + max_sequence_length: 128000 + client_spec: + class_name: "helm.clients.reka_client.RekaClient" + + - name: reka/reka-core-20240415 + model_name: reka/reka-core-20240415 + tokenizer_name: openai/cl100k_base + max_sequence_length: 128000 + client_spec: + class_name: "helm.clients.reka_client.RekaClient" + + - name: reka/reka-core-20240501 + model_name: reka/reka-core-20240501 + tokenizer_name: openai/cl100k_base + max_sequence_length: 128000 + client_spec: + class_name: "helm.clients.reka_client.RekaClient" + + - name: reka/reka-flash + model_name: reka/reka-flash + tokenizer_name: openai/cl100k_base + max_sequence_length: 128000 + client_spec: + class_name: "helm.clients.reka_client.RekaClient" + + - name: reka/reka-flash-20240226 + model_name: reka/reka-flash-20240226 + tokenizer_name: openai/cl100k_base + max_sequence_length: 128000 + client_spec: + class_name: "helm.clients.reka_client.RekaClient" + + - name: reka/reka-edge + model_name: reka/reka-edge + tokenizer_name: openai/cl100k_base + max_sequence_length: 64000 + client_spec: + class_name: "helm.clients.reka_client.RekaClient" + + - name: reka/reka-edge-20240208 + model_name: reka/reka-edge-20240208 + tokenizer_name: openai/cl100k_base + max_sequence_length: 64000 + client_spec: + class_name: "helm.clients.reka_client.RekaClient" \ No newline at end of file diff --git a/src/helm/config/model_metadata.yaml b/src/helm/config/model_metadata.yaml index 4afe6ab0ca..1199275a0f 100644 --- a/src/helm/config/model_metadata.yaml +++ b/src/helm/config/model_metadata.yaml @@ -2434,3 +2434,65 @@ models: num_parameters: 100000000000 release_date: 2022-06-23 tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, ABLATION_MODEL_TAG] + + # Reka + - name: reka/reka-core + display_name: Reka-Core + description: Reka-Core + creator_organization_name: Reka AI + access: limited + release_date: 2024-04-18 + tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: reka/reka-core-20240415 + display_name: Reka-Core-20240415 + description: Reka-Core-20240415 + creator_organization_name: Reka AI + access: limited + release_date: 2024-04-18 + tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: reka/reka-core-20240501 + display_name: Reka-Core-20240501 + description: Reka-Core-20240501 + creator_organization_name: Reka AI + access: limited + release_date: 2024-05-01 + tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: reka/reka-flash + display_name: Reka-Flash (21B) + description: Reka-Flash (21B) + creator_organization_name: Reka AI + access: limited + num_parameters: 21000000000 + release_date: 2024-04-18 + tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: reka/reka-flash-20240226 + display_name: Reka-Flash-20240226 (21B) + description: Reka-Flash-20240226 (21B) + creator_organization_name: Reka AI + access: limited + num_parameters: 21000000000 + release_date: 2024-04-18 + tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: reka/reka-edge + display_name: Reka-Edge (7B) + description: Reka-Edge (7B) + creator_organization_name: Reka AI + access: limited + num_parameters: 7000000000 + release_date: 2024-04-18 + tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG] + + - name: reka/reka-edge-20240208 + display_name: Reka-Edge-20240208 (7B) + description: Reka-Edge-20240208 (7B) + creator_organization_name: Reka AI + access: limited + num_parameters: 7000000000 + release_date: 2024-04-18 + tags: [VISION_LANGUAGE_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG] + diff --git a/src/helm/proxy/critique/model_critique_client.py b/src/helm/proxy/critique/model_critique_client.py index 56cf5618de..a5325d5037 100644 --- a/src/helm/proxy/critique/model_critique_client.py +++ b/src/helm/proxy/critique/model_critique_client.py @@ -25,6 +25,8 @@ class CritiqueParseError(Exception): class ModelCritiqueClient(CritiqueClient): """A CritiqueClient that queries a Model to answer CritiqueRequests.""" + VISION_LANGUAGE_MODELS = ["openai/gpt-4-vision", "reka/reka"] + def __init__(self, client: Client, model_name): self._client = client self._model_name = model_name @@ -32,7 +34,11 @@ def __init__(self, client: Client, model_name): get_default_model_deployment_for_model(model_name, warn_arg_deprecated=False, ignore_deprecated=True) or self._model_name ) - self.vision_language = True if model_name.startswith("openai/gpt-4-vision") else False + self.vision_language = False + for vision_language_model_name in self.VISION_LANGUAGE_MODELS: + if model_name.startswith(vision_language_model_name): + self.vision_language = True + break def _interpolate_fields(self, text: str, fields: Dict[str, str]) -> str: for key, value in fields.items(): @@ -60,10 +66,15 @@ def _task_to_requests(self, task: CritiqueTaskTemplate, fields: Dict[str, str]) requests: List[Request] = [] for question in task.questions: - prompt: str = base_prompt + "\n\n" + self._question_to_prompt(question, fields) + prompt: str + if len(question.text) > 0: + prompt = base_prompt + "\n\n" + self._question_to_prompt(question, fields) + else: + # We may don't want to add extra newlines and prompts + # if the question text is empty (e.g., the Vibe-Eval evaluator). + prompt = base_prompt if question.question_type == "free_response": - # TODO: Make max_tokens configurable - max_tokens = 100 + max_tokens = 100 if task.max_tokens is None else task.max_tokens elif question.question_type == "checkbox": # We multiply by 2 because the model will generate a comma after each option. max_tokens = len(question.options) * 2