Skip to content

Commit

Permalink
Add missing params in dict() output (#4)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored Sep 2, 2024
1 parent a028756 commit 36500e8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
13 changes: 13 additions & 0 deletions libs/databricks/langchain_databricks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,19 @@ def bind_tools(
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)

@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params

def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
return {
**self._default_params,
**super()._get_invocation_params(stop=stop, **kwargs),
}

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand Down
13 changes: 10 additions & 3 deletions libs/databricks/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,14 @@ def llm() -> ChatDatabricks:
)


def test_chat_mlflow_predict(llm: ChatDatabricks) -> None:
def test_dict(llm: ChatDatabricks) -> None:
d = llm.dict()
assert d["_type"] == "chat-databricks"
assert d["endpoint"] == "databricks-meta-llama-3-70b-instruct"
assert d["target_uri"] == "databricks"


def test_chat_model_predict(llm: ChatDatabricks) -> None:
res = llm.invoke(
[
{"role": "system", "content": "You are a helpful assistant."},
Expand All @@ -169,7 +176,7 @@ def test_chat_mlflow_predict(llm: ChatDatabricks) -> None:
assert res.content == _MOCK_CHAT_RESPONSE["choices"][0]["message"]["content"] # type: ignore[index]


def test_chat_mlflow_stream(llm: ChatDatabricks) -> None:
def test_chat_model_stream(llm: ChatDatabricks) -> None:
res = llm.stream(
[
{"role": "system", "content": "You are a helpful assistant."},
Expand All @@ -180,7 +187,7 @@ def test_chat_mlflow_stream(llm: ChatDatabricks) -> None:
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]


def test_chat_mlflow_bind_tools(llm: ChatDatabricks) -> None:
def test_chat_model_bind_tools(llm: ChatDatabricks) -> None:
class GetWeather(BaseModel):
"""Get the current weather in a given location"""

Expand Down

0 comments on commit 36500e8

Please sign in to comment.