Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replicate as an llm provider #707

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ google = [
"tiktoken >= 0.3.3",
"google-cloud-aiplatform>=1.25.0"
]
replicate = [
"replicate >= 0.23.1",
"transformers >= 4.25.0",
]
cohere = [
"cohere>=4.11.2"
]
Expand Down
2 changes: 2 additions & 0 deletions src/autolabel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from autolabel.models.palm import PaLMLLM
from autolabel.models.hf_pipeline import HFPipelineLLM
from autolabel.models.hf_pipeline_vision import HFPipelineMultimodal
from autolabel.models.replicate import ReplicateLLM
from autolabel.models.refuel import RefuelLLM

MODEL_REGISTRY = {
Expand All @@ -23,6 +24,7 @@
ModelProvider.HUGGINGFACE_PIPELINE: HFPipelineLLM,
ModelProvider.HUGGINGFACE_PIPELINE_VISION: HFPipelineMultimodal,
ModelProvider.GOOGLE: PaLMLLM,
ModelProvider.REPLICATE: ReplicateLLM,
ModelProvider.REFUEL: RefuelLLM,
}

Expand Down
149 changes: 149 additions & 0 deletions src/autolabel/models/replicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from typing import List, Optional
from time import time
import logging
import requests

from autolabel.models import BaseModel
from autolabel.configs import AutolabelConfig
from autolabel.cache import BaseCache
from autolabel.schema import RefuelLLMResult


import os

logger = logging.getLogger(__name__)


class ReplicateLLM(BaseModel):
REPLICATE_MAINTAINED_MODELS = [
"meta/llama-2-70b",
"meta/llama-2-13b",
"meta/llama-2-7b",
"meta/llama-2-70b-chat",
"meta/llama-2-13b-chat",
"meta/llama-2-7b-chat",
"mistralai/mistral-7b-v0.1",
"mistralai/mistral-7b-instruct-vo.2",
"mistralai/mixtral-8x7b-instruct-v0.1",
]

# Default parameters for OpenAILLM
DEFAULT_MODEL = "meta/llama-2-7b-chat"

DEFAULT_PARAMS_COMPLETION_ENGINE = {
"max_tokens": 1000,
"temperature": 0.01,
"model_kwargs": {"logprobs": 1},
"request_timeout": 30,
}

# Reference: https://replicate.com/docs/billing
COST_PER_PROMPT_TOKEN = {
"meta/llama-2-70b": 0.65 / 1e6,
"meta/llama-2-13b": 0.10 / 1e6,
"meta/llama-2-7b": 0.05 / 1e6,
"meta/llama-2-70b-chat": 0.65 / 1e6,
"meta/llama-2-13b-chat": 0.10 / 1e6,
"meta/llama-2-7b-chat": 0.05 / 1e6,
"mistralai/mistral-7b-v0.1": 0.05 / 1e6,
"mistralai/mistral-7b-instruct-v0.2": 0.05 / 1e6,
"mistralai/mixtral-8x7b-instruct-v0.1": 0.30 / 1e6,
}
COST_PER_COMPLETION_TOKEN = {
"meta/llama-2-70b": 2.75 / 1e6,
"meta/llama-2-13b": 0.50 / 1e6,
"meta/llama-2-7b": 0.25 / 1e6,
"meta/llama-2-70b-chat": 2.75 / 1e6,
"meta/llama-2-13b-chat": 0.50 / 1e6,
"meta/llama-2-7b-chat": 0.25 / 1e6,
"mistralai/mistral-7b-v0.1": 0.25 / 1e6,
"mistralai/mistral-7b-instruct-v0.2": 0.25 / 1e6,
"mistralai/mixtral-8x7b-instruct-v0.1": 1.00 / 1e6,
}

def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
super().__init__(config, cache)
try:
from langchain_community.llms import Replicate
from transformers import LlamaTokenizerFast
except ImportError:
raise ImportError(
"replicate is required to use the ReplicateLLM. Please install it with the following command: pip install 'refuel-autolabel[replicate]'"
)

if os.getenv("REPLICATE_API_TOKEN") is None:
raise ValueError("REPLICATE_API_TOKEN environment variable not set")

