From 28273d4bd6309a3c4b54bbee086086e5c4c1c2ef Mon Sep 17 00:00:00 2001 From: Binozo <70137898+Binozo@users.noreply.github.com> Date: Fri, 20 Sep 2024 14:45:36 +0200 Subject: [PATCH] go : add temperature options (#2417) * Fixed go cuda bindings building * Added note to go bindings Readme to build using cuda support * Added temperature bindings for Go --------- Co-authored-by: Binozo --- bindings/go/params.go | 12 +++++++++++ bindings/go/pkg/whisper/context.go | 11 ++++++++++ bindings/go/pkg/whisper/interface.go | 30 +++++++++++++++------------- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/bindings/go/params.go b/bindings/go/params.go index 9c075b6a2cb..95c5bfaf934 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -131,6 +131,16 @@ func (p *Params) SetEntropyThold(t float32) { p.entropy_thold = C.float(t) } +func (p *Params) SetTemperature(t float32) { + p.temperature = C.float(t) +} + +// Sets the fallback temperature incrementation +// Pass -1.0 to disable this feature +func (p *Params) SetTemperatureFallback(t float32) { + p.temperature_inc = C.float(t) +} + // Set initial prompt func (p *Params) SetInitialPrompt(prompt string) { p.initial_prompt = C.CString(prompt) @@ -162,6 +172,8 @@ func (p *Params) String() string { str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx) str += fmt.Sprintf(" initial_prompt=%s", C.GoString(p.initial_prompt)) str += fmt.Sprintf(" entropy_thold=%f", p.entropy_thold) + str += fmt.Sprintf(" temperature=%f", p.temperature) + str += fmt.Sprintf(" temperature_inc=%f", p.temperature_inc) str += fmt.Sprintf(" beam_size=%d", p.beam_search.beam_size) if p.translate { str += " translate" diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index dc34aa18bb8..06376b1b870 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -140,6 +140,17 @@ func (context *context) SetEntropyThold(t float32) { context.params.SetEntropyThold(t) } +// Set Temperature +func (context *context) SetTemperature(t float32) { + context.params.SetTemperature(t) +} + +// Set the fallback temperature incrementation +// Pass -1.0 to disable this feature +func (context *context) SetTemperatureFallback(t float32) { + context.params.SetTemperatureFallback(t) +} + // Set initial prompt func (context *context) SetInitialPrompt(prompt string) { context.params.SetInitialPrompt(prompt) diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index 6eb692ef610..8981b1a8116 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -38,20 +38,22 @@ type Context interface { IsMultilingual() bool // Return true if the model is multilingual. Language() string // Get language - SetOffset(time.Duration) // Set offset - SetDuration(time.Duration) // Set duration - SetThreads(uint) // Set number of threads to use - SetSplitOnWord(bool) // Set split on word flag - SetTokenThreshold(float32) // Set timestamp token probability threshold - SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold - SetMaxSegmentLength(uint) // Set max segment length in characters - SetTokenTimestamps(bool) // Set token timestamps flag - SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit) - SetAudioCtx(uint) // Set audio encoder context - SetMaxContext(n int) // Set maximum number of text context tokens to store - SetBeamSize(n int) // Set Beam Size - SetEntropyThold(t float32) // Set Entropy threshold - SetInitialPrompt(prompt string) // Set initial prompt + SetOffset(time.Duration) // Set offset + SetDuration(time.Duration) // Set duration + SetThreads(uint) // Set number of threads to use + SetSplitOnWord(bool) // Set split on word flag + SetTokenThreshold(float32) // Set timestamp token probability threshold + SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold + SetMaxSegmentLength(uint) // Set max segment length in characters + SetTokenTimestamps(bool) // Set token timestamps flag + SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit) + SetAudioCtx(uint) // Set audio encoder context + SetMaxContext(n int) // Set maximum number of text context tokens to store + SetBeamSize(n int) // Set Beam Size + SetEntropyThold(t float32) // Set Entropy threshold + SetInitialPrompt(prompt string) // Set initial prompt + SetTemperature(t float32) // Set temperature + SetTemperatureFallback(t float32) // Set temperature incrementation // Process mono audio data and return any errors. // If defined, newly generated segments are passed to the