Skip to content

Commit

Permalink
adding log_prob option for chat models
Browse files Browse the repository at this point in the history
  • Loading branch information
TLSDC committed Feb 11, 2025
1 parent fecf700 commit 6b3d0fc
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/agentlab/llm/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class BaseModelArgs(ABC):
max_new_tokens: int = None
temperature: float = 0.1
vision_support: bool = False
log_probs: bool = False

@abstractmethod
def make_model(self) -> AbstractChatModel:
Expand Down
20 changes: 18 additions & 2 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def make_model(self):
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
log_probs=self.log_probs,
)


Expand All @@ -100,6 +101,7 @@ def make_model(self):
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
log_probs=self.log_probs,
)


Expand All @@ -115,6 +117,7 @@ def make_model(self):
temperature=self.temperature,
max_tokens=self.max_new_tokens,
deployment_name=self.deployment_name,
log_probs=self.log_probs,
)


Expand Down Expand Up @@ -225,6 +228,7 @@ def __init__(
client_class=OpenAI,
client_args=None,
pricing_func=None,
log_probs=False,
):
assert max_retry > 0, "max_retry should be greater than 0"

Expand All @@ -233,6 +237,7 @@ def __init__(
self.max_tokens = max_tokens
self.max_retry = max_retry
self.min_retry_wait_time = min_retry_wait_time
self.logprobs = log_probs

# Get the API key from the environment variable if not provided
if api_key_env_var:
Expand Down Expand Up @@ -279,6 +284,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
n=n_samples,
temperature=temperature,
max_tokens=self.max_tokens,
logprobs=self.logprobs,
)

if completion.usage is None:
Expand Down Expand Up @@ -308,7 +314,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
tracking.TRACKER.instance(input_tokens, output_tokens, cost)

if n_samples == 1:
return AIMessage(completion.choices[0].message.content)
res = AIMessage(completion.choices[0].message.content)
if self.logprobs:
res["logprobs"] = completion.choices[0].logprobs
return res
else:
return [AIMessage(c.message.content) for c in completion.choices]

Expand All @@ -328,6 +337,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
super().__init__(
model_name=model_name,
Expand All @@ -339,6 +349,7 @@ def __init__(
api_key_env_var="OPENAI_API_KEY",
client_class=OpenAI,
pricing_func=tracking.get_pricing_openai,
log_probs=log_probs,
)


Expand All @@ -351,6 +362,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
client_args = {
"base_url": "https://openrouter.ai/api/v1",
Expand All @@ -366,6 +378,7 @@ def __init__(
client_class=OpenAI,
client_args=client_args,
pricing_func=tracking.get_pricing_openrouter,
log_probs=log_probs,
)


Expand All @@ -379,6 +392,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
Expand All @@ -399,6 +413,7 @@ def __init__(
client_class=AzureOpenAI,
client_args=client_args,
pricing_func=tracking.get_pricing_openai,
log_probs=log_probs,
)


Expand All @@ -412,6 +427,7 @@ def __init__(
temperature: Optional[int] = 1e-1,
max_new_tokens: Optional[int] = 512,
n_retry_server: Optional[int] = 4,
log_probs: Optional[bool] = False,
):
super().__init__(model_name, base_model_name, n_retry_server)
if temperature < 1e-3:
Expand All @@ -422,4 +438,4 @@ def __init__(
token = os.environ["TGI_TOKEN"]

client = InferenceClient(model=model_url, token=token)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)
3 changes: 2 additions & 1 deletion src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,10 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image):


class BaseMessage(dict):
def __init__(self, role: str, content: Union[str, list[dict]]):
def __init__(self, role: str, content: Union[str, list[dict]], **kwargs):
self["role"] = role
self["content"] = deepcopy(content)
self.update(kwargs)

def __str__(self, warn_if_image=False) -> str:
if isinstance(self["content"], str):
Expand Down

0 comments on commit 6b3d0fc

Please sign in to comment.