From ce50006ecaa0ca0d257f0d2841dba8e0aa5ef679 Mon Sep 17 00:00:00 2001 From: Miroslav Batchkarov Date: Tue, 30 Jan 2024 10:15:57 +0100 Subject: [PATCH 1/4] add headers for multiple language identification --- amazon_transcribe/client.py | 16 +++++++++++++++- amazon_transcribe/model.py | 16 +++++++++++++++- amazon_transcribe/serialize.py | 22 +++++++++++++++++++++- tests/functional/test_serialize.py | 28 ++++++++++++++++++++++++++-- tests/integration/test_client.py | 15 ++++++++------- 5 files changed, 85 insertions(+), 12 deletions(-) diff --git a/amazon_transcribe/client.py b/amazon_transcribe/client.py index 5229e74..153395e 100644 --- a/amazon_transcribe/client.py +++ b/amazon_transcribe/client.py @@ -84,6 +84,8 @@ async def start_stream_transcription( enable_partial_results_stabilization: Optional[bool] = None, partial_results_stability: Optional[str] = None, language_model_name: Optional[str] = None, + identify_multiple_languages=False, + language_options=None ) -> StartStreamTranscriptionEventStream: """Coordinate transcription settings and start stream. @@ -100,7 +102,8 @@ async def start_stream_transcription( than 5 minutes. :param language_code: - Indicates the source language used in the input audio stream. + Indicates the source language used in the input audio stream. Set to + None if identify_multiple_languages is set to True :param media_sample_rate_hz: The sample rate, in Hertz, of the input audio. We suggest that you use 8000 Hz for low quality audio and 16000 Hz for high quality audio. @@ -144,6 +147,15 @@ async def start_stream_transcription( overall transcription accuracy. Defaults to "high" if not set explicitly. :param language_model_name: The name of the language model you want to use. + :param identify_multiple_languages: + If true, all languages spoken in the stream are identified. A multilingual + transcripts is created your transcript using each identified language. + You must also provide at least two language_options and set + language_code to None + :param language_options: + A list of possible language to use when identify_multiple_languages is + set to True. Note that not all languages supported by Transcribe are + supported for multiple language identification """ transcribe_streaming_request = StartStreamTranscriptionRequest( language_code, @@ -159,6 +171,8 @@ async def start_stream_transcription( enable_partial_results_stabilization, partial_results_stability, language_model_name, + identify_multiple_languages, + language_options ) endpoint = await self._endpoint_resolver.resolve(self.region) diff --git a/amazon_transcribe/model.py b/amazon_transcribe/model.py index 38e241e..2eb7379 100644 --- a/amazon_transcribe/model.py +++ b/amazon_transcribe/model.py @@ -171,7 +171,8 @@ class StartStreamTranscriptionRequest: """Transcription Request :param language_code: - Indicates the source language used in the input audio stream. + Indicates the source language used in the input audio stream. Set to + None if identify_multiple_languages is set to True :param media_sample_rate_hz: The sample rate, in Hertz, of the input audio. We suggest that you @@ -226,6 +227,15 @@ class StartStreamTranscriptionRequest: overall transcription accuracy. :param language_model_name: The name of the language model you want to use. + :param identify_multiple_languages: + If true, all languages spoken in the stream are identified. A multilingual + transcripts is created your transcript using each identified language. + You must also provided at least two language_options and set + language_code to None + : param language_options: + A list of possible language to use when identify_multiple_languages is + set to True. Note that not all languages supported by Transcribe are + supported for multiple language identification """ def __init__( @@ -243,6 +253,8 @@ def __init__( enable_partial_results_stabilization=None, partial_results_stability=None, language_model_name=None, + identify_multiple_languages=False, + language_options=None ): self.language_code: Optional[str] = language_code @@ -262,6 +274,8 @@ def __init__( ] = enable_partial_results_stabilization self.partial_results_stability: Optional[str] = partial_results_stability self.language_model_name: Optional[str] = language_model_name + self.identify_multiple_languages: Optional[bool] = identify_multiple_languages + self.language_options: Optional[List[str]] = language_options or [] class StartStreamTranscriptionResponse: diff --git a/amazon_transcribe/serialize.py b/amazon_transcribe/serialize.py index 88056ed..a6d2368 100644 --- a/amazon_transcribe/serialize.py +++ b/amazon_transcribe/serialize.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. -from typing import Any, Dict, Tuple, Optional +from typing import Any, Dict, Tuple, Optional, List from amazon_transcribe.request import Request from amazon_transcribe.structures import BufferableByteStream @@ -56,6 +56,12 @@ def _serialize_bool_header( ) -> Dict[str, str]: return self._serialize_header(header, value) + def _serialize_list_header( + self, header: str, value: List[str] + ) -> Dict[str, str]: + languages = ",".join(value) + return self._serialize_str_header(header, languages) + def serialize_start_stream_transcription_request( self, endpoint: str, request_shape: StartStreamTranscriptionRequest ) -> Request: @@ -130,6 +136,20 @@ def serialize_start_stream_transcription_request( ) ) + headers.update( + self._serialize_bool_header( + "identify-multiple-languages", + request_shape.identify_multiple_languages, + ) + ) + + headers.update( + self._serialize_list_header( + "language-options", + request_shape.language_options, + ) + ) + _add_required_headers(endpoint, headers) body = BufferableByteStream() diff --git a/tests/functional/test_serialize.py b/tests/functional/test_serialize.py index 405a11c..4cd7ede 100644 --- a/tests/functional/test_serialize.py +++ b/tests/functional/test_serialize.py @@ -21,9 +21,21 @@ def request_shape(): ) +@pytest.fixture +def multi_lid_request(): + return StartStreamTranscriptionRequest( + language_code=None, + media_sample_rate_hz=9000, + media_encoding="pcm", + identify_multiple_languages=True, + language_options=["en-US", "de-DE"] + ) + + +request_serializer = TranscribeStreamingSerializer() + class TestStartStreamTransactionRequest: def test_serialization(self, request_shape): - request_serializer = TranscribeStreamingSerializer() request = request_serializer.serialize_start_stream_transcription_request( endpoint="https://transcribe.aws.com", request_shape=request_shape, @@ -37,7 +49,6 @@ def test_serialization(self, request_shape): assert isinstance(request.body, BufferableByteStream) def test_serialization_with_missing_endpoint(self, request_shape): - request_serializer = TranscribeStreamingSerializer() with pytest.raises(ValidationException): request_serializer.serialize_start_stream_transcription_request( endpoint=None, @@ -45,6 +56,19 @@ def test_serialization_with_missing_endpoint(self, request_shape): ) + def test_serialization_with_multi_lid(self, multi_lid_request): + request = request_serializer.serialize_start_stream_transcription_request( + endpoint="https://transcribe.aws.com", + request_shape=multi_lid_request, + ).prepare() + + assert "x-amzn-transcribe-language-code" not in request.headers + assert request.headers["x-amzn-transcribe-sample-rate"] == "9000" + assert request.headers["x-amzn-transcribe-media-encoding"] == "pcm" + assert request.headers["x-amzn-transcribe-identify-multiple-languages"] == "True" + assert request.headers["x-amzn-transcribe-language-options"] == "en-US,de-DE" + + class TestAudioEventSerializer: def test_serialization(self): audio_event = AudioEvent(audio_chunk=b"foo") diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 9d0a48d..419c0b9 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -9,7 +9,11 @@ ) from tests.integration import TEST_WAV_PATH - +request_options = [ + {"language_code": "en-US", "media_sample_rate_hz": 16000, "media_encoding": "pcm"}, + {"language_code": None, "media_sample_rate_hz": 16000, "media_encoding": "pcm", + "identify_multiple_languages": True, "language_options": ["en-US", "de-DE"]}, +] class TestClientStreaming: @pytest.fixture def client(self): @@ -31,12 +35,9 @@ async def byte_generator(): return byte_generator @pytest.mark.asyncio - async def test_client_start_transcribe_stream(self, client, wav_bytes): - stream = await client.start_stream_transcription( - language_code="en-US", - media_sample_rate_hz=16000, - media_encoding="pcm", - ) + @pytest.mark.parametrize('request_args', request_options) + async def test_client_start_transcribe_stream(self, client, wav_bytes, request_args): + stream = await client.start_stream_transcription(**request_args) async for chunk in wav_bytes(): await stream.input_stream.send_audio_event(audio_chunk=chunk) From 63ffee55cd5a86cff0e751b3b4ceac035a2b5875 Mon Sep 17 00:00:00 2001 From: Miroslav Batchkarov Date: Tue, 30 Jan 2024 10:50:33 +0100 Subject: [PATCH 2/4] add headers for language identification --- amazon_transcribe/client.py | 15 +++++++++++++-- amazon_transcribe/model.py | 8 ++++++-- amazon_transcribe/serialize.py | 30 ++++++++++++++++++++++-------- tests/functional/test_serialize.py | 8 +++++--- tests/integration/test_client.py | 30 ++++++++++++++++++++++++------ 5 files changed, 70 insertions(+), 21 deletions(-) diff --git a/amazon_transcribe/client.py b/amazon_transcribe/client.py index 153395e..da4f73a 100644 --- a/amazon_transcribe/client.py +++ b/amazon_transcribe/client.py @@ -84,8 +84,10 @@ async def start_stream_transcription( enable_partial_results_stabilization: Optional[bool] = None, partial_results_stability: Optional[str] = None, language_model_name: Optional[str] = None, + identify_language: Optional[bool] = False, + preferred_language: Optional[str] = None, identify_multiple_languages=False, - language_options=None + language_options=None, ) -> StartStreamTranscriptionEventStream: """Coordinate transcription settings and start stream. @@ -147,6 +149,13 @@ async def start_stream_transcription( overall transcription accuracy. Defaults to "high" if not set explicitly. :param language_model_name: The name of the language model you want to use. + :param identify_language: + if True, the language of the stream will be automatically detected. Set + language_code to None and provide at least two language_options when + identify_language is True. + :param preferred_language: + Adding a preferred language can speed up the language identification + process, which is helpful for short audio clips. :param identify_multiple_languages: If true, all languages spoken in the stream are identified. A multilingual transcripts is created your transcript using each identified language. @@ -171,8 +180,10 @@ async def start_stream_transcription( enable_partial_results_stabilization, partial_results_stability, language_model_name, + identify_language, + preferred_language, identify_multiple_languages, - language_options + language_options, ) endpoint = await self._endpoint_resolver.resolve(self.region) diff --git a/amazon_transcribe/model.py b/amazon_transcribe/model.py index 2eb7379..1aee443 100644 --- a/amazon_transcribe/model.py +++ b/amazon_transcribe/model.py @@ -230,7 +230,7 @@ class StartStreamTranscriptionRequest: :param identify_multiple_languages: If true, all languages spoken in the stream are identified. A multilingual transcripts is created your transcript using each identified language. - You must also provided at least two language_options and set + You must also provide at least two language_options and set language_code to None : param language_options: A list of possible language to use when identify_multiple_languages is @@ -253,8 +253,10 @@ def __init__( enable_partial_results_stabilization=None, partial_results_stability=None, language_model_name=None, + identify_language=None, + preferred_language=None, identify_multiple_languages=False, - language_options=None + language_options=None, ): self.language_code: Optional[str] = language_code @@ -274,6 +276,8 @@ def __init__( ] = enable_partial_results_stabilization self.partial_results_stability: Optional[str] = partial_results_stability self.language_model_name: Optional[str] = language_model_name + self.identify_language: Optional[bool] = identify_language + self.preferred_language: Optional[str] = preferred_language self.identify_multiple_languages: Optional[bool] = identify_multiple_languages self.language_options: Optional[List[str]] = language_options or [] diff --git a/amazon_transcribe/serialize.py b/amazon_transcribe/serialize.py index a6d2368..e5aa44b 100644 --- a/amazon_transcribe/serialize.py +++ b/amazon_transcribe/serialize.py @@ -56,9 +56,7 @@ def _serialize_bool_header( ) -> Dict[str, str]: return self._serialize_header(header, value) - def _serialize_list_header( - self, header: str, value: List[str] - ) -> Dict[str, str]: + def _serialize_list_header(self, header: str, value: List[str]) -> Dict[str, str]: languages = ",".join(value) return self._serialize_str_header(header, languages) @@ -138,18 +136,34 @@ def serialize_start_stream_transcription_request( headers.update( self._serialize_bool_header( - "identify-multiple-languages", - request_shape.identify_multiple_languages, + "identify-language", + request_shape.identify_language, ) ) headers.update( - self._serialize_list_header( - "language-options", - request_shape.language_options, + self._serialize_str_header( + "preferred_language", + request_shape.preferred_language, ) ) + if request_shape.identify_multiple_languages: + headers.update( + self._serialize_bool_header( + "identify-multiple-languages", + request_shape.identify_multiple_languages, + ) + ) + + if request_shape.language_options: + headers.update( + self._serialize_list_header( + "language-options", + request_shape.language_options, + ) + ) + _add_required_headers(endpoint, headers) body = BufferableByteStream() diff --git a/tests/functional/test_serialize.py b/tests/functional/test_serialize.py index 4cd7ede..8dda2df 100644 --- a/tests/functional/test_serialize.py +++ b/tests/functional/test_serialize.py @@ -28,12 +28,13 @@ def multi_lid_request(): media_sample_rate_hz=9000, media_encoding="pcm", identify_multiple_languages=True, - language_options=["en-US", "de-DE"] + language_options=["en-US", "de-DE"], ) request_serializer = TranscribeStreamingSerializer() + class TestStartStreamTransactionRequest: def test_serialization(self, request_shape): request = request_serializer.serialize_start_stream_transcription_request( @@ -55,7 +56,6 @@ def test_serialization_with_missing_endpoint(self, request_shape): request_shape=request_shape, ) - def test_serialization_with_multi_lid(self, multi_lid_request): request = request_serializer.serialize_start_stream_transcription_request( endpoint="https://transcribe.aws.com", @@ -65,7 +65,9 @@ def test_serialization_with_multi_lid(self, multi_lid_request): assert "x-amzn-transcribe-language-code" not in request.headers assert request.headers["x-amzn-transcribe-sample-rate"] == "9000" assert request.headers["x-amzn-transcribe-media-encoding"] == "pcm" - assert request.headers["x-amzn-transcribe-identify-multiple-languages"] == "True" + assert ( + request.headers["x-amzn-transcribe-identify-multiple-languages"] == "True" + ) assert request.headers["x-amzn-transcribe-language-options"] == "en-US,de-DE" diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 419c0b9..540c3fd 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -10,10 +10,24 @@ from tests.integration import TEST_WAV_PATH request_options = [ - {"language_code": "en-US", "media_sample_rate_hz": 16000, "media_encoding": "pcm"}, - {"language_code": None, "media_sample_rate_hz": 16000, "media_encoding": "pcm", - "identify_multiple_languages": True, "language_options": ["en-US", "de-DE"]}, + # plain request with a known language + {"language_code": "en-US"}, + # language identification + { + "language_code": None, + "identify_language": True, + "language_options": ["en-US", "de-DE"], + "preferred_language": "en-US", + }, + # multiple language identification + { + "language_code": None, + "identify_multiple_languages": True, + "language_options": ["en-US", "de-DE"], + }, ] + + class TestClientStreaming: @pytest.fixture def client(self): @@ -35,9 +49,13 @@ async def byte_generator(): return byte_generator @pytest.mark.asyncio - @pytest.mark.parametrize('request_args', request_options) - async def test_client_start_transcribe_stream(self, client, wav_bytes, request_args): - stream = await client.start_stream_transcription(**request_args) + @pytest.mark.parametrize("request_args", request_options) + async def test_client_start_transcribe_stream( + self, client, wav_bytes, request_args + ): + stream = await client.start_stream_transcription( + media_sample_rate_hz=16000, media_encoding="pcm", **request_args + ) async for chunk in wav_bytes(): await stream.input_stream.send_audio_event(audio_chunk=chunk) From 7736dc4c8306d7b647fffe0c6102131a8dad3dbf Mon Sep 17 00:00:00 2001 From: Miroslav Batchkarov Date: Sat, 27 Apr 2024 19:39:21 +0200 Subject: [PATCH 3/4] Fix typo in header name Header names are separated by hyphens, not underscores https://docs.aws.amazon.com/transcribe/latest/dg/lang-id-stream.html Co-authored-by: Gunwoo Kim --- amazon_transcribe/serialize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/amazon_transcribe/serialize.py b/amazon_transcribe/serialize.py index e5aa44b..3687618 100644 --- a/amazon_transcribe/serialize.py +++ b/amazon_transcribe/serialize.py @@ -143,7 +143,7 @@ def serialize_start_stream_transcription_request( headers.update( self._serialize_str_header( - "preferred_language", + "preferred-language", request_shape.preferred_language, ) ) From 107357e659c228325e563ebddb9e8b372990683b Mon Sep 17 00:00:00 2001 From: Miroslav Batchkarov Date: Fri, 7 Feb 2025 13:33:01 +0100 Subject: [PATCH 4/4] address CR comments and add language_code in response --- amazon_transcribe/client.py | 8 ++++---- amazon_transcribe/deserialize.py | 1 + amazon_transcribe/model.py | 2 ++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/amazon_transcribe/client.py b/amazon_transcribe/client.py index da4f73a..088b422 100644 --- a/amazon_transcribe/client.py +++ b/amazon_transcribe/client.py @@ -14,7 +14,7 @@ import re from binascii import unhexlify -from typing import Optional +from typing import Optional, List from amazon_transcribe import AWSCRTEventLoop from amazon_transcribe.auth import AwsCrtCredentialResolver, CredentialResolver @@ -84,10 +84,10 @@ async def start_stream_transcription( enable_partial_results_stabilization: Optional[bool] = None, partial_results_stability: Optional[str] = None, language_model_name: Optional[str] = None, - identify_language: Optional[bool] = False, + identify_language: Optional[bool] = None, preferred_language: Optional[str] = None, - identify_multiple_languages=False, - language_options=None, + identify_multiple_languages: Optional[bool] = None, + language_options: Optional[List[str]] = None, ) -> StartStreamTranscriptionEventStream: """Coordinate transcription settings and start stream. diff --git a/amazon_transcribe/deserialize.py b/amazon_transcribe/deserialize.py index 4d31585..0aa2568 100644 --- a/amazon_transcribe/deserialize.py +++ b/amazon_transcribe/deserialize.py @@ -187,6 +187,7 @@ def _parse_result(self, current_node: Any) -> Result: is_partial=current_node.get("IsPartial"), alternatives=alternatives, channel_id=current_node.get("ChannelId"), + language_code=current_node.get("LanguageCode"), ) def _parse_alternative_list(self, current_node: Any) -> List[Alternative]: diff --git a/amazon_transcribe/model.py b/amazon_transcribe/model.py index 1aee443..7889d64 100644 --- a/amazon_transcribe/model.py +++ b/amazon_transcribe/model.py @@ -158,6 +158,7 @@ def __init__( is_partial: Optional[bool] = None, alternatives: Optional[List[Alternative]] = None, channel_id: Optional[str] = None, + language_code: Optional[str] = None, ): self.result_id = result_id self.start_time = start_time @@ -165,6 +166,7 @@ def __init__( self.is_partial = is_partial self.alternatives = alternatives self.channel_id = channel_id + self.language_code = language_code class StartStreamTranscriptionRequest: