Skip to content

Commit

Permalink
Merge pull request #284 from benxu3/async-interpreter
Browse files Browse the repository at this point in the history
add --debug flag
  • Loading branch information
KillianLucas authored Jul 10, 2024
2 parents c401530 + d8d7658 commit dbb920b
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ If you want to run local speech-to-text using Whisper, you must install Rust. Fo

To customize the behavior of the system, edit the [system message, model, skills library path,](https://docs.openinterpreter.com/settings/all-settings) etc. in the `profiles` directory under the `server` directory. This file sets up an interpreter, and is powered by Open Interpreter.

To specify the text-to-speech service for the 01 `base_device.py`, set `interpreter.tts` to either "openai" for OpenAI, "elevenlabs" for ElevenLabs, or "coqui" for Coqui (local) in a profile. For the 01 Light, set `SPEAKER_SAMPLE_RATE` to 24000 for Coqui (local) or 22050 for OpenAI TTS. We currently don't support ElevenLabs TTS on the 01 Light.
To specify the text-to-speech service for the 01 `base_device.py`, set `interpreter.tts` to either "openai" for OpenAI, "elevenlabs" for ElevenLabs, or "coqui" for Coqui (local) in a profile. For the 01 Light, set `SPEAKER_SAMPLE_RATE` in `client.ino` under the `esp32` client directory to 24000 for Coqui (local) or 22050 for OpenAI TTS. We currently don't support ElevenLabs TTS on the 01 Light.

## Ubuntu Dependencies

Expand Down
8 changes: 8 additions & 0 deletions software/source/clients/base_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(self):
self.server_url = ""
self.ctrl_pressed = False
self.tts_service = ""
self.debug = False
self.playback_latency = None

def fetch_image_from_camera(self, camera_index=CAMERA_DEVICE_INDEX):
"""Captures an image from the specified camera device and saves it to a temporary file. Adds the image to the captured_images list."""
Expand Down Expand Up @@ -164,6 +166,10 @@ async def play_audiosegments(self):
while True:
try:
audio = await self.audiosegments.get()
if self.debug and self.playback_latency and isinstance(audio, bytes):
elapsed_time = time.time() - self.playback_latency
print(f"Time from request to playback: {elapsed_time} seconds")
self.playback_latency = None

if self.tts_service == "elevenlabs":
mpv_process.stdin.write(audio) # type: ignore
Expand Down Expand Up @@ -219,6 +225,8 @@ def record_audio(self):
stream.stop_stream()
stream.close()
print("Recording stopped.")
if self.debug:
self.playback_latency = time.time()

duration = wav_file.getnframes() / RATE
if duration < 0.3:
Expand Down
3 changes: 2 additions & 1 deletion software/source/clients/linux/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
device = Device()


def main(server_url):
def main(server_url, debug):
device.server_url = server_url
device.debug = debug
device.start()


Expand Down
3 changes: 2 additions & 1 deletion software/source/clients/mac/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
device = Device()


def main(server_url):
def main(server_url, debug):
device.server_url = server_url
device.debug = debug
device.start()


Expand Down
3 changes: 2 additions & 1 deletion software/source/clients/windows/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
device = Device()


def main(server_url):
def main(server_url, debug):
device.server_url = server_url
device.debug = debug
device.start()


Expand Down
47 changes: 39 additions & 8 deletions software/source/server/async_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@


class AsyncInterpreter:
def __init__(self, interpreter):
def __init__(self, interpreter, debug):
self.stt_latency = None
self.tts_latency = None
self.interpreter_latency = None
# time from first put to first yield
self.tffytfp = None
self.debug = debug

self.interpreter = interpreter
self.audio_chunks = []

Expand Down Expand Up @@ -126,6 +133,8 @@ def generate(self, message, start_interpreter):
# Experimental: The AI voice sounds better with replacements like these, but it should happen at the TTS layer
# content = content.replace(". ", ". ... ").replace(", ", ", ... ").replace("!", "! ... ").replace("?", "? ... ")
# print("yielding ", content)
if self.tffytfp is None:
self.tffytfp = time.time()

yield content

Expand Down Expand Up @@ -157,6 +166,10 @@ def generate(self, message, start_interpreter):
)

