Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(openai): implement voice mode #315

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions basilisk/conversation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Literal
from typing import Any, Literal, Optional

from pydantic import BaseModel, Field

Expand All @@ -27,8 +27,8 @@ class TextMessageContent(BaseModel):

class Message(BaseModel):
role: MessageRoleEnum
content: list[TextMessageContent | ImageUrlMessageContent] | str = Field(
discrminator="type"
content: list[TextMessageContent | ImageUrlMessageContent] | str | Any = (
Field(discrminator="type")
)


Expand All @@ -39,6 +39,8 @@ class MessageBlock(BaseModel):
temperature: float = Field(default=1)
max_tokens: int = Field(default=4096)
top_p: float = Field(default=1)
modalities: Optional[list[str]] = Field(default=None)
audio: Optional[dict[str, str]] = Field(default=None)
stream: bool = Field(default=False)
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
Expand Down
57 changes: 35 additions & 22 deletions basilisk/gui/base_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def GetName(self, childId):
class BaseConversation:
def __init__(self):
self.accounts_engines: dict[UUID, BaseEngine] = {}
self.voice_mode = False

@property
def current_engine(self) -> Optional[BaseEngine]:
Expand All @@ -39,14 +40,14 @@ def current_engine(self) -> Optional[BaseEngine]:
return None
return self.accounts_engines[account.id]

def create_account_widget(self) -> wx.StaticText:
def create_account_widget(self, parent: wx.Window) -> wx.StaticText:
label = wx.StaticText(
self,
parent,
# Translators: This is a label for account in the main window
label=_("&Account:"),
)
self.account_combo = wx.ComboBox(
self, style=wx.CB_READONLY, choices=self.get_display_accounts()
parent, style=wx.CB_READONLY, choices=self.get_display_accounts()
)
self.account_combo.Bind(wx.EVT_COMBOBOX, self.on_account_change)
return label
Expand Down Expand Up @@ -95,22 +96,22 @@ def on_account_change(self, event) -> Optional[config.Account]:
self.update_model_list()
return account

def create_system_prompt_widget(self) -> wx.StaticText:
def create_system_prompt_widget(self, parent: wx.Window) -> wx.StaticText:
label = wx.StaticText(
self,
parent,
# Translators: This is a label for system prompt in the main window
label=_("S&ystem prompt:"),
)
self.system_prompt_txt = wx.TextCtrl(
self,
parent,
size=(800, 100),
style=wx.TE_MULTILINE | wx.TE_WORDWRAP | wx.HSCROLL,
)
return label

def create_model_widget(self) -> wx.StaticText:
label = wx.StaticText(self, label=_("M&odels:"))
self.model_list = wx.ListCtrl(self, style=wx.LC_REPORT)
def create_model_widget(self, parent: wx.Window) -> wx.StaticText:
label = wx.StaticText(parent, label=_("M&odels:"))
self.model_list = wx.ListCtrl(parent, style=wx.LC_REPORT)
# Translators: This label appears in the main window's list of models
self.model_list.InsertColumn(0, _("Name"))
# Translators: This label appears in the main window's list of models to indicate whether the model supports images
Expand All @@ -137,7 +138,11 @@ def get_display_models(self) -> list[tuple[str, str, str]]:
engine = self.current_engine
if not engine:
return []
return [m.display_model for m in engine.models]
return [
m.display_model
for m in engine.models
if m.voice_mode == self.voice_mode
]

def set_model_list(self, model: Optional[ProviderAIModel]):
engine = self.current_engine
Expand All @@ -163,7 +168,15 @@ def current_model(self) -> Optional[ProviderAIModel]:
model_index = self.model_list.GetFirstSelected()
if model_index == wx.NOT_FOUND:
return None
return engine.models[model_index]
filtered_models = [
model
for model in engine.models
if model.voice_mode == self.voice_mode
]
try:
return filtered_models[model_index]
except IndexError:
return None

def on_model_change(self, event):
model = self.current_model
Expand Down Expand Up @@ -235,24 +248,24 @@ def on_show_model_details(self, event: wx.CommandEvent):
dlg.ShowModal()
dlg.Destroy()

def create_max_tokens_widget(self) -> wx.StaticText:
def create_max_tokens_widget(self, parent: wx.Window) -> wx.StaticText:
self.max_tokens_spin_label = wx.StaticText(
self,
parent,
# Translators: This is a label for max tokens in the main window
label=_("Max to&kens:"),
)
self.max_tokens_spin_ctrl = wx.SpinCtrl(
self, value='0', min=0, max=2000000
parent, value='0', min=0, max=2000000
)

def create_temperature_widget(self) -> wx.StaticText:
def create_temperature_widget(self, parent: wx.Window) -> wx.StaticText:
self.temperature_spinner_label = wx.StaticText(
self,
parent,
# Translators: This is a label for temperature in the main window
label=_("&Temperature:"),
)
self.temperature_spinner = FloatSpin(
self,
parent,
min_val=0.0,
max_val=2.0,
increment=0.01,
Expand All @@ -266,14 +279,14 @@ def create_temperature_widget(self) -> wx.StaticText:
)
self.temperature_spinner._textctrl.SetAccessible(float_spin_accessible)

def create_top_p_widget(self) -> wx.StaticText:
def create_top_p_widget(self, parent: wx.Window) -> wx.StaticText:
self.top_p_spinner_label = wx.StaticText(
self,
parent,
# Translators: This is a label for top P in the main window
label=_("Probabilit&y Mass (top P):"),
)
self.top_p_spinner = FloatSpin(
self,
parent,
min_val=0.0,
max_val=1.0,
increment=0.01,
Expand All @@ -287,9 +300,9 @@ def create_top_p_widget(self) -> wx.StaticText:
)
self.top_p_spinner._textctrl.SetAccessible(float_spin_accessible)

def create_stream_widget(self):
def create_stream_widget(self, parent: wx.Window) -> wx.CheckBox:
self.stream_mode = wx.CheckBox(
self,
parent,
# Translators: This is a label for stream mode in the main window
label=_("&Stream mode"),
)
Expand Down
29 changes: 15 additions & 14 deletions basilisk/gui/conversation_profile_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
class EditConversationProfileDialog(wx.Dialog, BaseConversation):
def __init__(
self,
parent,
parent: wx.Window,
title: str,
size=(400, 400),
size: tuple[int, int] = (800, 600),
profile: Optional[ConversationProfile] = None,
):
wx.Dialog.__init__(self, parent=parent, title=title, size=size)
Expand All @@ -26,46 +26,47 @@ def __init__(
self.adjust_advanced_mode_setting()

def init_ui(self):
panel = wx.Panel(self)
self.sizer = wx.BoxSizer(wx.VERTICAL)

label = wx.StaticText(
self,
panel,
# translators: Label for the name of a conversation profile
label=_("profile &name:"),
)
self.sizer.Add(label, 0, wx.ALL, 5)

self.profile_name_txt = wx.TextCtrl(self)
self.profile_name_txt = wx.TextCtrl(panel)
self.sizer.Add(self.profile_name_txt, 0, wx.ALL | wx.EXPAND, 5)

label = self.create_account_widget()
label = self.create_account_widget(panel)
self.sizer.Add(label, 0, wx.ALL, 5)
self.sizer.Add(self.account_combo, 0, wx.ALL | wx.EXPAND, 5)
self.include_account_checkbox = wx.CheckBox(
self,
panel,
# translators: Label for including an account in a conversation profile
label=_("&Include account in profile"),
)
self.sizer.Add(self.include_account_checkbox, 0, wx.ALL, 5)
label = self.create_system_prompt_widget()
label = self.create_system_prompt_widget(panel)
self.sizer.Add(label, 0, wx.ALL, 5)
self.sizer.Add(self.system_prompt_txt, 0, wx.ALL | wx.EXPAND, 5)
label = self.create_model_widget()
label = self.create_model_widget(panel)
self.sizer.Add(label, 0, wx.ALL, 5)
self.sizer.Add(self.model_list, 0, wx.ALL | wx.EXPAND, 5)
self.create_max_tokens_widget()
self.create_max_tokens_widget(panel)
self.sizer.Add(self.max_tokens_spin_label, 0, wx.ALL, 5)
self.sizer.Add(self.max_tokens_spin_ctrl, 0, wx.ALL | wx.EXPAND, 5)
self.create_temperature_widget()
self.create_temperature_widget(panel)
self.sizer.Add(self.temperature_spinner_label, 0, wx.ALL, 5)
self.sizer.Add(self.temperature_spinner, 0, wx.ALL | wx.EXPAND, 5)
self.create_top_p_widget()
self.create_top_p_widget(panel)
self.sizer.Add(self.top_p_spinner_label, 0, wx.ALL, 5)
self.sizer.Add(self.top_p_spinner, 0, wx.ALL | wx.EXPAND, 5)
self.create_stream_widget()
self.create_stream_widget(panel)
self.sizer.Add(self.stream_mode, 0, wx.ALL | wx.EXPAND, 5)
self.ok_button = wx.Button(self, wx.ID_OK)
self.cancel_button = wx.Button(self, wx.ID_CANCEL)
self.ok_button = wx.Button(panel, wx.ID_OK)
self.cancel_button = wx.Button(panel, wx.ID_CANCEL)
self.Bind(wx.EVT_BUTTON, self.on_ok, self.ok_button)
self.Bind(wx.EVT_BUTTON, self.on_cancel, self.cancel_button)
self.SetDefaultItem(self.ok_button)
Expand Down
32 changes: 24 additions & 8 deletions basilisk/gui/conversation_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
wx.Panel.__init__(self, parent)
BaseConversation.__init__(self)
self.title = title
self.SetStatusText = parent.GetParent().GetParent().SetStatusText
self.SetStatusText = parent.GetTopLevelParent().SetStatusText
self.conversation = Conversation()
self.image_files = []
self.last_time = 0
Expand All @@ -94,11 +94,11 @@ def __init__(

def init_ui(self):
sizer = wx.BoxSizer(wx.VERTICAL)
label = self.create_account_widget()
label = self.create_account_widget(self)
sizer.Add(label, proportion=0, flag=wx.EXPAND)
sizer.Add(self.account_combo, proportion=0, flag=wx.EXPAND)

label = self.create_system_prompt_widget()
label = self.create_system_prompt_widget(self)
sizer.Add(label, proportion=0, flag=wx.EXPAND)
sizer.Add(self.system_prompt_txt, proportion=1, flag=wx.EXPAND)

Expand Down Expand Up @@ -157,19 +157,19 @@ def init_ui(self):
self.images_list.SetColumnWidth(2, 100)
self.images_list.SetColumnWidth(3, 200)
sizer.Add(self.images_list, proportion=0, flag=wx.ALL | wx.EXPAND)
label = self.create_model_widget()
label = self.create_model_widget(self)
sizer.Add(label, proportion=0, flag=wx.EXPAND)
sizer.Add(self.model_list, proportion=0, flag=wx.ALL | wx.EXPAND)
self.create_max_tokens_widget()
self.create_max_tokens_widget(self)
sizer.Add(self.max_tokens_spin_label, proportion=0, flag=wx.EXPAND)
sizer.Add(self.max_tokens_spin_ctrl, proportion=0, flag=wx.EXPAND)
self.create_temperature_widget()
self.create_temperature_widget(self)
sizer.Add(self.temperature_spinner_label, proportion=0, flag=wx.EXPAND)
sizer.Add(self.temperature_spinner, proportion=0, flag=wx.EXPAND)
self.create_top_p_widget()
self.create_top_p_widget(self)
sizer.Add(self.top_p_spinner_label, proportion=0, flag=wx.EXPAND)
sizer.Add(self.top_p_spinner, proportion=0, flag=wx.EXPAND)
self.create_stream_widget()
self.create_stream_widget(self)
sizer.Add(self.stream_mode, proportion=0, flag=wx.EXPAND)

btn_sizer = wx.BoxSizer(wx.HORIZONTAL)
Expand Down Expand Up @@ -983,6 +983,22 @@ def stop_recording(self):
self.toggle_record_btn.SetLabel(_("Record") + " (Ctrl+R)")
self.submit_btn.Enable()

def on_voice_mode(self, event: wx.CommandEvent = None):
cur_provider = self.current_engine
if ProviderCapability.VOICE not in cur_provider.capabilities:
wx.MessageBox(
_("The selected provider does not support voice mode"),
_("Error"),
wx.OK | wx.ICON_ERROR,
)
return
from .conversation_voice_mode_dialog import ConversationVoiceModeDialog

account = self.current_account
voice_mode_dialog = ConversationVoiceModeDialog(self, account)
voice_mode_dialog.ShowModal()
voice_mode_dialog.Destroy()

@ensure_no_task_running
def on_submit(self, event: wx.CommandEvent):
if not self.submit_btn.IsEnabled():
Expand Down
Loading