Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add headers for multiple language identification #99

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion amazon_transcribe/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ async def start_stream_transcription(
enable_partial_results_stabilization: Optional[bool] = None,
partial_results_stability: Optional[str] = None,
language_model_name: Optional[str] = None,
identify_language: Optional[bool] = False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
identify_language: Optional[bool] = False,
identify_language: Optional[bool] = None,

preferred_language: Optional[str] = None,
identify_multiple_languages=False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
identify_multiple_languages=False,
identify_multiple_languages: Optional[bool] = None,

language_options=None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with AWS coding standards, but is there a reason why there isn't type hiting on identify_multiple_languages and language_options ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
language_options=None,
language_options: Optional[List[str]] = None,

) -> StartStreamTranscriptionEventStream:
"""Coordinate transcription settings and start stream.

Expand All @@ -100,7 +104,8 @@ async def start_stream_transcription(
than 5 minutes.

:param language_code:
Indicates the source language used in the input audio stream.
Indicates the source language used in the input audio stream. Set to
None if identify_multiple_languages is set to True
:param media_sample_rate_hz:
The sample rate, in Hertz, of the input audio. We suggest that you
use 8000 Hz for low quality audio and 16000 Hz for high quality audio.
Expand Down Expand Up @@ -144,6 +149,22 @@ async def start_stream_transcription(
overall transcription accuracy. Defaults to "high" if not set explicitly.
:param language_model_name:
The name of the language model you want to use.
:param identify_language:
if True, the language of the stream will be automatically detected. Set
language_code to None and provide at least two language_options when
identify_language is True.
:param preferred_language:
Adding a preferred language can speed up the language identification
process, which is helpful for short audio clips.
:param identify_multiple_languages:
If true, all languages spoken in the stream are identified. A multilingual
transcripts is created your transcript using each identified language.
You must also provide at least two language_options and set
language_code to None
:param language_options:
A list of possible language to use when identify_multiple_languages is
set to True. Note that not all languages supported by Transcribe are
supported for multiple language identification
"""
transcribe_streaming_request = StartStreamTranscriptionRequest(
language_code,
Expand All @@ -159,6 +180,10 @@ async def start_stream_transcription(
enable_partial_results_stabilization,
partial_results_stability,
language_model_name,
identify_language,
preferred_language,
identify_multiple_languages,
language_options,
)
endpoint = await self._endpoint_resolver.resolve(self.region)

Expand Down
20 changes: 19 additions & 1 deletion amazon_transcribe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ class StartStreamTranscriptionRequest:
"""Transcription Request

:param language_code:
Indicates the source language used in the input audio stream.
Indicates the source language used in the input audio stream. Set to
None if identify_multiple_languages is set to True

:param media_sample_rate_hz:
The sample rate, in Hertz, of the input audio. We suggest that you
Expand Down Expand Up @@ -226,6 +227,15 @@ class StartStreamTranscriptionRequest:
overall transcription accuracy.
:param language_model_name:
The name of the language model you want to use.
:param identify_multiple_languages:
If true, all languages spoken in the stream are identified. A multilingual
transcripts is created your transcript using each identified language.
You must also provide at least two language_options and set
language_code to None
: param language_options:
A list of possible language to use when identify_multiple_languages is
set to True. Note that not all languages supported by Transcribe are
supported for multiple language identification
"""

def __init__(
Expand All @@ -243,6 +253,10 @@ def __init__(
enable_partial_results_stabilization=None,
partial_results_stability=None,
language_model_name=None,
identify_language=None,
preferred_language=None,
identify_multiple_languages=False,
language_options=None,
):

self.language_code: Optional[str] = language_code
Expand All @@ -262,6 +276,10 @@ def __init__(
] = enable_partial_results_stabilization
self.partial_results_stability: Optional[str] = partial_results_stability
self.language_model_name: Optional[str] = language_model_name
self.identify_language: Optional[bool] = identify_language
self.preferred_language: Optional[str] = preferred_language
self.identify_multiple_languages: Optional[bool] = identify_multiple_languages
self.language_options: Optional[List[str]] = language_options or []


class StartStreamTranscriptionResponse:
Expand Down
36 changes: 35 additions & 1 deletion amazon_transcribe/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.


from typing import Any, Dict, Tuple, Optional
from typing import Any, Dict, Tuple, Optional, List

from amazon_transcribe.request import Request
from amazon_transcribe.structures import BufferableByteStream
Expand Down Expand Up @@ -56,6 +56,10 @@ def _serialize_bool_header(
) -> Dict[str, str]:
return self._serialize_header(header, value)

def _serialize_list_header(self, header: str, value: List[str]) -> Dict[str, str]:
languages = ",".join(value)
return self._serialize_str_header(header, languages)

def serialize_start_stream_transcription_request(
self, endpoint: str, request_shape: StartStreamTranscriptionRequest
) -> Request:
Expand Down Expand Up @@ -130,6 +134,36 @@ def serialize_start_stream_transcription_request(
)
)

headers.update(
self._serialize_bool_header(
"identify-language",
request_shape.identify_language,
)
)

headers.update(
self._serialize_str_header(
"preferred-language",
request_shape.preferred_language,
)
)

if request_shape.identify_multiple_languages:
headers.update(
self._serialize_bool_header(
"identify-multiple-languages",
request_shape.identify_multiple_languages,
)
)

if request_shape.language_options:
headers.update(
self._serialize_list_header(
"language-options",
request_shape.language_options,
)
)

_add_required_headers(endpoint, headers)

body = BufferableByteStream()
Expand Down
30 changes: 28 additions & 2 deletions tests/functional/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,22 @@ def request_shape():
)


