Skip to content

Commit

Permalink
Merge pull request #336 from ammend/main
Browse files Browse the repository at this point in the history
add Volcengine TTS/ASR Plugin and add tts_parallel config
  • Loading branch information
wzpan authored Oct 25, 2024
2 parents 87e3dbf + 9903564 commit 3fd73e0
Show file tree
Hide file tree
Showing 6 changed files with 535 additions and 20 deletions.
58 changes: 58 additions & 0 deletions robot/AI.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,64 @@ def chat(self, texts, _):
logger.critical("Tongyi robot failed to response for %r", msg, exc_info=True)
return "抱歉, Tongyi回答失败"

class CozeRobot(AbstractRobot):
SLUG = "coze"

def __init__(self, botid, token, **kwargs):
super(self.__class__, self).__init__()
self.botid = botid
self.token = token
self.userid = str(get_mac())[:32]

@classmethod
def get_config(cls):
return config.get("coze", {})

def chat(self, texts, parsed=None):
"""
使用coze聊天
Arguments:
texts -- user input, typically speech, to be parsed by a module
"""
msg = "".join(texts)
msg = utils.stripPunctuation(msg)
try:
url = "https://api.coze.cn/open_api/v2/chat"

body = {
"conversation_id": "123",
"bot_id": self.botid,
"user": self.userid,
"query": msg,
"stream": False
}
headers = {
"Authorization": "Bearer " + self.token,
"Content-Type": "application/json",
"Accept": "*/*",
"Host": "api.coze.cn",
"Connection": "keep-alive"
}
r = requests.post(url, headers=headers, json=body)
respond = json.loads(r.text)
result = ""
logger.info(f"{self.SLUG} 回答:{respond}")
if "messages" in respond:
for m in respond["messages"]:
if m["type"] == "answer":
result = m["content"].replace("\n", "").replace("\r", "")
else:
result = "抱歉,扣子回答失败"
if result == "":
result = "抱歉,扣子回答失败"
logger.info(f"{self.SLUG} 回答:{result}")
return result
except Exception:
logger.critical(
"Tuling robot failed to response for %r", msg, exc_info=True
)
return "抱歉, 扣子回答失败"

def get_unknown_response():
"""
Expand Down
26 changes: 25 additions & 1 deletion robot/ASR.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import json
from aip import AipSpeech
from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech, BaiduSpeech, FunASREngine
from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech, BaiduSpeech, FunASREngine, VolcengineSpeech
from . import utils, config
from robot import logging
from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -267,6 +267,30 @@ def transcribe(self, fp):
logger.critical(f"{self.SLUG} 语音识别出错了", stack_info=True)
return ""

class VolcengineASR(AbstractASR):
"""
VolcengineASR 实时语音转写服务软件包
"""

SLUG = "volcengine-asr"

def __init__(self, **kargs):
super(self.__class__, self).__init__()
self.volcengine_asr = VolcengineSpeech.VolcengineASR(**kargs)

@classmethod
def get_config(cls):
return config.get("volcengine-asr", {})

def transcribe(self, fp):
result = self.volcengine_asr.execute(fp)
if result:
logger.info(f"{self.SLUG} 语音识别到了:{result}")
return result
else:
logger.critical(f"{self.SLUG} 语音识别出错了", stack_info=True)
return ""

def get_engine_by_slug(slug=None):
"""
Returns:
Expand Down
2 changes: 1 addition & 1 deletion robot/Conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _tts(self, lines, cache, onCompleted=None):
pattern = r"http[s]?://.+"
logger.info("_tts")
with self.tts_lock:
with ThreadPoolExecutor(max_workers=5) as pool:
with ThreadPoolExecutor(max_workers=config.get("tts_parallel", 5)) as pool:
all_task = []
index = 0
for line in lines:
Expand Down
28 changes: 27 additions & 1 deletion robot/TTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pypinyin import lazy_pinyin
from pydub import AudioSegment
from abc import ABCMeta, abstractmethod
from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech, atc, VITSClient
from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech, atc, VITSClient, VolcengineSpeech
import requests
from xml.etree import ElementTree

Expand Down Expand Up @@ -469,6 +469,32 @@ def get_engine_by_slug(slug=None):
logger.info(f"使用 {engine.SLUG} TTS 引擎")
return engine.get_instance()

class VolcengineTTS(AbstractTTS):
"""
VolcengineTTS 语音合成
"""

SLUG = "volcengine-tts"

def __init__(self, appid, token, cluster, voice_type, **args):
super(self.__class__, self).__init__()
self.engine = VolcengineSpeech.VolcengineTTS(appid=appid, token=token, cluster=cluster, voice_type=voice_type)

@classmethod
def get_config(cls):
# Try to get ali_yuyin config from config
return config.get("volcengine-tts", {})

def get_speech(self, text):
result = self.engine.execute(text)
if result is None:
logger.critical(f"{self.SLUG} 合成失败!", stack_info=True)
else:
tmpfile = os.path.join(constants.TEMP_PATH, uuid.uuid4().hex + ".mp3")
with open(tmpfile, "wb") as f:
f.write(result)
logger.info(f"{self.SLUG} 语音合成成功,合成路径:{tmpfile}")
return tmpfile

def get_engines():
def get_subclasses(cls):
Expand Down
Loading

0 comments on commit 3fd73e0

Please sign in to comment.