Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for custom openai api #43

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker-compose.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ services:
- BING_API_KEY=${BING_API_KEY}
- SERPER_API_KEY=${SERPER_API_KEY}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- CUSTOM_HOST=${CUSTOM_HOST:-http://host.docker.internal:1234/v1}
- CUSTOM_API_KEY=${CUSTOM_API_KEY}
- GROQ_API_KEY=${GROQ_API_KEY}
- ENABLE_LOCAL_MODELS=${ENABLE_LOCAL_MODELS:-True}
- SEARCH_PROVIDER=${SEARCH_PROVIDER:-tavily}
Expand Down
5 changes: 5 additions & 0 deletions src/backend/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def get_llm(model: ChatModel) -> LLM:
)
elif model == ChatModel.LLAMA_3_70B:
return Groq(model=model_mappings[model])
elif model == ChatModel.CUSTOM:
return OpenAI(model=model_mappings[model], # can use https://docs.llamaindex.ai/en/stable/api_reference/llms/openai_like/ instead
api_base=os.environ.get("CUSTOM_HOST", "http://localhost:1234/v1"),
api_key=os.environ.get("CUSTOM_API_KEY", "custom-key")
)
else:
raise ValueError(f"Unknown model: {model}")

Expand Down
3 changes: 3 additions & 0 deletions src/backend/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
LOCAL_GEMMA_MODEL = "gemma:7b"
LOCAL_MISTRAL_MODEL = "mistral"
LOCAL_PHI3_14B = "phi3:14b"
CUSTOM_MODEL = "gpt-4" # llama_index really wants to know model context length so 8k by default.


class ChatModel(str, Enum):
Expand All @@ -21,6 +22,7 @@ class ChatModel(str, Enum):
LOCAL_GEMMA = "gemma"
LOCAL_MISTRAL = "mistral"
LOCAL_PHI3_14B = "phi3:14b"
CUSTOM = "custom-api"


model_mappings: dict[ChatModel, str] = {
Expand All @@ -31,4 +33,5 @@ class ChatModel(str, Enum):
ChatModel.LOCAL_GEMMA: LOCAL_GEMMA_MODEL,
ChatModel.LOCAL_MISTRAL: LOCAL_MISTRAL_MODEL,
ChatModel.LOCAL_PHI3_14B: LOCAL_PHI3_14B,
ChatModel.CUSTOM: CUSTOM_MODEL,
}
8 changes: 8 additions & 0 deletions src/backend/related_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def instructor_client(model: ChatModel) -> instructor.AsyncInstructor:
)
elif model == ChatModel.LLAMA_3_70B:
return instructor.from_groq(groq.AsyncGroq(), mode=instructor.Mode.JSON)
elif model == ChatModel.CUSTOM:
return instructor.from_openai(
openai.AsyncOpenAI(
base_url=os.environ.get("CUSTOM_HOST", "http://localhost:1234/v1"),
api_key=os.environ.get("CUSTOM_API_KEY", "custom-key")
),
mode=instructor.Mode.JSON,
)
else:
raise ValueError(f"Unknown model: {model}")

Expand Down
4 changes: 4 additions & 0 deletions src/backend/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def validate_model(model: ChatModel):
LOCAL_MODELS_ENABLED = strtobool(os.getenv("ENABLE_LOCAL_MODELS", False))
if not LOCAL_MODELS_ENABLED:
raise ValueError("Local models are not enabled")
elif model == ChatModel.CUSTOM:
CUSTOM_HOST = os.getenv("CUSTOM_HOST")
if not CUSTOM_HOST:
raise ValueError("CUSTOM_HOST environment variable not found")
else:
raise ValueError("Invalid model")
return True
1 change: 1 addition & 0 deletions src/frontend/generated/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export enum ChatModel {
GEMMA = "gemma",
MISTRAL = "mistral",
LOCAL_PHI3_14B = "phi3:14b",
CUSTOM = "custom-api",
}

export type ChatRequest = {
Expand Down
7 changes: 7 additions & 0 deletions src/frontend/src/components/model-selection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ const modelMap: Record<ChatModel, Model> = {
smallIcon: <FlameIcon className="w-4 h-4 text-green-500" />,
icon: <FlameIcon className="w-5 h-5 text-green-500" />,
},
[ChatModel.CUSTOM]: {
name: "Custom API",
description: "Custom API",
value: ChatModel.CUSTOM,
smallIcon: <MagicWandIcon className="w-4 h-4 text-[#FF0080]" />,
icon: <MagicWandIcon className="w-5 h-5 text-[#FF0080]" />,
},
};

const localModelMap: Partial<Record<ChatModel, Model>> = _.pickBy(
Expand Down