Skip to content

Commit

Permalink
Merge pull request #174 from jsichi/tee-client
Browse files Browse the repository at this point in the history
Add support for processing same audio stream via multiple clients running different tasks.
  • Loading branch information
makaveli10 authored Mar 17, 2024
2 parents 754f22d + 7c7a446 commit 2e37216
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 158 deletions.
55 changes: 50 additions & 5 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os
import scipy
import websocket
import copy
import unittest
from unittest.mock import patch, MagicMock
from whisper_live.client import TranscriptionClient
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper_live.utils import resample
from pathlib import Path


class BaseTestCase(unittest.TestCase):
Expand All @@ -24,14 +26,14 @@ def setUp(self, mock_pyaudio, mock_websocket):

self.mock_pyaudio = mock_pyaudio
self.mock_websocket = mock_websocket
self.mock_audio_packet = b'\x00\x01\x02\x03'

def tearDown(self):
self.client.close_websocket()
self.mock_pyaudio.stop()
self.mock_websocket.stop()
del self.client


class TestClientWebSocketCommunication(BaseTestCase):
def test_websocket_communication(self):
expected_url = 'ws://localhost:9090'
Expand Down Expand Up @@ -106,6 +108,49 @@ def test_resample_audio(self):

class TestSendingAudioPacket(BaseTestCase):
def test_send_packet(self):
mock_audio_packet = b'\x00\x01\x02\x03'
self.client.send_packet_to_server(mock_audio_packet)
self.client.client_socket.send.assert_called_with(mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
self.client.send_packet_to_server(self.mock_audio_packet)
self.client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)

class TestTee(BaseTestCase):
@patch('whisper_live.client.websocket.WebSocketApp')
@patch('whisper_live.client.pyaudio.PyAudio')
def setUp(self, mock_audio, mock_websocket):
super().setUp()
self.client2 = Client(host='localhost', port=9090, lang="es", translate=False, srt_file_path="transcript.srt")
self.client3 = Client(host='localhost', port=9090, lang="es", translate=True, srt_file_path="translation.srt")
# need a separate mock for each websocket
self.client3.client_socket = copy.deepcopy(self.client3.client_socket)
self.tee = TranscriptionTeeClient([self.client2, self.client3])

def tearDown(self):
self.tee.close_all_clients()
del self.tee
super().tearDown()

def test_invalid_constructor(self):
with self.assertRaises(Exception) as context:
TranscriptionTeeClient([])

def test_multicast_unconditional(self):
self.tee.multicast_packet(self.mock_audio_packet, True)
for client in self.tee.clients:
client.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)

def test_multicast_conditional(self):
self.client2.recording = False
self.client3.recording = True
self.tee.multicast_packet(self.mock_audio_packet, False)
self.client2.client_socket.send.assert_not_called()
self.client3.client_socket.send.assert_called_with(self.mock_audio_packet, websocket.ABNF.OPCODE_BINARY)

def test_close_all(self):
self.tee.close_all_clients()
for client in self.tee.clients:
client.client_socket.close.assert_called()

def test_write_all_srt(self):
for client in self.tee.clients:
client.server_backend = "faster_whisper"
self.tee.write_all_clients_srt()
self.assertTrue(Path("transcript.srt").is_file())
self.assertTrue(Path("translation.srt").is_file())
37 changes: 25 additions & 12 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from websockets.exceptions import ConnectionClosed
from whisper_live.server import TranscriptionServer
from whisper_live.client import TranscriptionClient
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper.normalizers import EnglishTextNormalizer


Expand Down Expand Up @@ -69,6 +69,10 @@ def test_recv_audio_exception_handling(self, mock_websocket):
class TestServerInferenceAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mock_pyaudio_patch = mock.patch('pyaudio.PyAudio')
cls.mock_pyaudio = cls.mock_pyaudio_patch.start()
cls.mock_pyaudio.return_value.open.return_value = mock.MagicMock()

cls.server_process = subprocess.Popen(["python", "run_server.py"])
time.sleep(2)

Expand All @@ -77,21 +81,13 @@ def tearDownClass(cls):
cls.server_process.terminate()
cls.server_process.wait()

@mock.patch('pyaudio.PyAudio')
def setUp(self, mock_pyaudio):
self.mock_pyaudio = mock_pyaudio.return_value
self.mock_stream = mock.MagicMock()
self.mock_pyaudio.open.return_value = self.mock_stream
def setUp(self):
self.metric = evaluate.load("wer")
self.normalizer = EnglishTextNormalizer()
self.client = TranscriptionClient(
"localhost", "9090", model="base.en", lang="en",
)

def test_inference(self):
def check_prediction(self, srt_path):
gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!"
self.client("assets/jfk.flac")
with open("output.srt", "r") as f:
with open(srt_path, "r") as f:
lines = f.readlines()
prediction = " ".join([line.strip() for line in lines[2::4]])
prediction_normalized = self.normalizer(prediction)
Expand All @@ -104,6 +100,23 @@ def test_inference(self):
)
self.assertLess(wer, 0.05)

def test_inference(self):
client = TranscriptionClient(
"localhost", "9090", model="base.en", lang="en",
)
client("assets/jfk.flac")
self.check_prediction("output.srt")

def test_simultaneous_inference(self):
client1 = Client(
"localhost", "9090", model="base.en", lang="en", srt_file_path="transcript1.srt")
client2 = Client(
"localhost", "9090", model="base.en", lang="en", srt_file_path="transcript2.srt")
tee = TranscriptionTeeClient([client1, client2])
tee("assets/jfk.flac")
self.check_prediction("transcript1.srt")
self.check_prediction("transcript2.srt")


class TestExceptionHandling(unittest.TestCase):
def setUp(self):
Expand Down
Loading

0 comments on commit 2e37216

Please sign in to comment.