# populate model name
self.model_name = config.model_name() or self.DEFAULT_MODEL

# populate model params and initialize the LLM
model_params = config.model_params()

self.model_params = {
**self.DEFAULT_PARAMS_COMPLETION_ENGINE,
**model_params,
}

# get latest model version, required by langchain to process replicate generations
response = requests.get(
f"https://api.replicate.com/v1/models/{self.model_name}",
headers={"Authorization": f"Token {os.environ['REPLICATE_API_TOKEN']}"},
)
if response.status_code == 404:
raise ValueError(f"Model {self.model_name} not found on Replicate")
latest_model_version = response.json()["latest_version"]["id"]

self.llm = Replicate(
model=f"{self.model_name}:{latest_model_version}",
verbose=False,
**self.model_params,
)

self.tokenizer = LlamaTokenizerFast.from_pretrained(
"hf-internal-testing/llama-tokenizer"
)

def is_model_managed_by_replicate(self) -> bool:
return self.model_name in self.REPLICATE_MAINTAINED_MODELS

def _label(self, prompts: List[str]) -> RefuelLLMResult:
try:
start_time = time()
result = self.llm.generate(prompts)
generations = result.generations
end_time = time()
return RefuelLLMResult(
generations=generations,
errors=[None] * len(generations),
latencies=[end_time - start_time] * len(generations),
)
except Exception as e:
return self._label_individually(prompts)

def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
if self.is_model_mangaed_by_replicate():
num_prompt_toks = len(self.tokenizer.encode(prompt))
if label:
num_label_toks = len(self.tokenizer.encode(label))
else:
# get an upper bound
num_label_toks = self.model_params["max_tokens"]

cost_per_prompt_token = self.COST_PER_PROMPT_TOKEN[self.model_name]
cost_per_completion_token = self.COST_PER_COMPLETION_TOKEN[self.model_name]
return (num_prompt_toks * cost_per_prompt_token) + (
num_label_toks * cost_per_completion_token
)
else:
# TODO - at the moment it's not possible to calculate it https://github.com/replicate/replicate-python/issues/243
return 0

def returns_token_probs(self) -> bool:
return (
self.model_name is not None
and self.model_name in self.MODELS_WITH_TOKEN_PROBS
)

def get_num_tokens(self, prompt: str) -> int:
return len(self.tokenizer.encode(prompt))
1 change: 1 addition & 0 deletions src/autolabel/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ModelProvider(str, Enum):
REFUEL = "refuel"
GOOGLE = "google"
COHERE = "cohere"
REPLICATE = "replicate"
CUSTOM = "custom"


