Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

音声データを Ogg ファイルで出力できる機能の追加 #199

Merged
merged 11 commits into from
Jan 8, 2025
9 changes: 9 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

## develop

- [ADD] 受信した音声データを Ogg ファイルで保存するかを指定する enable_ogg_file_output を追加する
- 保存するファイル名は、sora-session-id ヘッダーと sora-connection-id ヘッダーの値を使用して作成する
- ${sora-session-id}-${sora-connection-id}.ogg
- デフォルト値: false
- @Hexa
- [ADD] 受信した音声データを Ogg ファイルで保存する場合の保存先ディレクトリを指定する ogg_dir を追加する
- デフォルト値: .
- @Hexa

### misc

- [CHANGE] GitHub Actions の ubuntu-latest を ubuntu-24.04 に変更する
Expand Down
7 changes: 5 additions & 2 deletions amazon_transcribe_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,13 @@ func (h *AmazonTranscribeHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *AmazonTranscribeHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) {
func (h *AmazonTranscribeHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) {
at := NewAmazonTranscribe(h.Config, h.LanguageCode, int64(h.SampleRate), int64(h.ChannelCount))

packetReader := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config)
packetReader, err := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config, header)
if err != nil {
return nil, err
}

stream, err := at.Start(ctx, packetReader)
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ type Config struct {
SampleRate int `ini:"audio_sample_rate"`
ChannelCount int `ini:"audio_channel_count"`

EnableOggFileOutput bool `ini:"enable_ogg_file_output"`
OggDir string `ini:"ogg_dir"`

DumpFile string `ini:"dump_file"`

LogDir string `ini:"log_dir"`
Expand Down Expand Up @@ -173,6 +176,10 @@ func setDefaultsConfig(config *Config) {
if config.RetryIntervalMs == 0 {
config.RetryIntervalMs = DefaultRetryIntervalMs
}

if config.OggDir == "" {
config.OggDir = "."
}
}

func validateConfig(config *Config) error {
Expand Down
4 changes: 4 additions & 0 deletions config_example.ini
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ retry_interval_ms = 100
# aws の場合は IsPartial が false, gcp の場合は IsFinal が true の場合の最終的な結果のみを返す指定
final_result_only = true

# 受信した音声データを Ogg ファイルで保存するかどうかです
enable_ogg_file_output = false
# Ogg ファイルの保存先ディレクトリです
ogg_dir = "."

# 採用する結果の信頼スコアの最小値です(aws 指定時のみ有効)
# minimum_confidence_score が 0.0 の場合は信頼スコアによるフィルタリングは無効です
Expand Down
52 changes: 39 additions & 13 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"io"
"net/http"
"os"
"path"
"strings"
"time"

Expand Down Expand Up @@ -41,6 +43,16 @@ func NewSuzuErrorResponse(err error) TranscriptionResult {
}
}

type soraHeader struct {
SoraChannelID string `header:"sora-channel-id"`
SoraSessionID string `header:"sora-session-id"`
// SoraClientID string `header:"sora-client-id"`
SoraConnectionID string `header:"sora-connection-id"`
// SoraAudioCodecType string `header:"sora-audio-codec-type"`
// SoraAudioSampleRate int64 `header:"sora-audio-sample-rate"`
SoraAudioStreamingLanguageCode string `header:"sora-audio-streaming-language-code"`
}

func getServiceHandler(serviceType string, config Config, channelID, connectionID string, sampleRate uint32, channelCount uint16, languageCode string, onResultFunc any) (serviceHandlerInterface, error) {
newHandlerFunc, err := NewServiceHandlerFuncs.get(serviceType)
if err != nil {
Expand All @@ -65,15 +77,7 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
return echo.NewHTTPError(http.StatusBadRequest)
}

h := struct {
SoraChannelID string `header:"Sora-Channel-Id"`
// SoraSessionID string `header:"sora-session-id"`
// SoraClientID string `header:"sora-client-id"`
SoraConnectionID string `header:"sora-connection-id"`
// SoraAudioCodecType string `header:"sora-audio-codec-type"`
// SoraAudioSampleRate int64 `header:"sora-audio-sample-rate"`
SoraAudioStreamingLanguageCode string `header:"sora-audio-streaming-language-code"`
}{}
h := soraHeader{}
if err := (&echo.DefaultBinder{}).BindHeaders(c, &h); err != nil {
zlog.Error().
Err(err).
Expand Down Expand Up @@ -153,7 +157,7 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte
serviceHandlerCtx, cancelServiceHandler := context.WithCancel(ctx)
defer cancelServiceHandler()

reader, err := serviceHandler.Handle(serviceHandlerCtx, opusCh)
reader, err := serviceHandler.Handle(serviceHandlerCtx, opusCh, h)
if err != nil {
zlog.Error().
Err(err).
Expand Down Expand Up @@ -459,17 +463,39 @@ func readOpus(ctx context.Context, reader io.Reader) chan opusChannel {
return opusCh
}

func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, channelCount uint16, c Config) io.ReadCloser {
func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, channelCount uint16, c Config, header soraHeader) (io.ReadCloser, error) {
oggReader, oggWriter := io.Pipe()

writers := []io.Writer{}

var f *os.File
if c.EnableOggFileOutput {
fileName := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID)
filePath := path.Join(c.OggDir, fileName)

var err error
f, err = os.Create(filePath)
if err != nil {
return nil, err
}
writers = append(writers, f)
}
writers = append(writers, oggWriter)

multiWriter := io.MultiWriter(writers...)

go func() {
o, err := NewWith(oggWriter, sampleRate, channelCount)
o, err := NewWith(multiWriter, sampleRate, channelCount)
if err != nil {
oggWriter.CloseWithError(err)
return
}
defer o.Close()

if c.EnableOggFileOutput {
o.fd = f
}

for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -501,7 +527,7 @@ func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, c
}
}()

return oggReader
return oggReader, nil
}

type opusRequest struct {
Expand Down
165 changes: 165 additions & 0 deletions handler_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package suzu

import (
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"testing"
"time"

Expand Down Expand Up @@ -310,3 +314,164 @@ func TestReadPacketWithHeader(t *testing.T) {
})
}
}

