Skip to content

Commit

Permalink
Add support for LLaVA
Browse files Browse the repository at this point in the history
  • Loading branch information
vietanhdev committed Sep 29, 2024
1 parent 53cffff commit 4657ca6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,22 @@ This assistant can run offline on your local machine, and it respects your priva
- 📝 Text-only models:
- [Llama 3.2](https://github.com/facebookresearch/llama) - 1B, 3B (4/8-bit quantized)
- [Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GGUF) (4-bit quantized)
- And other models that [LlamaCPP](https://github.com/ggerganov/llama.cpp) supports via custom models. [See the list](https://github.com/ggerganov/llama.cpp).

- 🖼️ Multimodal models:
- [Moondream2](https://huggingface.co/vikhyatk/moondream2)
- [MiniCPM-v2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf)
- [LLaVA 1.5/1.6](https://llava-vl.github.io/)
- Besides supported models, you can try other variants via custom models.

## TODO

- [x] 🖼️ Support multimodal model: [moondream2](https://huggingface.co/vikhyatk/moondream2).
- [x] 🗣️ Add wake word detection: "Hey Llama!".
- [x] Custom models: Add support for custom models.
- [x] 📚 Support 5 other text models.
- [x] 🖼️ Support 5 other multimodal models.
- [ ] 🎙️ Add offline STT support: WhisperCPP. [Experimental Code](llama_assistant/speech_recognition_whisper_experimental.py).
- [ ] 📚 Support 5 other text models.
- [ ] 🖼️ Support 5 other multimodal models.
- [ ] 🧠 Knowledge database: Langchain or LlamaIndex?.
- [ ] 🔌 Plugin system for extensibility.
- [ ] 📰 News and weather updates.
Expand Down
16 changes: 16 additions & 0 deletions llama_assistant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@
"repo_id": "vikhyatk/moondream2",
"filename": "*text-model*",
},
{
"model_name": "Llava-1.5",
"model_id": "mys/ggml_llava-v1.5-7b/q4_k",
"model_type": "image",
"model_path": None,
"repo_id": "mys/ggml_llava-v1.5-7b",
"filename": "*q4_k.gguf",
},
{
"model_name": "Llava-1.5",
"model_id": "mys/ggml_llava-v1.5-7b/f16",
"model_type": "image",
"model_path": None,
"repo_id": "mys/ggml_llava-v1.5-7b",
"filename": "*f16.gguf",
},
{
"model_name": "MiniCPM-V-2_6-gguf",
"model_id": "openbmb/MiniCPM-V-2_6-gguf-Q4_K_M",
Expand Down
33 changes: 30 additions & 3 deletions llama_assistant/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import time
from threading import Timer
from llama_cpp import Llama
from llama_cpp.llama_chat_format import MoondreamChatHandler, MiniCPMv26ChatHandler
from llama_cpp.llama_chat_format import (
MoondreamChatHandler,
MiniCPMv26ChatHandler,
Llava15ChatHandler,
Llava16ChatHandler,
)

from llama_assistant import config

Expand Down Expand Up @@ -69,7 +74,7 @@ def load_model(self, model_id: str) -> Optional[Dict]:
n_ctx=2048,
)
elif model.model_type == "image":
if "moondream2" in model.repo_id:
if "moondream2" in model.model_id:
chat_handler = MoondreamChatHandler.from_pretrained(
repo_id="vikhyatk/moondream2",
filename="*mmproj*",
Expand All @@ -80,7 +85,7 @@ def load_model(self, model_id: str) -> Optional[Dict]:
chat_handler=chat_handler,
n_ctx=2048,
)
elif "MiniCPM" in model.repo_id:
elif "MiniCPM" in model.model_id:
chat_handler = MiniCPMv26ChatHandler.from_pretrained(
repo_id=model.repo_id,
filename="*mmproj*",
Expand All @@ -91,6 +96,28 @@ def load_model(self, model_id: str) -> Optional[Dict]:
chat_handler=chat_handler,
n_ctx=2048,
)
elif "llava-v1.5" in model.model_id:
chat_handler = Llava15ChatHandler.from_pretrained(
repo_id=model.repo_id,
filename="*mmproj*",
)
loaded_model = Llama.from_pretrained(
repo_id=model.repo_id,
filename=model.filename,
chat_handler=chat_handler,
n_ctx=2048,
)
elif "llava-v1.6" in model.model_id:
chat_handler = Llava16ChatHandler.from_pretrained(
repo_id=model.repo_id,
filename="*mmproj*",
)
loaded_model = Llama.from_pretrained(
repo_id=model.repo_id,
filename=model.filename,
chat_handler=chat_handler,
n_ctx=2048,
)
else:
print(f"Unsupported model type: {model.model_type}")
return None
Expand Down

0 comments on commit 4657ca6

Please sign in to comment.