diff --git a/setup.cfg b/setup.cfg index 64e998cc6a..e030d14b7b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -129,6 +129,9 @@ anthropic = anthropic~=0.17 websocket-client~=1.3.2 # For legacy stanford-online-all-v4-s3 +cohere = + cohere~=5.3 + mistral = mistralai~=0.0.11 @@ -154,6 +157,7 @@ models = crfm-helm[allenai] crfm-helm[amazon] crfm-helm[anthropic] + crfm-helm[cohere] crfm-helm[google] crfm-helm[mistral] crfm-helm[openai] diff --git a/src/helm/clients/cohere_client.py b/src/helm/clients/cohere_client.py index 555e5e12d3..06c8b939b4 100644 --- a/src/helm/clients/cohere_client.py +++ b/src/helm/clients/cohere_client.py @@ -1,8 +1,9 @@ import json import requests -from typing import List +from typing import List, Optional, Sequence, TypedDict from helm.common.cache import CacheConfig +from helm.common.optional_dependencies import handle_module_not_found_error from helm.common.request import ( wrap_request_time, EMBEDDING_UNAVAILABLE_REQUEST_RESULT, @@ -11,8 +12,13 @@ GeneratedOutput, Token, ) -from .client import CachingClient, truncate_sequence -from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION +from helm.clients.client import CachingClient, truncate_sequence +from helm.clients.cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION + +try: + import cohere +except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["cohere"]) class CohereClient(CachingClient): @@ -152,3 +158,92 @@ def do_it(): completions=completions, embedding=[], ) + + +class CohereRawChatRequest(TypedDict): + message: str + model: Optional[str] + preamble: Optional[str] + chat_history: Optional[Sequence[cohere.ChatMessage]] + temperature: Optional[float] + max_tokens: Optional[int] + k: Optional[int] + p: Optional[float] + seed: Optional[float] + stop_sequences: Optional[Sequence[str]] + frequency_penalty: Optional[float] + presence_penalty: Optional[float] + + +def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest: + # TODO: Support chat + model = request.model.replace("cohere/", "") + return { + "message": request.prompt, + "model": model, + "preamble": None, + "chat_history": None, + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "k": request.top_k_per_token, + "p": request.top_p, + "stop_sequences": request.stop_sequences, + "seed": float(request.random) if request.random is not None else None, + "frequency_penalty": request.frequency_penalty, + "presence_penalty": request.presence_penalty, + } + + +class CohereChatClient(CachingClient): + """ + Leverages the chat endpoint: https://docs.cohere.com/reference/chat + + Cohere models will only support chat soon: https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat + """ + + def __init__(self, api_key: str, cache_config: CacheConfig): + super().__init__(cache_config=cache_config) + self.client = cohere.Client(api_key=api_key) + + def make_request(self, request: Request) -> RequestResult: + if request.embedding: + return EMBEDDING_UNAVAILABLE_REQUEST_RESULT + # TODO: Support multiple completions + assert request.num_completions == 1, "CohereChatClient only supports num_completions=1" + # TODO: Support messages + assert not request.messages, "CohereChatClient currently does not support the messages API" + + raw_request: CohereRawChatRequest = convert_to_raw_chat_request(request) + + try: + + def do_it(): + """ + Send the request to the Cohere Chat API. Responses will be structured like this: + cohere.Chat { + message: What's up? + text: Hey there! How's it going? I'm doing well, thank you for asking 😊. + ... + } + """ + raw_response = self.client.chat(**raw_request).dict() + assert "text" in raw_response, f"Response does not contain text: {raw_response}" + return raw_response + + response, cached = self.cache.get(raw_request, wrap_request_time(do_it)) + except (requests.exceptions.RequestException, AssertionError) as e: + error: str = f"CohereClient error: {e}" + return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[]) + + completions: List[GeneratedOutput] = [] + completion: GeneratedOutput = GeneratedOutput(text=response["text"], logprob=0.0, tokens=[]) + completions.append(completion) + + return RequestResult( + success=True, + cached=cached, + request_time=response["request_time"], + request_datetime=response["request_datetime"], + completions=completions, + embedding=[], + ) diff --git a/src/helm/config/model_deployments.yaml b/src/helm/config/model_deployments.yaml index 8a07b7b90a..73eb3b0d89 100644 --- a/src/helm/config/model_deployments.yaml +++ b/src/helm/config/model_deployments.yaml @@ -325,6 +325,25 @@ model_deployments: window_service_spec: class_name: "helm.benchmark.window_services.cohere_window_service.CohereWindowService" + - name: cohere/command-r + model_name: cohere/command-r + tokenizer_name: cohere/c4ai-command-r-v01 + max_sequence_length: 128000 + max_request_length: 128000 + client_spec: + class_name: "helm.clients.cohere_client.CohereChatClient" + + - name: cohere/command-r-plus + model_name: cohere/command-r-plus + tokenizer_name: cohere/c4ai-command-r-plus + # "We have a known issue where prompts between 112K - 128K in length + # result in bad generations." + # Source: https://docs.cohere.com/docs/command-r-plus + max_sequence_length: 110000 + max_request_length: 110000 + client_spec: + class_name: "helm.clients.cohere_client.CohereChatClient" + # Craiyon - name: craiyon/dalle-mini diff --git a/src/helm/config/model_metadata.yaml b/src/helm/config/model_metadata.yaml index 06c9ada9d4..d8d4551cde 100644 --- a/src/helm/config/model_metadata.yaml +++ b/src/helm/config/model_metadata.yaml @@ -468,7 +468,25 @@ models: creator_organization_name: Cohere access: limited release_date: 2023-09-29 - tags: [TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG] + tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG] + + - name: cohere/command-r + display_name: Cohere Command R + description: Command R is a multilingual 35B parameter model with a context length of 128K that has been trained with conversational tool use capabilities. + creator_organization_name: Cohere + access: open + num_parameters: 35000000000 + release_date: 2024-03-11 + tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG] + + - name: cohere/command-r-plus + display_name: Cohere Command R Plus + description: Command R+ is a multilingual 104B parameter model with a context length of 128K that has been trained with conversational tool use capabilities. + creator_organization_name: Cohere + access: open + num_parameters: 104000000000 + release_date: 2024-04-04 + tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG] # Craiyon - name: craiyon/dalle-mini diff --git a/src/helm/config/tokenizer_configs.yaml b/src/helm/config/tokenizer_configs.yaml index 430dc0aaf0..51780578fe 100644 --- a/src/helm/config/tokenizer_configs.yaml +++ b/src/helm/config/tokenizer_configs.yaml @@ -83,6 +83,22 @@ tokenizer_configs: end_of_text_token: "" prefix_token: ":" + - name: cohere/c4ai-command-r-v01 + tokenizer_spec: + class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" + args: + pretrained_model_name_or_path: CohereForAI/c4ai-command-r-v01 + end_of_text_token: "<|END_OF_TURN_TOKEN|>" + prefix_token: "" + + - name: cohere/c4ai-command-r-plus + tokenizer_spec: + class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer" + args: + pretrained_model_name_or_path: CohereForAI/c4ai-command-r-plus + end_of_text_token: "<|END_OF_TURN_TOKEN|>" + prefix_token: "" + # Databricks - name: databricks/dbrx-instruct tokenizer_spec: