diff --git a/src/controlflow/llm/models.py b/src/controlflow/llm/models.py index 11a66c8..7b0af5b 100644 --- a/src/controlflow/llm/models.py +++ b/src/controlflow/llm/models.py @@ -52,6 +52,8 @@ def get_model( "To use Google as an LLM provider, please install the `langchain_google_genai` package." ) cls = ChatGoogleGenerativeAI + if temperature is None: + temperature = 0.7 elif provider == "groq": try: from langchain_groq import ChatGroq @@ -60,6 +62,8 @@ def get_model( "To use Groq as an LLM provider, please install the `langchain_groq` package." ) cls = ChatGroq + if temperature is None: + temperature = 0.7 elif provider == "ollama": try: from langchain_ollama import ChatOllama @@ -73,7 +77,9 @@ def get_model( f"Could not load provider `{provider}` automatically. Please provide the LLM class manually." ) - return cls(model=model, temperature=temperature, **kwargs) + if temperature is not None: + kwargs["temperature"] = temperature + return cls(model=model, **kwargs) def _get_initial_default_model() -> BaseChatModel: