Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix identifying params #16

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.language_models import LLM, BaseLanguageModel
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_aws.utils import (
Expand Down Expand Up @@ -292,7 +292,7 @@ async def aprepare_output_stream(
)


class BedrockBase(BaseModel, ABC):
class BedrockBase(BaseLanguageModel, ABC):
"""Base class for Bedrock models."""

client: Any = Field(exclude=True) #: :meta private:
Expand Down Expand Up @@ -325,7 +325,7 @@ class BedrockBase(BaseModel, ABC):
equivalent to the modelId property in the list-foundation-models api. For custom and
provisioned models, an ARN value is expected."""

model_kwargs: Optional[Dict] = None
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""

endpoint_url: Optional[str] = None
Expand Down Expand Up @@ -440,11 +440,14 @@ def validate_environment(cls, values: Dict) -> Dict:
return values

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
def _identifying_params(self) -> Dict[str, Any]:
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
"model_id": self.model_id,
"provider": self._get_provider(),
"stream": self.streaming,
"guardrails": self.guardrails,
**_model_kwargs,
}
Comment on lines 445 to 451
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, would we need to call _get_provider in any of the other functions or can just use self.provider? For example:
https://github.com/langchain-ai/langchain-aws/blob/main/libs/aws/langchain_aws/llms/bedrock.py#L462

provider = self._get_provider()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fixed! This might have been a comment from before I changed self.provider to self._get_provider()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just calling out - this moves the model_kwargs into a flat structure of invocation params to match most other model providers. Let me know if that sounds good to you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I misunderstood the question. Still need to use self._get_provider() everywhere because provider isn't always set.

Although Likely worth adding a root validator that populates that field based on model string.


def _get_provider(self) -> str:
Expand Down Expand Up @@ -617,7 +620,8 @@ def _prepare_input_and_invoke_stream(

# stop sequence from _generate() overrides
# stop sequences in the class attribute
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
if k := self.provider_stop_sequence_key_name_map.get(provider):
_model_kwargs[k] = stop

Comment on lines +623 to 625
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Ty.

if provider == "cohere":
_model_kwargs["stream"] = True
Expand Down Expand Up @@ -679,7 +683,8 @@ async def _aprepare_input_and_invoke_stream(
raise ValueError(
f"Stop sequence key name for {provider} is not supported."
)
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
if k := self.provider_stop_sequence_key_name_map.get(provider):
_model_kwargs[k] = stop

if provider == "cohere":
_model_kwargs["stream"] = True
Expand Down
Loading