Expand Down
99 changes: 99 additions & 0 deletions tests/assets/banking/config_banking_replicate.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{
"task_name": "BankingComplaintsClassification",
"task_type": "classification",
"dataset": {
"label_column": "label",
"delimiter": ","
},
"model": {
"provider": "replicate",
"name": "meta/llama-2-70b-chat"
},
"prompt": {
"task_guidelines": "You are an expert at understanding bank customers support complaints and queries.\nYour job is to correctly classify the provided input example into one of the following categories.\nCategories:\n{labels}",
"output_guidelines": "You will answer with just the the correct output label and nothing else.",
"labels": [
"activate_my_card",
"age_limit",
"apple_pay_or_google_pay",
"atm_support",
"automatic_top_up",
"balance_not_updated_after_bank_transfer",
"balance_not_updated_after_cheque_or_cash_deposit",
"beneficiary_not_allowed",
"cancel_transfer",
"card_about_to_expire",
"card_acceptance",
"card_arrival",
"card_delivery_estimate",
"card_linking",
"card_not_working",
"card_payment_fee_charged",
"card_payment_not_recognised",
"card_payment_wrong_exchange_rate",
"card_swallowed",
"cash_withdrawal_charge",
"cash_withdrawal_not_recognised",
"change_pin",
"compromised_card",
"contactless_not_working",
"country_support",
"declined_card_payment",
"declined_cash_withdrawal",
"declined_transfer",
"direct_debit_payment_not_recognised",
"disposable_card_limits",
"edit_personal_details",
"exchange_charge",
"exchange_rate",
"exchange_via_app",
"extra_charge_on_statement",
"failed_transfer",
"fiat_currency_support",
"get_disposable_virtual_card",
"get_physical_card",
"getting_spare_card",
"getting_virtual_card",
"lost_or_stolen_card",
"lost_or_stolen_phone",
"order_physical_card",
"passcode_forgotten",
"pending_card_payment",
"pending_cash_withdrawal",
"pending_top_up",
"pending_transfer",
"pin_blocked",
"receiving_money",
"Refund_not_showing_up",
"request_refund",
"reverted_card_payment?",
"supported_cards_and_currencies",
"terminate_account",
"top_up_by_bank_transfer_charge",
"top_up_by_card_charge",
"top_up_by_cash_or_cheque",
"top_up_failed",
"top_up_limits",
"top_up_reverted",
"topping_up_by_card",
"transaction_charged_twice",
"transfer_fee_charged",
"transfer_into_account",
"transfer_not_received_by_recipient",
"transfer_timing",
"unable_to_verify_identity",
"verify_my_identity",
"verify_source_of_funds",
"verify_top_up",
"virtual_card_not_working",
"visa_or_mastercard",
"why_verify_identity",
"wrong_amount_of_cash_received",
"wrong_exchange_rate_for_cash_withdrawal"
],
"few_shot_examples": "seed.csv",
"few_shot_selection": "semantic_similarity",
"few_shot_num": 10,
"example_template": "Input: {example}\nOutput: {label}"
}
}
1 change: 1 addition & 0 deletions tests/unit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
{
"REFUEL_API_KEY": "dummy_refuel_api_key",
"OPENAI_API_KEY": "dummy_open_api_key",
"REPLICATE_API_TOKEN": "dummy_replicate_api_token",
"ANTHROPIC_API_KEY": "dummy_anthropic_api_key",
}
)
65 changes: 65 additions & 0 deletions tests/unit/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from autolabel.models.openai import OpenAILLM
from autolabel.models.openai_vision import OpenAIVisionLLM
from autolabel.models.palm import PaLMLLM
from autolabel.models.replicate import ReplicateLLM
from autolabel.models.refuel import RefuelLLM
from langchain.schema import Generation, LLMResult
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from pytest import approx
import pytest


################### ANTHROPIC TESTS #######################
Expand Down Expand Up @@ -198,6 +200,69 @@ def test_gpt4V_return_probs():
################### OPENAI GPT 4V TESTS #######################


################### REPLICATE TESTS #######################
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data


@pytest.fixture
def replicate_model(mocker):
mocker.patch(
"requests.get",
return_value=MockResponse({"latest_version": {"id": "valid_id"}}, 200),
)
return ReplicateLLM(
config=AutolabelConfig(
config="tests/assets/banking/config_banking_replicate.json"
)
)


def test_replicate_initialization(replicate_model):
assert isinstance(replicate_model, ReplicateLLM)


def test_replicate_invalid_model_initialization(mocker):
config = AutolabelConfig(
config="tests/assets/banking/config_banking_replicate.json"
)
mocker.patch(
"requests.get",
return_value=MockResponse({"error": "model not found"}, 404),
)

with pytest.raises(ValueError) as excinfo:
model = ReplicateLLM(config)

assert "Model meta/llama-2-70b-chat not found on Replicate" in str(excinfo.value)


def test_replicate_label(mocker, replicate_model):
prompts = ["test1", "test2"]
mocker.patch(
"langchain_community.llms.Replicate.generate",
return_value=LLMResult(
generations=[[Generation(text="Answers")] for _ in prompts]
),
)
x = replicate_model.label(prompts)
assert [i[0].text for i in x.generations] == ["Answers", "Answers"]


def test_replicate_get_cost(replicate_model):
example_prompt = "TestingExamplePrompt"
curr_cost = replicate_model.get_cost(example_prompt)
assert curr_cost == approx(0.00275389, rel=1e-3)


################### REPLICATE TESTS #######################


################### REFUEL TESTS #######################
def test_refuel_initialization():
model = RefuelLLM(
Expand Down