diff --git a/lumigator/python/mzai/backend/backend/api/deps.py b/lumigator/python/mzai/backend/backend/api/deps.py index 9ecef9ccf..de2ade340 100644 --- a/lumigator/python/mzai/backend/backend/api/deps.py +++ b/lumigator/python/mzai/backend/backend/api/deps.py @@ -11,7 +11,7 @@ from backend.db import session_manager from backend.repositories.datasets import DatasetRepository from backend.repositories.jobs import JobRepository, JobResultRepository -from backend.services.completions import MistralCompletionService, OpenAICompletionService +from backend.services.completions import LiteLLMCompletionService, MistralCompletionService from backend.services.datasets import DatasetService from backend.services.jobs import JobService from backend.settings import settings @@ -68,10 +68,10 @@ def get_mistral_completion_service() -> MistralCompletionService: ] -def get_openai_completion_service() -> OpenAICompletionService: - return OpenAICompletionService() +def get_openai_completion_service() -> LiteLLMCompletionService: + return LiteLLMCompletionService() -OpenAICompletionServiceDep = Annotated[ - OpenAICompletionService, Depends(get_mistral_completion_service) +LiteLLMCompletionServiceDep = Annotated[ + LiteLLMCompletionService, Depends(get_mistral_completion_service) ] diff --git a/lumigator/python/mzai/backend/backend/api/routes/completions.py b/lumigator/python/mzai/backend/backend/api/routes/completions.py index dc8375dd0..ac6a1f54e 100644 --- a/lumigator/python/mzai/backend/backend/api/routes/completions.py +++ b/lumigator/python/mzai/backend/backend/api/routes/completions.py @@ -1,7 +1,7 @@ from fastapi import APIRouter from lumigator_schemas.completions import CompletionRequest -from backend.api.deps import MistralCompletionServiceDep, OpenAICompletionServiceDep +from backend.api.deps import LiteLLMCompletionServiceDep, MistralCompletionServiceDep router = APIRouter() @@ -20,5 +20,5 @@ def get_mistral_completion(request: CompletionRequest, service: MistralCompletio @router.post(f"/{VENDOR_OPENAI}") -def get_openai_completion(request: CompletionRequest, service: OpenAICompletionServiceDep): +def get_openai_completion(request: CompletionRequest, service: LiteLLMCompletionServiceDep): return service.get_completions_response(request) diff --git a/lumigator/python/mzai/backend/backend/services/completions.py b/lumigator/python/mzai/backend/backend/services/completions.py index bc62dd8e6..5b15a6e10 100644 --- a/lumigator/python/mzai/backend/backend/services/completions.py +++ b/lumigator/python/mzai/backend/backend/services/completions.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod +import litellm import mistralai.client from lumigator_schemas.completions import CompletionRequest, CompletionResponse from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage -from openai import OpenAI from backend.settings import settings @@ -44,21 +44,20 @@ def get_completions_response(self, request: CompletionRequest) -> CompletionResp return CompletionResponse(text=response) -class OpenAICompletionService(CompletionService): +class LiteLLMCompletionService(CompletionService): def __init__(self): - self.client = OpenAI(api_key=settings.OAI_API_KEY) self.model = "gpt-4o-mini" self.max_tokens = 256 self.temperature = 1 self.top_p = 1 - def get_models(self) -> mistralai.client.ModelList: - response = self.client.list_models() + def get_models(self) -> list[str]: + response = litellm.models_by_provider.get("openai") return response def get_completions_response(self, request: CompletionRequest) -> CompletionResponse: - response = self.client.chat.completions.create( + response = litellm.completion( model=self.model, messages=[ {"role": "system", "content": self.prompt},