Skip to content

Commit

Permalink
tool callig samabstudio updated
Browse files Browse the repository at this point in the history
  • Loading branch information
jhpiedrahitao committed Nov 19, 2024
1 parent dde001b commit 8781a4c
Showing 1 changed file with 53 additions and 25 deletions.
78 changes: 53 additions & 25 deletions libs/community/langchain_community/chat_models/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,7 @@ class AnswerWithJustification(BaseModel):
if method == "function_calling":
if schema is None:
raise ValueError(
"`schema` must be specified when method is `function_calling`. "
"Received None."
"`schema` must be specified when method is `function_calling`. Received None."

Check failure on line 626 in libs/community/langchain_community/chat_models/sambanova.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/chat_models/sambanova.py:626:89: E501 Line too long (98 > 88)

Check failure on line 626 in libs/community/langchain_community/chat_models/sambanova.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/chat_models/sambanova.py:626:89: E501 Line too long (98 > 88)
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
Expand All @@ -650,8 +649,7 @@ class AnswerWithJustification(BaseModel):
elif method == "json_schema":
if schema is None:
raise ValueError(
"`schema` must be specified when method is not `json_mode`. "
"Received None."
"`schema` must be specified when method is not `json_mode`. Received None."

Check failure on line 652 in libs/community/langchain_community/chat_models/sambanova.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/chat_models/sambanova.py:652:89: E501 Line too long (95 > 88)

Check failure on line 652 in libs/community/langchain_community/chat_models/sambanova.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/chat_models/sambanova.py:652:89: E501 Line too long (95 > 88)
)
llm = self
# TODO bind response format when json schema available by API,
Expand Down Expand Up @@ -959,8 +957,8 @@ class ChatSambaStudio(BaseChatModel):
Setup:
To use, you should have the environment variables:
``SAMBASTUDIO_URL`` set with your SambaStudio deployed endpoint URL.
``SAMBASTUDIO_API_KEY`` set with your SambaStudio deployed endpoint Key.
`SAMBASTUDIO_URL` set with your SambaStudio deployed endpoint URL.
`SAMBASTUDIO_API_KEY` set with your SambaStudio deployed endpoint Key.
https://docs.sambanova.ai/sambastudio/latest/index.html
Example:
Expand Down Expand Up @@ -1109,7 +1107,7 @@ class Joke(BaseModel):
Joke(setup="Why did the cat join a band?",
punchline="Because it wanted to be the purr-cussionist!")
See ``ChatSambaStudio.with_structured_output()`` for more.
See `ChatSambaStudio.with_structured_output()` for more.
Token usage:
.. code-block:: python
Expand Down Expand Up @@ -1296,7 +1294,7 @@ def with_structured_output(
- a JSON Schema,
- a TypedDict class,
- or a Pydantic class.
If ``schema`` is a Pydantic class then the model output will be a
If `schema` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
Expand Down Expand Up @@ -1324,15 +1322,15 @@ def with_structured_output(
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
an instance of ``schema`` (i.e., a Pydantic object).
If `include_raw` is False and `schema` is a Pydantic class, Runnable outputs
an instance of `schema` (i.e., a Pydantic object).
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
Otherwise, if `include_raw` is False then Runnable outputs a dict.
If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException]
If `include_raw` is True, then Runnable outputs a dict with keys:
- `"raw"`: BaseMessage
- `"parsed"`: None if there was a parsing error, otherwise the type depends on the `schema` as described above.
- `"parsing_error"`: Optional[BaseException]
Example: schema=Pydantic class, method="function_calling", include_raw=False:
.. code-block:: python
Expand Down Expand Up @@ -1600,21 +1598,21 @@ def _get_role(self, message: BaseMessage) -> str:
Returns:
str: Role of the LangChain BaseMessage
"""
if isinstance(message, ChatMessage):
role = message.role
elif isinstance(message, SystemMessage):
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, ToolMessage):
role = "tool"
elif isinstance(message, ChatMessage):
role = message.role
else:
raise TypeError(f"Got unknown type {message}")
return role

def _messages_to_string(self, messages: List[BaseMessage]) -> str:
def _messages_to_string(self, messages: List[BaseMessage], **kwargs: Any) -> str:
"""
Convert a list of BaseMessages to a:
- dumped json string with Role / content dict structure
Expand All @@ -1632,18 +1630,47 @@ def _messages_to_string(self, messages: List[BaseMessage]) -> str:
messages_dict: Dict[str, Any] = {
"conversation_id": "sambaverse-conversation-id",
"messages": [],
**kwargs,
}
for message in messages:
messages_dict["messages"].append(
{
if isinstance(message, AIMessage):
message_dict = {
"message_id": message.id,
"role": self._get_role(message),
"content": message.content,
}
# TODO add tools msgs id and assistant msgs tool calls
)
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs[
"tool_calls"
]
if message_dict["content"] == "":
message_dict["content"] = None

elif isinstance(message, ToolMessage):
message_dict = {
"message_id": message.id,
"role": self._get_role(message),
"content": message.content,
"tool_call_id": message.tool_call_id,
}

else:
message_dict = {
"message_id": message.id,
"role": self._get_role(message),
"content": message.content,
}

messages_dict["messages"].append(message_dict)

messages_string = json.dumps(messages_dict)

else:
if "tools" in kwargs.keys():
raise NotImplementedError(
"tool calling not supported in API Generic V2 without process_prompt, "

Check failure on line 1671 in libs/community/langchain_community/chat_models/sambanova.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/chat_models/sambanova.py:1671:89: E501 Line too long (91 > 88)

Check failure on line 1671 in libs/community/langchain_community/chat_models/sambanova.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/chat_models/sambanova.py:1671:89: E501 Line too long (91 > 88)
"switch to OpenAI compatible API or Generic V2 API with process_prompt=True"

Check failure on line 1672 in libs/community/langchain_community/chat_models/sambanova.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/chat_models/sambanova.py:1672:89: E501 Line too long (96 > 88)

Check failure on line 1672 in libs/community/langchain_community/chat_models/sambanova.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/chat_models/sambanova.py:1672:89: E501 Line too long (96 > 88)
)
messages_string = self.special_tokens["start"]
for message in messages:
messages_string += self.special_tokens["start_role"].format(
Expand Down Expand Up @@ -1725,7 +1752,9 @@ def _handle_request(

# create request payload for generic v1 API
elif "api/v2/predict/generic" in self.sambastudio_url:
items = [{"id": "item0", "value": self._messages_to_string(messages)}]
items = [
{"id": "item0", "value": self._messages_to_string(messages, **kwargs)}
]
params: Dict[str, Any] = {
"select_expert": self.model,
"process_prompt": self.process_prompt,
Expand All @@ -1734,7 +1763,6 @@ def _handle_request(
"top_p": self.top_p,
"top_k": self.top_k,
"do_sample": self.do_sample,
**kwargs,
}
if self.model_kwargs is not None:
params = {**params, **self.model_kwargs}
Expand Down

0 comments on commit 8781a4c

Please sign in to comment.