From 402287741b2d76fb7bcd25fb1d8148730d82f13b Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 20 Dec 2024 08:39:17 +0000 Subject: [PATCH] fix: Improve Ollama client detection and error handling (#373) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add connection retry logic with exponential backoff - Implement URL validation and connection testing - Add comprehensive error messages and logging - Support OLLAMA_HOST environment variable - Add test suite with mock testing Co-Authored-By: Erkin Alp Güney --- requirements.txt | 6 +- src/llm/ollama_client.py | 105 ++++++++++++++++++++++++++----- tests/test_ollama.py | 129 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 223 insertions(+), 17 deletions(-) create mode 100644 tests/test_ollama.py diff --git a/requirements.txt b/requirements.txt index 91666960..bf522945 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ flask flask-cors toml urllib3 -requests +requests>=2.31.0 colorama fastlogging Jinja2 @@ -12,7 +12,7 @@ pdfminer.six playwright pytest-playwright tiktoken -ollama +ollama>=0.1.6 openai anthropic google-generativeai @@ -31,3 +31,5 @@ orjson gevent gevent-websocket curl_cffi +pytest>=7.4.0 +pytest-mock>=3.12.0 diff --git a/src/llm/ollama_client.py b/src/llm/ollama_client.py index 95b1e365..86d3cd98 100644 --- a/src/llm/ollama_client.py +++ b/src/llm/ollama_client.py @@ -1,25 +1,100 @@ +import os +import time +import requests +from typing import Optional, List, Dict, Any +from urllib.parse import urlparse + import ollama from src.logger import Logger from src.config import Config log = Logger() - class Ollama: def __init__(self): - try: - self.client = ollama.Client(Config().get_ollama_api_endpoint()) - self.models = self.client.list()["models"] - log.info("Ollama available") - except: - self.client = None - log.warning("Ollama not available") - log.warning("run ollama server to use ollama models otherwise use API models") + """Initialize Ollama client with retry logic and proper error handling.""" + self.host = os.getenv("OLLAMA_HOST", Config().get_ollama_api_endpoint()) + self.client = None + self.models = [] + self._initialize_client() + + def _initialize_client(self, max_retries: int = 3, initial_delay: float = 1.0) -> None: + """Initialize Ollama client with retry logic. + + Args: + max_retries: Maximum number of connection attempts + initial_delay: Initial delay between retries in seconds + """ + delay = initial_delay + for attempt in range(max_retries): + try: + # Validate URL format + parsed_url = urlparse(self.host) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid Ollama server URL: {self.host}") + + # Test server connection + response = requests.get(f"{self.host}/api/version") + if response.status_code != 200: + raise ConnectionError(f"Ollama server returned status {response.status_code}") + + # Initialize client and fetch models + self.client = ollama.Client(self.host) + self.models = self.client.list()["models"] + log.info(f"Ollama available at {self.host}") + log.info(f"Found {len(self.models)} models: {[m['name'] for m in self.models]}") + return + + except requests.exceptions.ConnectionError as e: + log.warning(f"Connection failed to Ollama server at {self.host}") + log.warning(f"Error: {str(e)}") + + except ValueError as e: + log.error(f"Configuration error: {str(e)}") + return + + except Exception as e: + log.warning(f"Failed to initialize Ollama client: {str(e)}") + + if attempt < max_retries - 1: + log.info(f"Retrying in {delay:.1f} seconds...") + time.sleep(delay) + delay *= 2 # Exponential backoff + else: + log.warning("Max retries reached. Please ensure Ollama server is running") + log.warning("Run 'ollama serve' to start the server") + log.warning("Or set OLLAMA_HOST environment variable to correct server URL") + + self.client = None + self.models = [] def inference(self, model_id: str, prompt: str) -> str: - response = self.client.generate( - model=model_id, - prompt=prompt.strip(), - options={"temperature": 0} - ) - return response['response'] + """Run inference using specified model. + + Args: + model_id: Name of the Ollama model to use + prompt: Input prompt for the model + + Returns: + Model response text + + Raises: + RuntimeError: If client is not initialized or model is not found + """ + if not self.client: + raise RuntimeError("Ollama client not initialized. Please check server connection.") + + if not any(m['name'] == model_id for m in self.models): + raise RuntimeError(f"Model {model_id} not found in available models: {[m['name'] for m in self.models]}") + + try: + response = self.client.generate( + model=model_id, + prompt=prompt.strip(), + options={"temperature": 0} + ) + return response['response'] + + except Exception as e: + log.error(f"Inference failed for model {model_id}: {str(e)}") + raise RuntimeError(f"Failed to get response from Ollama: {str(e)}") diff --git a/tests/test_ollama.py b/tests/test_ollama.py new file mode 100644 index 00000000..afc9c9d9 --- /dev/null +++ b/tests/test_ollama.py @@ -0,0 +1,129 @@ +import pytest +import os +import requests +from unittest.mock import patch, MagicMock +from src.llm.ollama_client import Ollama +from src.config import Config + +def test_ollama_client_initialization(): + """Test Ollama client initialization with default config""" + with patch('requests.get') as mock_get, \ + patch('ollama.Client') as mock_client: + mock_get.return_value = MagicMock(status_code=200) + mock_client.return_value.list.return_value = {"models": []} + + client = Ollama() + assert client.host == Config().get_ollama_api_endpoint() + assert client.client is not None + assert isinstance(client.models, list) + +def test_ollama_client_initialization_with_env(): + """Test Ollama client initialization with environment variable""" + with patch('requests.get') as mock_get, \ + patch('ollama.Client') as mock_client, \ + patch.dict(os.environ, {'OLLAMA_HOST': 'http://ollama-service:11434'}): + mock_get.return_value = MagicMock(status_code=200) + mock_client.return_value.list.return_value = {"models": []} + + client = Ollama() + assert client.host == "http://ollama-service:11434" + assert client.client is not None + +def test_ollama_client_connection_retry(): + """Test Ollama client connection retry logic""" + with patch('requests.get') as mock_get, \ + patch('ollama.Client') as mock_client, \ + patch('time.sleep') as mock_sleep: + # Simulate first two failures, then success + mock_get.side_effect = [ + requests.exceptions.ConnectionError(), + requests.exceptions.ConnectionError(), + MagicMock(status_code=200) + ] + mock_client.return_value.list.return_value = {"models": []} + + client = Ollama() + assert client.client is not None + assert mock_get.call_count == 3 + assert mock_sleep.call_count == 2 + +def test_ollama_client_invalid_url(): + """Test Ollama client with invalid URL""" + with patch.dict(os.environ, {'OLLAMA_HOST': 'invalid-url'}): + client = Ollama() + assert client.client is None + assert len(client.models) == 0 + +def test_ollama_client_models_list(): + """Test Ollama client models list retrieval""" + mock_models = { + "models": [ + {"name": "llama2"}, + {"name": "codellama"} + ] + } + with patch('requests.get') as mock_get, \ + patch('ollama.Client') as mock_client: + mock_get.return_value = MagicMock(status_code=200) + mock_client.return_value.list.return_value = mock_models + + client = Ollama() + assert len(client.models) == 2 + assert client.models[0]["name"] == "llama2" + assert client.models[1]["name"] == "codellama" + +def test_ollama_client_inference(): + """Test Ollama client inference""" + mock_models = { + "models": [ + {"name": "llama2"} + ] + } + mock_response = { + "response": "Test response" + } + with patch('requests.get') as mock_get, \ + patch('ollama.Client') as mock_client: + mock_get.return_value = MagicMock(status_code=200) + mock_client.return_value.list.return_value = mock_models + mock_client.return_value.generate.return_value = mock_response + + client = Ollama() + response = client.inference("llama2", "Test prompt") + assert response == "Test response" + mock_client.return_value.generate.assert_called_once() + +def test_ollama_client_inference_invalid_model(): + """Test Ollama client inference with invalid model""" + mock_models = { + "models": [ + {"name": "llama2"} + ] + } + with patch('requests.get') as mock_get, \ + patch('ollama.Client') as mock_client: + mock_get.return_value = MagicMock(status_code=200) + mock_client.return_value.list.return_value = mock_models + + client = Ollama() + with pytest.raises(RuntimeError) as exc_info: + client.inference("invalid-model", "Test prompt") + assert "Model invalid-model not found" in str(exc_info.value) + +def test_ollama_client_inference_server_error(): + """Test Ollama client inference with server error""" + mock_models = { + "models": [ + {"name": "llama2"} + ] + } + with patch('requests.get') as mock_get, \ + patch('ollama.Client') as mock_client: + mock_get.return_value = MagicMock(status_code=200) + mock_client.return_value.list.return_value = mock_models + mock_client.return_value.generate.side_effect = Exception("Server error") + + client = Ollama() + with pytest.raises(RuntimeError) as exc_info: + client.inference("llama2", "Test prompt") + assert "Failed to get response from Ollama" in str(exc_info.value)