# Send a completion signal
if self.debug:
end_interpreter = time.time()
self.interpreter_latency = end_interpreter - start_interpreter
print("INTERPRETER LATENCY", self.interpreter_latency)
# self.add_to_output_queue_sync({"role": "server","type": "completion", "content": "DONE"})

async def run(self):
Expand All @@ -171,13 +184,20 @@ async def run(self):
while not self._input_queue.empty():
input_queue.append(self._input_queue.get())

message = self.stt.text()

if self.audio_chunks:
audio_bytes = bytearray(b"".join(self.audio_chunks))
wav_file_path = bytes_to_wav(audio_bytes, "audio/raw")
print("wav_file_path ", wav_file_path)
self.audio_chunks = []
if self.debug:
start_stt = time.time()
message = self.stt.text()
end_stt = time.time()
self.stt_latency = end_stt - start_stt
print("STT LATENCY", self.stt_latency)

if self.audio_chunks:
audio_bytes = bytearray(b"".join(self.audio_chunks))
wav_file_path = bytes_to_wav(audio_bytes, "audio/raw")
print("wav_file_path ", wav_file_path)
self.audio_chunks = []
else:
message = self.stt.text()

print(message)

Expand All @@ -204,11 +224,22 @@ async def run(self):
"end": True,
}
)
if self.debug:
end_tts = time.time()
self.tts_latency = end_tts - self.tts.stream_start_time
print("TTS LATENCY", self.tts_latency)
self.tts.stop()

break

async def _on_tts_chunk_async(self, chunk):
# print("adding chunk to queue")
if self.debug and self.tffytfp is not None and self.tffytfp != 0:
print(
"time from first yield to first put is ",
time.time() - self.tffytfp,
)
self.tffytfp = 0
await self._add_to_queue(self._output_queue, chunk)

def on_tts_chunk(self, chunk):
Expand Down
44 changes: 28 additions & 16 deletions software/source/server/async_server.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,18 @@
# import from the profiles directory the interpreter to be served

# add other profiles to the directory to define other interpreter instances and import them here
# {.profiles.fast: optimizes for STT/TTS latency with the fastest models }
# {.profiles.local: uses local models and local STT/TTS }
# {.profiles.default: uses default interpreter settings with optimized TTS latency }

# from .profiles.fast import interpreter as base_interpreter
# from .profiles.local import interpreter as base_interpreter
from .profiles.default import interpreter as base_interpreter

import asyncio
import traceback
import json
from fastapi import FastAPI, WebSocket
from fastapi import FastAPI, WebSocket, Depends
from fastapi.responses import PlainTextResponse
from uvicorn import Config, Server
from .async_interpreter import AsyncInterpreter
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Dict, Any
import os
import importlib.util

os.environ["STT_RUNNER"] = "server"
os.environ["TTS_RUNNER"] = "server"

# interpreter.tts set in the profiles directory!!!!
interpreter = AsyncInterpreter(base_interpreter)

app = FastAPI()

Expand All @@ -37,13 +25,19 @@
)


async def get_debug_flag():
return app.state.debug


@app.get("/ping")
async def ping():
return PlainTextResponse("pong")


@app.websocket("/")
async def websocket_endpoint(websocket: WebSocket):
async def websocket_endpoint(
websocket: WebSocket, debug: bool = Depends(get_debug_flag)
):
await websocket.accept()

# Send the tts_service value to the client
Expand Down Expand Up @@ -91,7 +85,25 @@ async def send_output():
await websocket.close()


async def main(server_host, server_port):
async def main(server_host, server_port, profile, debug):

app.state.debug = debug

# Load the profile module from the provided path
spec = importlib.util.spec_from_file_location("profile", profile)
profile_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(profile_module)

# Get the interpreter from the profile
interpreter = profile_module.interpreter

if not hasattr(interpreter, 'tts'):
print("Setting TTS provider to default: openai")
interpreter.tts = "openai"

# Make it async
interpreter = AsyncInterpreter(interpreter, debug)

print(f"Starting server on {server_host}:{server_port}")
config = Config(app, host=server_host, port=server_port, lifespan="on")
server = Server(config)
Expand Down
9 changes: 7 additions & 2 deletions software/source/server/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def create_tunnel(
tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False
tunnel_method="ngrok", server_host="localhost", server_port=10001, qr=False, domain=None
):
print_markdown("Exposing server to the internet...")

Expand Down Expand Up @@ -99,8 +99,13 @@ def create_tunnel(

# If ngrok is installed, start it on the specified port
# process = subprocess.Popen(f'ngrok http {server_port} --log=stdout', shell=True, stdout=subprocess.PIPE)

if domain:
domain = f"--domain={domain}"
else:
domain = ""
process = subprocess.Popen(
f"ngrok http {server_port} --scheme http,https --log=stdout",
f"ngrok http {server_port} --scheme http,https {domain} --log=stdout",
shell=True,
stdout=subprocess.PIPE,
)
Expand Down
57 changes: 54 additions & 3 deletions software/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import importlib
from source.server.tunnel import create_tunnel
from source.server.async_server import main
import subprocess

import signal

Expand Down Expand Up @@ -41,6 +42,25 @@ def run(
qr: bool = typer.Option(
False, "--qr", help="Display QR code to scan to connect to the server"
),
domain: str = typer.Option(
None, "--domain", help="Connect ngrok to a custom domain"
),
profiles: bool = typer.Option(
False,
"--profiles",
help="Opens the folder where this script is contained",
),
profile: str = typer.Option(
"default.py", # default
"--profile",
help="Specify the path to the profile, or the name of the file if it's in the `profiles` directory (run `--profiles` to open the profiles directory)",
),
debug: bool = typer.Option(
False,
"--debug",
help="Print latency measurements and save microphone recordings locally for manual playback.",
),

):
_run(
server=server,
Expand All @@ -52,6 +72,10 @@ def run(
server_url=server_url,
client_type=client_type,
qr=qr,
debug=debug,
domain=domain,
profiles=profiles,
profile=profile,
)


Expand All @@ -65,8 +89,34 @@ def _run(
server_url: str = None,
client_type: str = "auto",
qr: bool = False,
debug: bool = False,
domain = None,
profiles = None,
profile = None,
):

profiles_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "source", "server", "profiles")

if profiles:
if platform.system() == "Windows":
subprocess.Popen(['explorer', profiles_dir])
elif platform.system() == "Darwin":
subprocess.Popen(['open', profiles_dir])
elif platform.system() == "Linux":
subprocess.Popen(['xdg-open', profiles_dir])
else:
subprocess.Popen(['open', profiles_dir])
exit(0)

if profile:
if not os.path.isfile(profile):
profile = os.path.join(profiles_dir, profile)
if not os.path.isfile(profile):
profile += ".py"
if not os.path.isfile(profile):
print(f"Invalid profile path: {profile}")
exit(1)

system_type = platform.system()
if system_type == "Windows":
server_host = "localhost"
Expand All @@ -84,7 +134,6 @@ def handle_exit(signum, frame):
signal.signal(signal.SIGINT, handle_exit)

if server:
# print(f"Starting server with mobile = {mobile}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server_thread = threading.Thread(
Expand All @@ -93,14 +142,16 @@ def handle_exit(signum, frame):
main(
server_host,
server_port,
profile,
debug,
),
),
)
server_thread.start()

if expose:
tunnel_thread = threading.Thread(
target=create_tunnel, args=[tunnel_service, server_host, server_port, qr]
target=create_tunnel, args=[tunnel_service, server_host, server_port, qr, domain]
)
tunnel_thread.start()

Expand All @@ -125,7 +176,7 @@ def handle_exit(signum, frame):
f".clients.{client_type}.device", package="source"
)

client_thread = threading.Thread(target=module.main, args=[server_url])
client_thread = threading.Thread(target=module.main, args=[server_url, debug])
client_thread.start()

try:
Expand Down

0 comments on commit dbb920b

Please sign in to comment.