diff --git a/CHANGES.md b/CHANGES.md index f916279..dd02009 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,15 @@ ## develop +- [ADD] 受信した音声データを Ogg ファイルで保存するかを指定する enable_ogg_file_output を追加する + - 保存するファイル名は、sora-session-id ヘッダーと sora-connection-id ヘッダーの値を使用して作成する + - ${sora-session-id}-${sora-connection-id}.ogg + - デフォルト値: false + - @Hexa +- [ADD] 受信した音声データを Ogg ファイルで保存する場合の保存先ディレクトリを指定する ogg_dir を追加する + - デフォルト値: . + - @Hexa + ### misc - [CHANGE] GitHub Actions の ubuntu-latest を ubuntu-24.04 に変更する diff --git a/amazon_transcribe_handler.go b/amazon_transcribe_handler.go index adc73cc..737999f 100644 --- a/amazon_transcribe_handler.go +++ b/amazon_transcribe_handler.go @@ -96,10 +96,13 @@ func (h *AmazonTranscribeHandler) ResetRetryCount() int { return h.RetryCount } -func (h *AmazonTranscribeHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) { +func (h *AmazonTranscribeHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) { at := NewAmazonTranscribe(h.Config, h.LanguageCode, int64(h.SampleRate), int64(h.ChannelCount)) - packetReader := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config) + packetReader, err := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config, header) + if err != nil { + return nil, err + } stream, err := at.Start(ctx, packetReader) if err != nil { diff --git a/config.go b/config.go index aeba0ad..cd138e8 100644 --- a/config.go +++ b/config.go @@ -65,6 +65,9 @@ type Config struct { SampleRate int `ini:"audio_sample_rate"` ChannelCount int `ini:"audio_channel_count"` + EnableOggFileOutput bool `ini:"enable_ogg_file_output"` + OggDir string `ini:"ogg_dir"` + DumpFile string `ini:"dump_file"` LogDir string `ini:"log_dir"` @@ -173,6 +176,10 @@ func setDefaultsConfig(config *Config) { if config.RetryIntervalMs == 0 { config.RetryIntervalMs = DefaultRetryIntervalMs } + + if config.OggDir == "" { + config.OggDir = "." + } } func validateConfig(config *Config) error { diff --git a/config_example.ini b/config_example.ini index b9578a3..6493c2f 100644 --- a/config_example.ini +++ b/config_example.ini @@ -55,6 +55,10 @@ retry_interval_ms = 100 # aws の場合は IsPartial が false, gcp の場合は IsFinal が true の場合の最終的な結果のみを返す指定 final_result_only = true +# 受信した音声データを Ogg ファイルで保存するかどうかです +enable_ogg_file_output = false +# Ogg ファイルの保存先ディレクトリです +ogg_dir = "." # 採用する結果の信頼スコアの最小値です(aws 指定時のみ有効) # minimum_confidence_score が 0.0 の場合は信頼スコアによるフィルタリングは無効です diff --git a/handler.go b/handler.go index 9586524..5833a3e 100644 --- a/handler.go +++ b/handler.go @@ -8,6 +8,8 @@ import ( "fmt" "io" "net/http" + "os" + "path" "strings" "time" @@ -41,6 +43,16 @@ func NewSuzuErrorResponse(err error) TranscriptionResult { } } +type soraHeader struct { + SoraChannelID string `header:"sora-channel-id"` + SoraSessionID string `header:"sora-session-id"` + // SoraClientID string `header:"sora-client-id"` + SoraConnectionID string `header:"sora-connection-id"` + // SoraAudioCodecType string `header:"sora-audio-codec-type"` + // SoraAudioSampleRate int64 `header:"sora-audio-sample-rate"` + SoraAudioStreamingLanguageCode string `header:"sora-audio-streaming-language-code"` +} + func getServiceHandler(serviceType string, config Config, channelID, connectionID string, sampleRate uint32, channelCount uint16, languageCode string, onResultFunc any) (serviceHandlerInterface, error) { newHandlerFunc, err := NewServiceHandlerFuncs.get(serviceType) if err != nil { @@ -65,15 +77,7 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte return echo.NewHTTPError(http.StatusBadRequest) } - h := struct { - SoraChannelID string `header:"Sora-Channel-Id"` - // SoraSessionID string `header:"sora-session-id"` - // SoraClientID string `header:"sora-client-id"` - SoraConnectionID string `header:"sora-connection-id"` - // SoraAudioCodecType string `header:"sora-audio-codec-type"` - // SoraAudioSampleRate int64 `header:"sora-audio-sample-rate"` - SoraAudioStreamingLanguageCode string `header:"sora-audio-streaming-language-code"` - }{} + h := soraHeader{} if err := (&echo.DefaultBinder{}).BindHeaders(c, &h); err != nil { zlog.Error(). Err(err). @@ -153,7 +157,7 @@ func (s *Server) createSpeechHandler(serviceType string, onResultFunc func(conte serviceHandlerCtx, cancelServiceHandler := context.WithCancel(ctx) defer cancelServiceHandler() - reader, err := serviceHandler.Handle(serviceHandlerCtx, opusCh) + reader, err := serviceHandler.Handle(serviceHandlerCtx, opusCh, h) if err != nil { zlog.Error(). Err(err). @@ -459,17 +463,39 @@ func readOpus(ctx context.Context, reader io.Reader) chan opusChannel { return opusCh } -func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, channelCount uint16, c Config) io.ReadCloser { +func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, channelCount uint16, c Config, header soraHeader) (io.ReadCloser, error) { oggReader, oggWriter := io.Pipe() + writers := []io.Writer{} + + var f *os.File + if c.EnableOggFileOutput { + fileName := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID) + filePath := path.Join(c.OggDir, fileName) + + var err error + f, err = os.Create(filePath) + if err != nil { + return nil, err + } + writers = append(writers, f) + } + writers = append(writers, oggWriter) + + multiWriter := io.MultiWriter(writers...) + go func() { - o, err := NewWith(oggWriter, sampleRate, channelCount) + o, err := NewWith(multiWriter, sampleRate, channelCount) if err != nil { oggWriter.CloseWithError(err) return } defer o.Close() + if c.EnableOggFileOutput { + o.fd = f + } + for { select { case <-ctx.Done(): @@ -501,7 +527,7 @@ func opus2ogg(ctx context.Context, opusCh chan opusChannel, sampleRate uint32, c } }() - return oggReader + return oggReader, nil } type opusRequest struct { diff --git a/handler_test.go b/handler_test.go index ce3d365..f9ae64d 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,8 +1,12 @@ package suzu import ( + "context" "errors" + "fmt" "io" + "os" + "path/filepath" "testing" "time" @@ -310,3 +314,164 @@ func TestReadPacketWithHeader(t *testing.T) { }) } } + +func TestOggFileWriting(t *testing.T) { + t.Run("success", func(t *testing.T) { + oggDir, err := os.MkdirTemp("", "ogg-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(oggDir) + + c := Config{ + EnableOggFileOutput: true, + OggDir: oggDir, + } + + header := soraHeader{ + SoraChannelID: "ogg-test", + SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM", + SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G", + } + + opusCh := make(chan opusChannel) + defer close(opusCh) + + sampleRate := uint32(48000) + channelCount := uint16(1) + + ctx := context.Background() + reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header) + if assert.NoError(t, err) { + assert.NotNil(t, reader) + } + defer reader.Close() + + // ファイルへの書き込み待ち + time.Sleep(100 * time.Millisecond) + + filename := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID) + filePath := filepath.Join(oggDir, filename) + _, err = os.Stat(filePath) + assert.NoError(t, err) + + // Ogg ファイルのヘッダーを確認 + f, err := os.Open(filePath) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + buf := make([]byte, 4) + n, err := f.Read(buf) + assert.NoError(t, err) + assert.Equal(t, []byte(`OggS`), buf[:n]) + }) + + t.Run("disable_ogg_file_output", func(t *testing.T) { + oggDir, err := os.MkdirTemp("", "ogg-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(oggDir) + + c := Config{ + EnableOggFileOutput: false, + OggDir: oggDir, + } + + header := soraHeader{ + SoraChannelID: "ogg-test", + SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM", + SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G", + } + + opusCh := make(chan opusChannel) + defer close(opusCh) + + sampleRate := uint32(48000) + channelCount := uint16(1) + + ctx := context.Background() + reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header) + assert.NoError(t, err) + assert.NotNil(t, reader) + defer reader.Close() + + filename := fmt.Sprintf("%s-%s.ogg", header.SoraSessionID, header.SoraConnectionID) + filePath := filepath.Join(oggDir, filename) + _, err = os.Stat(filePath) + assert.ErrorIs(t, err, os.ErrNotExist) + }) + + t.Run("no permission", func(t *testing.T) { + oggDir, err := os.MkdirTemp("", "ogg-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(oggDir) + + // 書き込み権限を剥奪 + if err := os.Chmod(oggDir, 0000); err != nil { + t.Fatal(err) + } + defer func() { + if err := os.Chmod(oggDir, 0700); err != nil { + t.Fatal(err) + } + }() + + c := Config{ + EnableOggFileOutput: true, + OggDir: oggDir, + } + + header := soraHeader{ + SoraChannelID: "ogg-test", + SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM", + SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G", + } + + opusCh := make(chan opusChannel) + defer close(opusCh) + + sampleRate := uint32(48000) + channelCount := uint16(1) + + ctx := context.Background() + reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header) + assert.ErrorIs(t, err, os.ErrPermission) + assert.Nil(t, reader) + }) + + t.Run("directory does not exist", func(t *testing.T) { + oggDir, err := os.MkdirTemp("", "ogg-") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(oggDir) + + c := Config{ + EnableOggFileOutput: true, + // 既存のディレクトリ名に 0 を付与して存在しないディレクトリを指定する + OggDir: oggDir + "0", + } + + header := soraHeader{ + SoraChannelID: "ogg-test", + SoraSessionID: "C2TFB1QBDS4WD5SX317SWMJ6FM", + SoraConnectionID: "1X0Z8JXZAD5A93X68M2S9NTC4G", + } + + opusCh := make(chan opusChannel) + defer close(opusCh) + + sampleRate := uint32(48000) + channelCount := uint16(1) + + ctx := context.Background() + reader, err := opus2ogg(ctx, opusCh, sampleRate, channelCount, c, header) + assert.ErrorIs(t, err, os.ErrNotExist) + assert.Nil(t, reader) + }) +} diff --git a/packet_dump_handler.go b/packet_dump_handler.go index e35e855..565bc20 100644 --- a/packet_dump_handler.go +++ b/packet_dump_handler.go @@ -67,7 +67,7 @@ func (h *PacketDumpHandler) ResetRetryCount() int { return h.RetryCount } -func (h *PacketDumpHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) { +func (h *PacketDumpHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) { c := h.Config filename := c.DumpFile channelID := h.ChannelID diff --git a/service_handler.go b/service_handler.go index 05a3cf2..5b61ec3 100644 --- a/service_handler.go +++ b/service_handler.go @@ -15,7 +15,7 @@ var ( ) type serviceHandlerInterface interface { - Handle(context.Context, chan opusChannel) (*io.PipeReader, error) + Handle(context.Context, chan opusChannel, soraHeader) (*io.PipeReader, error) UpdateRetryCount() int GetRetryCount() int ResetRetryCount() int diff --git a/speech_to_text_handler.go b/speech_to_text_handler.go index 7d0649f..31d8e27 100644 --- a/speech_to_text_handler.go +++ b/speech_to_text_handler.go @@ -91,10 +91,13 @@ func (h *SpeechToTextHandler) ResetRetryCount() int { return h.RetryCount } -func (h *SpeechToTextHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) { +func (h *SpeechToTextHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) { stt := NewSpeechToText(h.Config, h.LanguageCode, int32(h.SampleRate), int32(h.ChannelCount)) - packetReader := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config) + packetReader, err := opus2ogg(ctx, opusCh, h.SampleRate, h.ChannelCount, h.Config, header) + if err != nil { + return nil, err + } stream, err := stt.Start(ctx, packetReader) if err != nil { diff --git a/test_handler.go b/test_handler.go index 005ea90..287eb60 100644 --- a/test_handler.go +++ b/test_handler.go @@ -73,7 +73,7 @@ func (h *TestHandler) ResetRetryCount() int { return h.RetryCount } -func (h *TestHandler) Handle(ctx context.Context, opusCh chan opusChannel) (*io.PipeReader, error) { +func (h *TestHandler) Handle(ctx context.Context, opusCh chan opusChannel, header soraHeader) (*io.PipeReader, error) { r, w := io.Pipe() reader := opusChannelToIOReadCloser(ctx, opusCh)