@pytest.fixture
def multi_lid_request():
return StartStreamTranscriptionRequest(
language_code=None,
media_sample_rate_hz=9000,
media_encoding="pcm",
identify_multiple_languages=True,
language_options=["en-US", "de-DE"],
)


request_serializer = TranscribeStreamingSerializer()


class TestStartStreamTransactionRequest:
def test_serialization(self, request_shape):
request_serializer = TranscribeStreamingSerializer()
request = request_serializer.serialize_start_stream_transcription_request(
endpoint="https://transcribe.aws.com",
request_shape=request_shape,
Expand All @@ -37,13 +50,26 @@ def test_serialization(self, request_shape):
assert isinstance(request.body, BufferableByteStream)

def test_serialization_with_missing_endpoint(self, request_shape):
request_serializer = TranscribeStreamingSerializer()
with pytest.raises(ValidationException):
request_serializer.serialize_start_stream_transcription_request(
endpoint=None,
request_shape=request_shape,
)

def test_serialization_with_multi_lid(self, multi_lid_request):
request = request_serializer.serialize_start_stream_transcription_request(
endpoint="https://transcribe.aws.com",
request_shape=multi_lid_request,
).prepare()

assert "x-amzn-transcribe-language-code" not in request.headers
assert request.headers["x-amzn-transcribe-sample-rate"] == "9000"
assert request.headers["x-amzn-transcribe-media-encoding"] == "pcm"
assert (
request.headers["x-amzn-transcribe-identify-multiple-languages"] == "True"
)
assert request.headers["x-amzn-transcribe-language-options"] == "en-US,de-DE"


class TestAudioEventSerializer:
def test_serialization(self):
Expand Down
27 changes: 23 additions & 4 deletions tests/integration/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@
)
from tests.integration import TEST_WAV_PATH

request_options = [
# plain request with a known language
{"language_code": "en-US"},
# language identification
{
"language_code": None,
"identify_language": True,
"language_options": ["en-US", "de-DE"],
"preferred_language": "en-US",
},
# multiple language identification
{
"language_code": None,
"identify_multiple_languages": True,
"language_options": ["en-US", "de-DE"],
},
]


class TestClientStreaming:
@pytest.fixture
Expand All @@ -31,11 +49,12 @@ async def byte_generator():
return byte_generator

@pytest.mark.asyncio
async def test_client_start_transcribe_stream(self, client, wav_bytes):
@pytest.mark.parametrize("request_args", request_options)
async def test_client_start_transcribe_stream(
self, client, wav_bytes, request_args
):
stream = await client.start_stream_transcription(
language_code="en-US",
media_sample_rate_hz=16000,
media_encoding="pcm",
media_sample_rate_hz=16000, media_encoding="pcm", **request_args
)

async for chunk in wav_bytes():
Expand Down