Skip to content

Commit

Permalink
Add Command R and Command R+ models (#2548)
Browse files Browse the repository at this point in the history
Co-authored-by: Yifan Mai <[email protected]>
  • Loading branch information
andyt-cohere and yifanmai authored May 20, 2024
1 parent 13abf8f commit c5c451c
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 4 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ anthropic =
anthropic~=0.17
websocket-client~=1.3.2 # For legacy stanford-online-all-v4-s3

cohere =
cohere~=5.3

mistral =
mistralai~=0.0.11

Expand All @@ -154,6 +157,7 @@ models =
crfm-helm[allenai]
crfm-helm[amazon]
crfm-helm[anthropic]
crfm-helm[cohere]
crfm-helm[google]
crfm-helm[mistral]
crfm-helm[openai]
Expand Down
101 changes: 98 additions & 3 deletions src/helm/clients/cohere_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import requests
from typing import List
from typing import List, Optional, Sequence, TypedDict

from helm.common.cache import CacheConfig
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import (
wrap_request_time,
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
Expand All @@ -11,8 +12,13 @@
GeneratedOutput,
Token,
)
from .client import CachingClient, truncate_sequence
from .cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION
from helm.clients.client import CachingClient, truncate_sequence
from helm.clients.cohere_utils import get_cohere_url, DEFAULT_COHERE_API_VERSION

try:
import cohere
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["cohere"])


class CohereClient(CachingClient):
Expand Down Expand Up @@ -152,3 +158,92 @@ def do_it():
completions=completions,
embedding=[],
)


class CohereRawChatRequest(TypedDict):
message: str
model: Optional[str]
preamble: Optional[str]
chat_history: Optional[Sequence[cohere.ChatMessage]]
temperature: Optional[float]
max_tokens: Optional[int]
k: Optional[int]
p: Optional[float]
seed: Optional[float]
stop_sequences: Optional[Sequence[str]]
frequency_penalty: Optional[float]
presence_penalty: Optional[float]


def convert_to_raw_chat_request(request: Request) -> CohereRawChatRequest:
# TODO: Support chat
model = request.model.replace("cohere/", "")
return {
"message": request.prompt,
"model": model,
"preamble": None,
"chat_history": None,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"k": request.top_k_per_token,
"p": request.top_p,
"stop_sequences": request.stop_sequences,
"seed": float(request.random) if request.random is not None else None,
"frequency_penalty": request.frequency_penalty,
"presence_penalty": request.presence_penalty,
}


class CohereChatClient(CachingClient):
"""
Leverages the chat endpoint: https://docs.cohere.com/reference/chat
Cohere models will only support chat soon: https://docs.cohere.com/docs/migrating-from-cogenerate-to-cochat
"""

def __init__(self, api_key: str, cache_config: CacheConfig):
super().__init__(cache_config=cache_config)
self.client = cohere.Client(api_key=api_key)

def make_request(self, request: Request) -> RequestResult:
if request.embedding:
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
# TODO: Support multiple completions
assert request.num_completions == 1, "CohereChatClient only supports num_completions=1"
# TODO: Support messages
assert not request.messages, "CohereChatClient currently does not support the messages API"

raw_request: CohereRawChatRequest = convert_to_raw_chat_request(request)

try:

def do_it():
"""
Send the request to the Cohere Chat API. Responses will be structured like this:
cohere.Chat {
message: What's up?
text: Hey there! How's it going? I'm doing well, thank you for asking 😊.
...
}
"""
raw_response = self.client.chat(**raw_request).dict()
assert "text" in raw_response, f"Response does not contain text: {raw_response}"
return raw_response

response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
except (requests.exceptions.RequestException, AssertionError) as e:
error: str = f"CohereClient error: {e}"
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])

completions: List[GeneratedOutput] = []
completion: GeneratedOutput = GeneratedOutput(text=response["text"], logprob=0.0, tokens=[])
completions.append(completion)

return RequestResult(
success=True,
cached=cached,
request_time=response["request_time"],
request_datetime=response["request_datetime"],
completions=completions,
embedding=[],
)
19 changes: 19 additions & 0 deletions src/helm/config/model_deployments.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,25 @@ model_deployments:
window_service_spec:
class_name: "helm.benchmark.window_services.cohere_window_service.CohereWindowService"

- name: cohere/command-r
model_name: cohere/command-r
tokenizer_name: cohere/c4ai-command-r-v01
max_sequence_length: 128000
max_request_length: 128000
client_spec:
class_name: "helm.clients.cohere_client.CohereChatClient"

- name: cohere/command-r-plus
model_name: cohere/command-r-plus
tokenizer_name: cohere/c4ai-command-r-plus
# "We have a known issue where prompts between 112K - 128K in length
# result in bad generations."
# Source: https://docs.cohere.com/docs/command-r-plus
max_sequence_length: 110000
max_request_length: 110000
client_spec:
class_name: "helm.clients.cohere_client.CohereChatClient"

# Craiyon

- name: craiyon/dalle-mini
Expand Down
20 changes: 19 additions & 1 deletion src/helm/config/model_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,25 @@ models:
creator_organization_name: Cohere
access: limited
release_date: 2023-09-29
tags: [TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: cohere/command-r
display_name: Cohere Command R
description: Command R is a multilingual 35B parameter model with a context length of 128K that has been trained with conversational tool use capabilities.
creator_organization_name: Cohere
access: open
num_parameters: 35000000000
release_date: 2024-03-11
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: cohere/command-r-plus
display_name: Cohere Command R Plus
description: Command R+ is a multilingual 104B parameter model with a context length of 128K that has been trained with conversational tool use capabilities.
creator_organization_name: Cohere
access: open
num_parameters: 104000000000
release_date: 2024-04-04
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

# Craiyon
- name: craiyon/dalle-mini
Expand Down
16 changes: 16 additions & 0 deletions src/helm/config/tokenizer_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ tokenizer_configs:
end_of_text_token: ""
prefix_token: ":"

- name: cohere/c4ai-command-r-v01
tokenizer_spec:
class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer"
args:
pretrained_model_name_or_path: CohereForAI/c4ai-command-r-v01
end_of_text_token: "<|END_OF_TURN_TOKEN|>"
prefix_token: "<BOS_TOKEN>"

- name: cohere/c4ai-command-r-plus
tokenizer_spec:
class_name: "helm.tokenizers.huggingface_tokenizer.HuggingFaceTokenizer"
args:
pretrained_model_name_or_path: CohereForAI/c4ai-command-r-plus
end_of_text_token: "<|END_OF_TURN_TOKEN|>"
prefix_token: "<BOS_TOKEN>"

# Databricks
- name: databricks/dbrx-instruct
tokenizer_spec:
Expand Down

0 comments on commit c5c451c

Please sign in to comment.