Skip to content

Commit

Permalink
feat(gui): enhance temperature and top_p control precision (#112)
Browse files Browse the repository at this point in the history
* feat(gui): enhance temperature and top_p control precision

* fix: correct max_temperature handling as float across app

* fix: removed unused `SetName` method in `FloatSpinTextCtrlAccessible` class

* fix: standardize temperature settings across models

* fix: reorder import statements
  • Loading branch information
AAClause authored Jul 7, 2024
1 parent 0e6fad0 commit b9d0bb1
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 13 deletions.
54 changes: 44 additions & 10 deletions basilisk/gui/conversation_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from uuid import UUID

import wx
from wx.lib.agw.floatspin import FloatSpin

import basilisk.config as config
from basilisk import global_vars
Expand All @@ -31,6 +32,17 @@
log = logging.getLogger(__name__)


class FloatSpinTextCtrlAccessible(wx.Accessible):
def __init__(self, win: wx.Window = None, name: str = None):
super().__init__(win)
self._name = name

def GetName(self, childId):
if self._name:
return (wx.ACC_OK, self._name)
return super().GetName(childId)


class ConversationTab(wx.Panel):
def __init__(self, parent: wx.Window):
wx.Panel.__init__(self, parent)
Expand Down Expand Up @@ -161,9 +173,20 @@ def init_ui(self):
label=_("&Temperature:"),
)
sizer.Add(self.temperature_label, proportion=0, flag=wx.EXPAND)
self.temperature_spinner = wx.SpinCtrl(
self, value="100", min=0, max=200
)
self.temperature_spinner = FloatSpin(
self,
min_val=0.0,
max_val=2.0,
increment=0.01,
value=0.5,
digits=2,
name="temperature",
)
float_spin_accessible = FloatSpinTextCtrlAccessible(
win=self.temperature_spinner._textctrl,
name=self.temperature_label.GetLabel().replace("&", ""),
)
self.temperature_spinner._textctrl.SetAccessible(float_spin_accessible)
sizer.Add(self.temperature_spinner, proportion=0, flag=wx.EXPAND)

self.top_p_label = wx.StaticText(
Expand All @@ -172,7 +195,20 @@ def init_ui(self):
label=_("Probabilit&y Mass (top P):"),
)
sizer.Add(self.top_p_label, proportion=0, flag=wx.EXPAND)
self.top_p_spinner = wx.SpinCtrl(self, value="100", min=0, max=100)
self.top_p_spinner = FloatSpin(
self,
min_val=0.0,
max_val=1.0,
increment=0.01,
value=1.0,
digits=2,
name="Top P",
)
float_spin_accessible = FloatSpinTextCtrlAccessible(
win=self.top_p_spinner._textctrl,
name=self.top_p_label.GetLabel().replace("&", ""),
)
self.top_p_spinner._textctrl.SetAccessible(float_spin_accessible)
sizer.Add(self.top_p_spinner, proportion=0, flag=wx.EXPAND)

self.stream_mode = wx.CheckBox(
Expand Down Expand Up @@ -424,10 +460,8 @@ def on_model_change(self, event: wx.CommandEvent):
if model_index == wx.NOT_FOUND:
return
model = self.current_engine.models[model_index]
self.temperature_spinner.SetMax(int(model.max_temperature * 100))
self.temperature_spinner.SetValue(
str(int(model.max_temperature / 2 * 100))
)
self.temperature_spinner.SetMax(model.max_temperature)
self.temperature_spinner.SetValue(model.default_temperature)
max_tokens = model.max_output_tokens
if max_tokens < 1:
max_tokens = model.context_window
Expand Down Expand Up @@ -760,8 +794,8 @@ def on_submit(self, event: wx.CommandEvent):
),
),
model=model,
temperature=self.temperature_spinner.GetValue() / 100,
top_p=self.top_p_spinner.GetValue() / 100,
temperature=self.temperature_spinner.GetValue(),
top_p=self.top_p_spinner.GetValue(),
max_tokens=self.max_tokens_spin_ctrl.GetValue(),
stream=self.stream_mode.GetValue(),
)
Expand Down
4 changes: 2 additions & 2 deletions basilisk/provider_ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class ProviderAIModel:
description: str | None = field(default=None)
context_window: int = field(default=0)
max_output_tokens: int = field(default=-1)
max_temperature: float = field(default=2)
default_temperature: float = field(default=1)
max_temperature: float = field(default=1.0)
default_temperature: float = field(default=1.0)
vision: bool = field(default=False)
preview: bool = field(default=False)
extra_info: dict[str, Any] = field(default_factory=dict)
Expand Down
10 changes: 10 additions & 0 deletions basilisk/provider_engine/openai_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def models(self) -> list[ProviderAIModel]:
context_window=128000,
max_output_tokens=4096,
vision=True,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-4-turbo",
Expand All @@ -77,6 +78,7 @@ def models(self) -> list[ProviderAIModel]:
context_window=128000,
max_output_tokens=4096,
vision=True,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-3.5-turbo",
Expand All @@ -86,6 +88,7 @@ def models(self) -> list[ProviderAIModel]:
),
context_window=16385,
max_output_tokens=4096,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-3.5-turbo-0125",
Expand All @@ -95,6 +98,7 @@ def models(self) -> list[ProviderAIModel]:
),
context_window=16385,
max_output_tokens=4096,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-4-turbo-preview",
Expand All @@ -104,6 +108,7 @@ def models(self) -> list[ProviderAIModel]:
),
context_window=128000,
max_output_tokens=4096,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-4-0125-preview",
Expand All @@ -113,6 +118,7 @@ def models(self) -> list[ProviderAIModel]:
),
context_window=128000,
max_output_tokens=4096,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-4-1106-preview",
Expand All @@ -122,6 +128,7 @@ def models(self) -> list[ProviderAIModel]:
),
context_window=128000,
max_output_tokens=4096,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-4-vision-preview",
Expand All @@ -132,6 +139,7 @@ def models(self) -> list[ProviderAIModel]:
context_window=128000,
max_output_tokens=4096,
vision=True,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-4-0613",
Expand All @@ -140,6 +148,7 @@ def models(self) -> list[ProviderAIModel]:
"More capable than any GPT-3.5 ProviderAIModel, able to do more complex tasks, and optimized for chat"
),
max_output_tokens=8192,
max_temperature=2.0,
),
ProviderAIModel(
id="gpt-4-32k-0613",
Expand All @@ -149,6 +158,7 @@ def models(self) -> list[ProviderAIModel]:
),
context_window=32768,
max_output_tokens=8192,
max_temperature=2.0,
),
]

Expand Down
2 changes: 1 addition & 1 deletion basilisk/provider_engine/openrouter_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def models(self) -> list[ProviderAIModel]:
"max_completion_tokens"
)
or -1,
max_temperature=2,
max_temperature=2.0,
vision="#multimodal" in model['description'],
preview="-preview" in model['id'],
extra_info={
Expand Down

0 comments on commit b9d0bb1

Please sign in to comment.