func TestOggFileWriting(t *testing.T) {
t.Run("success", func(t *testing.T) {
oggDir, err := os.MkdirTemp("", "ogg-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(oggDir)

c := Config{
EnableOggFileOutput: true,
OggDir: oggDir,
}

header := soraHeader{
SoraChannelID: "ogg-test",
SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM",
SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G",
}

opusCh := make(chan opusChannel)
defer close(opusCh)

sampleRate := uint32(48000)
channelCount := uint16(1)

ctx := context.Background()
reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header)
if assert.NoError(t, err) {
assert.NotNil(t, reader)
}
defer reader.Close()

// ファイルへの書き込み待ち
time.Sleep(100 * time.Millisecond)

filename := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID)
filePath := filepath.Join(oggDir, filename)
_, err = os.Stat(filePath)
assert.NoError(t, err)

// Ogg ファイルのヘッダーを確認
f, err := os.Open(filePath)
if err != nil {
t.Fatal(err)
}
defer f.Close()

buf := make([]byte, 4)
n, err := f.Read(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(`OggS`), buf[:n])
})

t.Run("disable_ogg_file_output", func(t *testing.T) {
oggDir, err := os.MkdirTemp("", "ogg-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(oggDir)

c := Config{
EnableOggFileOutput: false,
OggDir: oggDir,
}

header := soraHeader{
SoraChannelID: "ogg-test",
SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM",
SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G",
}

opusCh := make(chan opusChannel)
defer close(opusCh)

sampleRate := uint32(48000)
channelCount := uint16(1)

ctx := context.Background()
reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header)
assert.NoError(t, err)
assert.NotNil(t, reader)
defer reader.Close()

filename := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID)
filePath := filepath.Join(oggDir, filename)
_, err = os.Stat(filePath)
assert.ErrorIs(t, err, os.ErrNotExist)
})

t.Run("no permission", func(t *testing.T) {
oggDir, err := os.MkdirTemp("", "ogg-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(oggDir)

// 書き込み権限を剥奪
if err := os.Chmod(oggDir, 0000); err != nil {
t.Fatal(err)
}
defer func() {
if err := os.Chmod(oggDir, 0700); err != nil {
t.Fatal(err)
}
}()

c := Config{
EnableOggFileOutput: true,
OggDir: oggDir,
}

header := soraHeader{
SoraChannelID: "ogg-test",
SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM",
SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G",
}

opusCh := make(chan opusChannel)
defer close(opusCh)

sampleRate := uint32(48000)
channelCount := uint16(1)

ctx := context.Background()
reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header)
assert.ErrorIs(t, err, os.ErrPermission)
assert.Nil(t, reader)
})

t.Run("directory does not exist", func(t *testing.T) {
oggDir, err := os.MkdirTemp("", "ogg-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(oggDir)

c := Config{
EnableOggFileOutput: true,
// 既存のディレクトリ名に 0 を付与して存在しないディレクトリを指定する
OggDir: oggDir + "0",
}

header := soraHeader{
SoraChannelID: "ogg-test",
SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM",
SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G",
}

opusCh := make(chan opusChannel)
defer close(opusCh)

sampleRate := uint32(48000)
channelCount := uint16(1)

ctx := context.Background()
reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header)
assert.ErrorIs(t, err, os.ErrNotExist)
assert.Nil(t, reader)
})
}
2 changes: 1 addition & 1 deletion packet_dump_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (h *PacketDumpHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *PacketDumpHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) {
func (h *PacketDumpHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) {
c := h.Config
filename := c.DumpFile
channelID := h.ChannelID
Expand Down
2 changes: 1 addition & 1 deletion service_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var (
)

type serviceHandlerInterface interface {
Handle(context.Context, chan opusChannel) (*io.PipeReader, error)
Handle(context.Context, chan opusChannel, soraHeader) (*io.PipeReader, error)
UpdateRetryCount() int
GetRetryCount() int
ResetRetryCount() int
Expand Down
7 changes: 5 additions & 2 deletions speech_to_text_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ func (h *SpeechToTextHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *SpeechToTextHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) {
func (h *SpeechToTextHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) {
stt := NewSpeechToText(h.Config, h.LanguageCode, int32(h.SampleRate), int32(h.ChannelCount))

packetReader := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config)
packetReader, err := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config, header)
if err != nil {
return nil, err
}

stream, err := stt.Start(ctx, packetReader)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion test_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (h *TestHandler) ResetRetryCount() int {
return h.RetryCount
}

func (h *TestHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) {
func (h *TestHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) {
r, w := io.Pipe()

reader := opusChannelToIOReadCloser(ctx, opusCh)
Expand Down