Skip to content

Commit

Permalink
Implementing Encoder Begin Callback for golang binding
Browse files Browse the repository at this point in the history
Adding in EncoderBeginCallback to the Context's Process callback.
This optional callback function returns false if computation should
be aborted.
  • Loading branch information
Amanda Der Bedrosian authored and aderbedr committed Jun 27, 2024
1 parent c118733 commit 4c8651a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 20 deletions.
2 changes: 1 addition & 1 deletion bindings/go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
if err != nil {
panic(err)
}
if err := context.Process(samples, nil, nil); err != nil {
if err := context.Process(samples, nil, nil, nil); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion bindings/go/examples/go-whisper/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb, nil); err != nil {
if err := context.Process(data, nil, cb, nil); err != nil {
return err
}

Expand Down
35 changes: 19 additions & 16 deletions bindings/go/pkg/whisper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
// Process new sample data and return any errors
func (context *context) Process(
data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
Expand All @@ -177,30 +178,32 @@ func (context *context) Process(
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
}); err != nil {
return err
}

Expand Down
8 changes: 6 additions & 2 deletions bindings/go/pkg/whisper/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
// processing. It is called during the Process function
type ProgressCallback func(int)

// EncoderBeginCallback is the callback function for checking if we want to
// continue processing. It is called during the Process function
type EncoderBeginCallback func() bool

// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
Expand All @@ -31,7 +35,7 @@ type Model interface {
Languages() []string
}

// Context is the speach recognition context.
// Context is the speech recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag
Expand All @@ -53,7 +57,7 @@ type Context interface {
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback, ProgressCallback) error
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error

// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
Expand Down

0 comments on commit 4c8651a

Please sign in to comment.