From 1fff97b1372bcdb62bc4cc6ff8eaf61b61b7689a Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 14:17:52 -0700 Subject: [PATCH 1/8] creates adaptiveobject abstraction Co-authored-by: srhinos <6531393+srhinos@users.noreply.github.com> --- vocode/streaming/models/adaptive_object.py | 44 ++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 vocode/streaming/models/adaptive_object.py diff --git a/vocode/streaming/models/adaptive_object.py b/vocode/streaming/models/adaptive_object.py new file mode 100644 index 0000000000..f8eb60b1f6 --- /dev/null +++ b/vocode/streaming/models/adaptive_object.py @@ -0,0 +1,44 @@ +from abc import ABC +from typing import Any + +from pydantic import BaseModel, ValidationError, model_validator + + +class AdaptiveObject(BaseModel, ABC): + """An abstract object that may be one of several concrete types.""" + + @model_validator(mode="wrap") + @classmethod + def _resolve_adaptive_object(cls, data: dict, handler) -> Any: + if not isinstance(data, dict): + return handler(data) + # if cls is not abstract, there's nothing to do + if ABC not in cls.__bases__: + return handler(data) + + # try to validate the data for each possible type + print(data) + for subcls in cls._find_all_possible_types(): + print(subcls) + try: + # return the first successful validation + return subcls.model_validate(data) + except ValidationError: + continue + + raise ValidationError( + "adaptive-object", + "unable to resolve input", + ) + + @classmethod + def _find_all_possible_types(cls): + """Recursively generate all possible types for this object.""" + + # any concrete class is a possible type + if ABC not in cls.__bases__: + yield cls + + # continue looking for possible types in subclasses + for subclass in cls.__subclasses__(): + yield from subclass._find_all_possible_types() From b0c47818d2ba2dd3f1a22ff83e21dcaa9998b001 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 14:18:50 -0700 Subject: [PATCH 2/8] onboards adaptiveobject --- apps/telephony_app/speller_agent.py | 8 +- vocode/streaming/action/dtmf.py | 8 +- vocode/streaming/action/end_conversation.py | 10 +- .../action/execute_external_action.py | 9 +- .../action/external_actions_requester.py | 2 +- vocode/streaming/action/record_email.py | 8 +- vocode/streaming/action/transfer_call.py | 5 +- vocode/streaming/action/wait.py | 8 +- vocode/streaming/agent/base_agent.py | 57 ++++++----- .../agent/restful_user_implemented_agent.py | 10 +- .../streaming/client_backend/conversation.py | 5 +- vocode/streaming/models/actions.py | 10 +- vocode/streaming/models/agent.py | 93 ++++++++---------- vocode/streaming/models/client_backend.py | 3 +- vocode/streaming/models/events.py | 28 ++++-- vocode/streaming/models/message.py | 21 ++-- vocode/streaming/models/model.py | 2 +- vocode/streaming/models/synthesizer.py | 97 +++++++------------ vocode/streaming/models/telephony.py | 20 ++-- vocode/streaming/models/transcriber.py | 68 ++++++------- vocode/streaming/models/transcript.py | 8 +- vocode/streaming/models/vector_db.py | 16 ++- vocode/streaming/models/websocket.py | 30 +++--- vocode/streaming/models/websocket_agent.py | 29 +++--- .../synthesizer/coqui_synthesizer.py | 1 - .../config_manager/redis_config_manager.py | 1 + vocode/streaming/telephony/server/base.py | 2 +- .../transcriber/deepgram_transcriber.py | 12 +-- .../utils/redis_conversation_message_queue.py | 2 +- 29 files changed, 274 insertions(+), 299 deletions(-) diff --git a/apps/telephony_app/speller_agent.py b/apps/telephony_app/speller_agent.py index fc4f441aa2..90d498cb85 100644 --- a/apps/telephony_app/speller_agent.py +++ b/apps/telephony_app/speller_agent.py @@ -1,14 +1,16 @@ -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple from vocode.streaming.agent.abstract_factory import AbstractAgentFactory from vocode.streaming.agent.base_agent import BaseAgent, RespondAgent from vocode.streaming.agent.chat_gpt_agent import ChatGPTAgent -from vocode.streaming.models.agent import AgentConfig, AgentType, ChatGPTAgentConfig +from vocode.streaming.models.agent import AgentConfig, ChatGPTAgentConfig -class SpellerAgentConfig(AgentConfig, type="agent_speller"): +class SpellerAgentConfig(AgentConfig): """Configuration for SpellerAgent. Inherits from AgentConfig.""" + type: Literal["agent_speller"] = "agent_speller" + pass diff --git a/vocode/streaming/action/dtmf.py b/vocode/streaming/action/dtmf.py index 33c3a23db3..5af324f6e1 100644 --- a/vocode/streaming/action/dtmf.py +++ b/vocode/streaming/action/dtmf.py @@ -1,7 +1,7 @@ -from typing import List, Optional, Type +from typing import List, Literal, Optional, Type from loguru import logger -from pydantic.v1 import BaseModel, Field +from pydantic import BaseModel, Field from vocode.streaming.action.phone_call_action import ( TwilioPhoneConversationAction, @@ -25,7 +25,9 @@ class DTMFResponse(BaseModel): message: Optional[str] = None -class DTMFVocodeActionConfig(VocodeActionConfig, type="action_dtmf"): # type: ignore +class DTMFVocodeActionConfig(VocodeActionConfig): + type: Literal["action_dtmf"] = "action_dtmf" + def action_attempt_to_string(self, input: ActionInput) -> str: assert isinstance(input.params, DTMFParameters) return "Attempting to press numbers: " f"{list(input.params.buttons)}" diff --git a/vocode/streaming/action/end_conversation.py b/vocode/streaming/action/end_conversation.py index 7175ad8461..1b96afbd69 100644 --- a/vocode/streaming/action/end_conversation.py +++ b/vocode/streaming/action/end_conversation.py @@ -1,7 +1,7 @@ -from typing import Type +from typing import Literal, Type from loguru import logger -from pydantic.v1 import BaseModel +from pydantic import BaseModel from vocode.streaming.action.base_action import BaseAction from vocode.streaming.models.actions import ActionConfig as VocodeActionConfig @@ -24,9 +24,9 @@ class EndConversationResponse(BaseModel): success: bool -class EndConversationVocodeActionConfig( - VocodeActionConfig, type="action_end_conversation" # type: ignore -): +class EndConversationVocodeActionConfig(VocodeActionConfig): + type: Literal["action_end_conversation"] = "action_end_conversation" + def action_attempt_to_string(self, input: ActionInput) -> str: assert isinstance(input.params, EndConversationParameters) return "Attempting to end conversation" diff --git a/vocode/streaming/action/execute_external_action.py b/vocode/streaming/action/execute_external_action.py index c689b206d6..ca628e061f 100644 --- a/vocode/streaming/action/execute_external_action.py +++ b/vocode/streaming/action/execute_external_action.py @@ -1,7 +1,7 @@ import json -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Literal, Optional, Type -from pydantic.v1 import BaseModel +from pydantic import BaseModel from vocode.streaming.action.base_action import BaseAction from vocode.streaming.action.external_actions_requester import ( @@ -13,9 +13,8 @@ from vocode.streaming.models.message import BaseMessage -class ExecuteExternalActionVocodeActionConfig( - VocodeActionConfig, type="action_external" # type: ignore -): +class ExecuteExternalActionVocodeActionConfig(VocodeActionConfig): + type: Literal["action_external"] = "action_external" processing_mode: ExternalActionProcessingMode name: str description: str diff --git a/vocode/streaming/action/external_actions_requester.py b/vocode/streaming/action/external_actions_requester.py index 658192e7b4..9520190ae5 100644 --- a/vocode/streaming/action/external_actions_requester.py +++ b/vocode/streaming/action/external_actions_requester.py @@ -6,7 +6,7 @@ import httpx from loguru import logger -from pydantic.v1 import BaseModel +from pydantic import BaseModel class ExternalActionValueError(ValueError): diff --git a/vocode/streaming/action/record_email.py b/vocode/streaming/action/record_email.py index 98619ef121..651379b903 100644 --- a/vocode/streaming/action/record_email.py +++ b/vocode/streaming/action/record_email.py @@ -1,7 +1,7 @@ import re -from typing import Optional, Type +from typing import Literal, Optional, Type -from pydantic.v1 import BaseModel, Field +from pydantic import BaseModel, Field from vocode.streaming.action.base_action import BaseAction from vocode.streaming.models.actions import ActionConfig, ActionInput, ActionOutput @@ -9,8 +9,8 @@ EMAIL_REGEX = r"^(?!\.)(?!.*\.\.)[a-zA-Z0-9._%+-]+(? str: diff --git a/vocode/streaming/models/agent.py b/vocode/streaming/models/agent.py index aa9b63c8fd..3bfa209cba 100644 --- a/vocode/streaming/models/agent.py +++ b/vocode/streaming/models/agent.py @@ -1,12 +1,13 @@ +from abc import ABC from enum import Enum -from typing import List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union -from pydantic.v1 import validator +from pydantic import BaseModel, model_validator from vocode.streaming.models.actions import ActionConfig +from vocode.streaming.models.adaptive_object import AdaptiveObject from vocode.streaming.models.message import BaseMessage -from .model import BaseModel, TypedModel from .vector_db import VectorDBConfig FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS = 0.5 @@ -40,36 +41,18 @@ InterruptSensitivity = Literal["low", "high"] -class AgentType(str, Enum): - BASE = "agent_base" - LLM = "agent_llm" - CHAT_GPT_ALPHA = "agent_chat_gpt_alpha" - CHAT_GPT = "agent_chat_gpt" - ANTHROPIC = "agent_anthropic" - CHAT_VERTEX_AI = "agent_chat_vertex_ai" - ECHO = "agent_echo" - GPT4ALL = "agent_gpt4all" - LLAMACPP = "agent_llamacpp" - GROQ = "agent_groq" - INFORMATION_RETRIEVAL = "agent_information_retrieval" - RESTFUL_USER_IMPLEMENTED = "agent_restful_user_implemented" - WEBSOCKET_USER_IMPLEMENTED = "agent_websocket_user_implemented" - ACTION = "agent_action" - LANGCHAIN = "agent_langchain" - - class FillerAudioConfig(BaseModel): silence_threshold_seconds: float = FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS use_phrases: bool = True use_typing_noise: bool = False - @validator("use_typing_noise") - def typing_noise_excludes_phrases(cls, v, values): - if v and values.get("use_phrases"): - values["use_phrases"] = False - if not v and not values.get("use_phrases"): + @model_validator(mode="after") + def typing_noise_excludes_phrases(self): + if self.use_typing_noise and self.use_phrases: + self.use_phrases = False + if not self.use_typing_noise and not self.use_phrases: raise ValueError("must use either typing noise or phrases for filler audio") - return v + return self class WebhookConfig(BaseModel): @@ -90,7 +73,8 @@ class CutOffResponse(BaseModel): messages: List[BaseMessage] = [BaseMessage(text="Sorry?")] -class AgentConfig(TypedModel, type=AgentType.BASE.value): # type: ignore +class AgentConfig(AdaptiveObject, ABC): + type: Any initial_message: Optional[BaseMessage] = None generate_responses: bool = True allowed_idle_time_seconds: Optional[float] = None @@ -106,7 +90,8 @@ class AgentConfig(TypedModel, type=AgentType.BASE.value): # type: ignore cut_off_response: Optional[CutOffResponse] = None -class LLMAgentConfig(AgentConfig, type=AgentType.LLM.value): # type: ignore +class LLMAgentConfig(AgentConfig): + type: Literal["agent_llm"] = "agent_llm" prompt_preamble: str model_name: str = LLM_AGENT_DEFAULT_MODEL_NAME temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE @@ -118,7 +103,8 @@ class LLMFallback(BaseModel): model_name: str -class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value): # type: ignore +class ChatGPTAgentConfig(AgentConfig): + type: Literal["agent_chat_gpt"] = "agent_chat_gpt" openai_api_key: Optional[str] = None prompt_preamble: str model_name: str = CHAT_GPT_AGENT_DEFAULT_MODEL_NAME @@ -134,14 +120,16 @@ class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value): # type: i llm_fallback: Optional[LLMFallback] = None -class AnthropicAgentConfig(AgentConfig, type=AgentType.ANTHROPIC.value): # type: ignore +class AnthropicAgentConfig(AgentConfig): + type: Literal["agent_anthropic"] = "agent_anthropic" prompt_preamble: str model_name: str = CHAT_ANTHROPIC_DEFAULT_MODEL_NAME max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE -class LangchainAgentConfig(AgentConfig, type=AgentType.LANGCHAIN.value): # type: ignore +class LangchainAgentConfig(AgentConfig): + type: Literal["agent_langchain"] = "agent_langchain" prompt_preamble: str model_name: str provider: str @@ -149,13 +137,15 @@ class LangchainAgentConfig(AgentConfig, type=AgentType.LANGCHAIN.value): # type max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS -class ChatVertexAIAgentConfig(AgentConfig, type=AgentType.CHAT_VERTEX_AI.value): # type: ignore +class ChatVertexAIAgentConfig(AgentConfig): + type: Literal["agent_chat_vertex_ai"] = "agent_chat_vertex_ai" prompt_preamble: str model_name: str = CHAT_VERTEX_AI_DEFAULT_MODEL_NAME generate_responses: bool = False # Google Vertex AI doesn't support streaming -class GroqAgentConfig(AgentConfig, type=AgentType.GROQ.value): # type: ignore +class GroqAgentConfig(AgentConfig): + type: Literal["agent_groq"] = "agent_groq" groq_api_key: Optional[str] = None prompt_preamble: str model_name: str = GROQ_DEFAULT_MODEL_NAME @@ -168,9 +158,8 @@ class GroqAgentConfig(AgentConfig, type=AgentType.GROQ.value): # type: ignore first_response_filler_message: Optional[str] = None -class InformationRetrievalAgentConfig( - AgentConfig, type=AgentType.INFORMATION_RETRIEVAL.value # type: ignore -): +class InformationRetrievalAgentConfig(AgentConfig): + type: Literal["agent_information_retrieval"] = "agent_information_retrieval" recipient_descriptor: str caller_descriptor: str goal_description: str @@ -178,23 +167,23 @@ class InformationRetrievalAgentConfig( # TODO: add fields for IVR, voicemail -class EchoAgentConfig(AgentConfig, type=AgentType.ECHO.value): # type: ignore - pass +class EchoAgentConfig(AgentConfig): + type: Literal["agent_echo"] = "agent_echo" -class GPT4AllAgentConfig(AgentConfig, type=AgentType.GPT4ALL.value): # type: ignore +class GPT4AllAgentConfig(AgentConfig): + type: Literal["agent_gpt4all"] = "agent_gpt4all" prompt_preamble: str model_path: str generate_responses: bool = False -class RESTfulUserImplementedAgentConfig( - AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED.value # type: ignore -): +class RESTfulUserImplementedAgentConfig(AgentConfig): class EndpointConfig(BaseModel): url: str method: str = "POST" + type: Literal["agent_restful_user_implemented"] = "agent_restful_user_implemented" respond: EndpointConfig generate_responses: bool = False # generate_response: Optional[EndpointConfig] @@ -205,19 +194,13 @@ class RESTfulAgentInput(BaseModel): human_input: str -class RESTfulAgentOutputType(str, Enum): - BASE = "restful_agent_base" - TEXT = "restful_agent_text" - END = "restful_agent_end" - - -class RESTfulAgentOutput(TypedModel, type=RESTfulAgentOutputType.BASE): # type: ignore - pass +class RESTfulAgentText(BaseModel): + type: Literal["restful_agent_text"] = "restful_agent_text" + response: str -class RESTfulAgentText(RESTfulAgentOutput, type=RESTfulAgentOutputType.TEXT): # type: ignore - response: str +class RESTfulAgentEnd(BaseModel): + type: Literal["restful_agent_end"] = "restful_agent_end" -class RESTfulAgentEnd(RESTfulAgentOutput, type=RESTfulAgentOutputType.END): # type: ignore - pass +RESTfulAgentOutput = Union[RESTfulAgentText, RESTfulAgentEnd] diff --git a/vocode/streaming/models/client_backend.py b/vocode/streaming/models/client_backend.py index e655f32b39..2499d24ab4 100644 --- a/vocode/streaming/models/client_backend.py +++ b/vocode/streaming/models/client_backend.py @@ -1,7 +1,8 @@ from typing import Optional +from pydantic import BaseModel + from vocode.streaming.models.audio import AudioEncoding -from vocode.streaming.models.model import BaseModel class InputAudioConfig(BaseModel): diff --git a/vocode/streaming/models/events.py b/vocode/streaming/models/events.py index 901fd3a333..8d9e878640 100644 --- a/vocode/streaming/models/events.py +++ b/vocode/streaming/models/events.py @@ -1,7 +1,13 @@ +from abc import ABC from enum import Enum -from typing import Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, Union -from vocode.streaming.models.model import TypedModel +from pydantic import BaseModel + +from vocode.streaming.models.adaptive_object import AdaptiveObject + +if TYPE_CHECKING: + from vocode.streaming.models.transcript import Transcript class Sender(str, Enum): @@ -22,27 +28,33 @@ class EventType(str, Enum): ACTION = "event_action" -class Event(TypedModel): +class Event(AdaptiveObject, ABC): + type: Any conversation_id: str -class PhoneCallConnectedEvent(Event, type=EventType.PHONE_CALL_CONNECTED): # type: ignore +class PhoneCallConnectedEvent(Event): + type: Literal["event_phone_call_connected"] = "event_phone_call_connected" to_phone_number: str from_phone_number: str -class PhoneCallEndedEvent(Event, type=EventType.PHONE_CALL_ENDED): # type: ignore +class PhoneCallEndedEvent(Event): + type: Literal["event_phone_call_ended"] = "event_phone_call_ended" conversation_minutes: float = 0 -class PhoneCallDidNotConnectEvent(Event, type=EventType.PHONE_CALL_DID_NOT_CONNECT): # type: ignore +class PhoneCallDidNotConnectEvent(Event): + type: Literal["event_phone_call_did_not_connect"] = "event_phone_call_did_not_connect" telephony_status: str -class RecordingEvent(Event, type=EventType.RECORDING): # type: ignore +class RecordingEvent(Event): + type: Literal["event_recording"] = "event_recording" recording_url: str -class ActionEvent(Event, type=EventType.ACTION): # type: ignore +class ActionEvent(Event): + type: Literal["event_action"] = "event_action" action_input: Optional[dict] = None action_output: Optional[dict] = None diff --git a/vocode/streaming/models/message.py b/vocode/streaming/models/message.py index db32f9e29a..1bb5624fd3 100644 --- a/vocode/streaming/models/message.py +++ b/vocode/streaming/models/message.py @@ -1,7 +1,7 @@ from enum import Enum -from typing import Optional +from typing import Literal, Optional -from .model import TypedModel +from pydantic import BaseModel class MessageType(str, Enum): @@ -11,22 +11,27 @@ class MessageType(str, Enum): LLM_TOKEN = "llm_token" -class BaseMessage(TypedModel, type=MessageType.BASE): # type: ignore +MessageType = Literal["message_base", "message_ssml", "bot_backchannel", "llm_token"] + + +class BaseMessage(BaseModel): + type: MessageType = "message_base" text: str trailing_silence_seconds: float = 0.0 cache_phrase: Optional[str] = None -class SSMLMessage(BaseMessage, type=MessageType.SSML): # type: ignore +class SSMLMessage(BaseMessage): + type: Literal["message_ssml"] = "message_ssml" ssml: str -class BotBackchannel(BaseMessage, type=MessageType.BOT_BACKCHANNEL): # type: ignore - pass +class BotBackchannel(BaseMessage): + type: Literal["bot_backchannel"] = "bot_backchannel" -class LLMToken(BaseMessage, type=MessageType.LLM_TOKEN): # type: ignore - pass +class LLMToken(BaseMessage): + type: Literal["llm_token"] = "llm_token" class SilenceMessage(BotBackchannel): diff --git a/vocode/streaming/models/model.py b/vocode/streaming/models/model.py index 05ec13f6e6..d6e21d18cc 100644 --- a/vocode/streaming/models/model.py +++ b/vocode/streaming/models/model.py @@ -1,6 +1,6 @@ from typing import Any, List, Tuple -from pydantic.v1 import BaseModel as Pydantic1BaseModel +from pydantic import BaseModel as Pydantic1BaseModel class BaseModel(Pydantic1BaseModel): diff --git a/vocode/streaming/models/synthesizer.py b/vocode/streaming/models/synthesizer.py index ce8af4babb..76e44465e6 100644 --- a/vocode/streaming/models/synthesizer.py +++ b/vocode/streaming/models/synthesizer.py @@ -1,30 +1,16 @@ +from abc import ABC from enum import Enum from typing import Any, Dict, List, Literal, Optional +from pydantic import BaseModel, field_validator, model_validator from pydantic.v1 import validator +from vocode.streaming.models.adaptive_object import AdaptiveObject from vocode.streaming.models.client_backend import OutputAudioConfig from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.telephony.constants import DEFAULT_AUDIO_ENCODING, DEFAULT_SAMPLING_RATE from .audio import AudioEncoding, SamplingRate -from .model import BaseModel, TypedModel - - -class SynthesizerType(str, Enum): - BASE = "synthesizer_base" - AZURE = "synthesizer_azure" - GOOGLE = "synthesizer_google" - ELEVEN_LABS = "synthesizer_eleven_labs" - RIME = "synthesizer_rime" - PLAY_HT = "synthesizer_play_ht" - GTTS = "synthesizer_gtts" - STREAM_ELEMENTS = "synthesizer_stream_elements" - COQUI_TTS = "synthesizer_coqui_tts" - COQUI = "synthesizer_coqui" - BARK = "synthesizer_bark" - POLLY = "synthesizer_polly" - CARTESIA = "synthesizer_cartesia" class SentimentConfig(BaseModel): @@ -37,7 +23,8 @@ def emotions_must_not_be_empty(cls, v): return v -class SynthesizerConfig(TypedModel, type=SynthesizerType.BASE.value): # type: ignore +class SynthesizerConfig(AdaptiveObject, ABC): + type: Any sampling_rate: int audio_encoding: AudioEncoding should_encode_as_wav: bool = False @@ -75,7 +62,8 @@ def from_output_audio_config(cls, output_audio_config: OutputAudioConfig, **kwar AZURE_SYNTHESIZER_DEFAULT_RATE = 15 -class AzureSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.AZURE.value): # type: ignore +class AzureSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_azure"] = "synthesizer_azure" voice_name: str = AZURE_SYNTHESIZER_DEFAULT_VOICE_NAME pitch: int = AZURE_SYNTHESIZER_DEFAULT_PITCH rate: int = AZURE_SYNTHESIZER_DEFAULT_RATE @@ -88,7 +76,8 @@ class AzureSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.AZURE.value DEFAULT_GOOGLE_SPEAKING_RATE = 1.2 -class GoogleSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.GOOGLE.value): # type: ignore +class GoogleSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_google"] = "synthesizer_google" language_code: str = DEFAULT_GOOGLE_LANGUAGE_CODE voice_name: str = DEFAULT_GOOGLE_VOICE_NAME pitch: float = DEFAULT_GOOGLE_PITCH @@ -98,9 +87,8 @@ class GoogleSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.GOOGLE.val ELEVEN_LABS_ADAM_VOICE_ID = "pNInz6obpgDQGcFmaJgB" -class ElevenLabsSynthesizerConfig( - SynthesizerConfig, type=SynthesizerType.ELEVEN_LABS.value # type: ignore -): +class ElevenLabsSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_eleven_labs"] = "synthesizer_eleven_labs" api_key: Optional[str] = None voice_id: Optional[str] = ELEVEN_LABS_ADAM_VOICE_ID optimize_streaming_latency: Optional[int] @@ -111,24 +99,26 @@ class ElevenLabsSynthesizerConfig( experimental_websocket: bool = False backchannel_amplitude_factor: float = 0.5 - @validator("voice_id") + @field_validator("voice_id", mode="after") + @classmethod def set_name(cls, voice_id): return voice_id or ELEVEN_LABS_ADAM_VOICE_ID - @validator("similarity_boost", always=True) - def stability_and_similarity_boost_check(cls, similarity_boost, values): - stability = values.get("stability") - if (stability is None) != (similarity_boost is None): + @model_validator(mode="after") + def stability_and_similarity_boost_check(self): + if (self.stability is None) != (self.similarity_boost is None): raise ValueError("Both stability and similarity_boost must be set or not set.") - return similarity_boost + return self - @validator("optimize_streaming_latency") + @field_validator("optimize_streaming_latency", mode="after") + @classmethod def optimize_streaming_latency_check(cls, optimize_streaming_latency): if optimize_streaming_latency is not None and not (0 <= optimize_streaming_latency <= 4): raise ValueError("optimize_streaming_latency must be between 0 and 4.") return optimize_streaming_latency - @validator("backchannel_amplitude_factor") + @field_validator("backchannel_amplitude_factor", mode="after") + @classmethod def backchannel_amplitude_factor_check(cls, backchannel_amplitude_factor): if backchannel_amplitude_factor is not None and not (0 < backchannel_amplitude_factor <= 1): raise ValueError( @@ -146,7 +136,8 @@ def backchannel_amplitude_factor_check(cls, backchannel_amplitude_factor): RimeModelId = Literal["mist", "v1"] -class RimeSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.RIME.value): # type: ignore +class RimeSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_rime"] = "synthesizer_rime" base_url: str = RIME_DEFAULT_BASE_URL model_id: Optional[Literal[RimeModelId]] = RIME_DEFAULT_MODEL_ID speaker: str = RIME_DEFAULT_SPEAKER @@ -155,26 +146,11 @@ class RimeSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.RIME.value): reduce_latency: Optional[bool] = RIME_DEFAULT_REDUCE_LATENCY -COQUI_DEFAULT_SPEAKER_ID = "ebe2db86-62a6-49a1-907a-9a1360d4416e" - - -class CoquiSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.COQUI.value): # type: ignore - api_key: Optional[str] = None - voice_id: Optional[str] = COQUI_DEFAULT_SPEAKER_ID - voice_prompt: Optional[str] = None - use_xtts: Optional[bool] = True - - @validator("voice_id", always=True) - def override_voice_id_with_prompt(cls, voice_id, values): - if values.get("voice_prompt"): - return None - return voice_id or COQUI_DEFAULT_SPEAKER_ID - - PlayHtVoiceVersionType = Literal["1", "2"] -class PlayHtSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.PLAY_HT.value): # type: ignore +class PlayHtSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_play_ht"] = "synthesizer_play_ht" voice_id: str api_key: Optional[str] = None user_id: Optional[str] = None @@ -192,28 +168,27 @@ class PlayHtSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.PLAY_HT.va experimental_remove_silence: bool = False -class CoquiTTSSynthesizerConfig( - SynthesizerConfig, type=SynthesizerType.COQUI_TTS.value # type: ignore -): +class CoquiTTSSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_coqui_tts"] = "synthesizer_coqui_tts" tts_kwargs: dict = {} speaker: Optional[str] = None language: Optional[str] = None -class GTTSSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.GTTS.value): # type: ignore - pass +class GTTSSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_gtts"] = "synthesizer_gtts" STREAM_ELEMENTS_SYNTHESIZER_DEFAULT_VOICE = "Brian" -class StreamElementsSynthesizerConfig( - SynthesizerConfig, type=SynthesizerType.STREAM_ELEMENTS.value # type: ignore -): +class StreamElementsSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_stream_elements"] = "synthesizer_stream_elements" voice: str = STREAM_ELEMENTS_SYNTHESIZER_DEFAULT_VOICE -class BarkSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.BARK.value): # type: ignore +class BarkSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_bark"] = "synthesizer_bark" preload_kwargs: Dict[str, Any] = {} generate_kwargs: Dict[str, Any] = {} @@ -223,7 +198,8 @@ class BarkSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.BARK.value): DEFAULT_POLLY_SAMPLING_RATE = SamplingRate.RATE_16000.value -class PollySynthesizerConfig(SynthesizerConfig, type=SynthesizerType.POLLY.value): # type: ignore +class PollySynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_polly"] = "synthesizer_polly" language_code: str = DEFAULT_POLLY_LANGUAGE_CODE voice_id: str = DEFAULT_POLLY_VOICE_ID sampling_rate: int = DEFAULT_POLLY_SAMPLING_RATE @@ -233,7 +209,8 @@ class PollySynthesizerConfig(SynthesizerConfig, type=SynthesizerType.POLLY.value DEFAULT_CARTESIA_VOICE_ID = "5345cf08-6f37-424d-a5d9-8ae1101b9377" -class CartesiaSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.CARTESIA.value): # type: ignore +class CartesiaSynthesizerConfig(SynthesizerConfig): + type: Literal["synthesizer_cartesia"] = "synthesizer_cartesia" api_key: Optional[str] = None model_id: str = DEFAULT_CARTESIA_MODEL_ID voice_id: str = DEFAULT_CARTESIA_VOICE_ID diff --git a/vocode/streaming/models/telephony.py b/vocode/streaming/models/telephony.py index cd01b3ced5..bbff9f0830 100644 --- a/vocode/streaming/models/telephony.py +++ b/vocode/streaming/models/telephony.py @@ -1,8 +1,11 @@ +from abc import ABC from enum import Enum from typing import Any, Dict, Literal, Optional, Union +from pydantic import BaseModel + +from vocode.streaming.models.adaptive_object import AdaptiveObject from vocode.streaming.models.agent import AgentConfig -from vocode.streaming.models.model import BaseModel, TypedModel from vocode.streaming.models.synthesizer import AzureSynthesizerConfig, SynthesizerConfig from vocode.streaming.models.transcriber import ( DeepgramTranscriberConfig, @@ -85,16 +88,11 @@ class DialIntoZoomCall(BaseModel): twilio_config: Optional[TwilioConfig] = None -class CallConfigType(str, Enum): - BASE = "call_config_base" - TWILIO = "call_config_twilio" - VONAGE = "call_config_vonage" - - PhoneCallDirection = Literal["inbound", "outbound"] -class BaseCallConfig(TypedModel, type=CallConfigType.BASE.value): # type: ignore +class BaseCallConfig(AdaptiveObject, ABC): + type: Any transcriber_config: TranscriberConfig agent_config: AgentConfig synthesizer_config: SynthesizerConfig @@ -114,7 +112,8 @@ def default_synthesizer_config(): raise NotImplementedError -class TwilioCallConfig(BaseCallConfig, type=CallConfigType.TWILIO.value): # type: ignore +class TwilioCallConfig(BaseCallConfig): + type: Literal["call_config_twilio"] = "call_config_twilio" twilio_config: TwilioConfig twilio_sid: str @@ -137,7 +136,8 @@ def default_synthesizer_config(): ) -class VonageCallConfig(BaseCallConfig, type=CallConfigType.VONAGE.value): # type: ignore +class VonageCallConfig(BaseCallConfig): + type: Literal["call_config_vonage"] = "call_config_vonage" vonage_config: VonageConfig vonage_uuid: str output_to_speaker: bool = False diff --git a/vocode/streaming/models/transcriber.py b/vocode/streaming/models/transcriber.py index c27d865197..a48d502a9d 100644 --- a/vocode/streaming/models/transcriber.py +++ b/vocode/streaming/models/transcriber.py @@ -1,10 +1,12 @@ +from abc import ABC from enum import Enum -from typing import List, Optional +from typing import Any, List, Literal, Optional -from pydantic.v1 import validator +from pydantic import field_validator import vocode.streaming.livekit.constants as LiveKitConstants from vocode.streaming.input_device.base_input_device import BaseInputDevice +from vocode.streaming.models.adaptive_object import AdaptiveObject from vocode.streaming.models.client_backend import InputAudioConfig from vocode.streaming.models.model import BaseModel from vocode.streaming.telephony.constants import ( @@ -14,44 +16,27 @@ ) from .audio import AudioEncoding -from .model import TypedModel AZURE_DEFAULT_LANGUAGE = "en-US" DEEPGRAM_API_WS_URL = "wss://api.deepgram.com" -class TranscriberType(str, Enum): - BASE = "transcriber_base" - DEEPGRAM = "transcriber_deepgram" - GOOGLE = "transcriber_google" - ASSEMBLY_AI = "transcriber_assembly_ai" - WHISPER_CPP = "transcriber_whisper_cpp" - REV_AI = "transcriber_rev_ai" - AZURE = "transcriber_azure" - GLADIA = "transcriber_gladia" +class EndpointingConfig(AdaptiveObject, ABC): + type: Any -class EndpointingType(str, Enum): - BASE = "endpointing_base" - TIME_BASED = "endpointing_time_based" - PUNCTUATION_BASED = "endpointing_punctuation_based" - - -class EndpointingConfig(TypedModel, type=EndpointingType.BASE): # type: ignore - pass - - -class TimeEndpointingConfig(EndpointingConfig, type=EndpointingType.TIME_BASED): # type: ignore +class TimeEndpointingConfig(EndpointingConfig): + type: Literal["endpointing_time_based"] = "endpointing_time_based" time_cutoff_seconds: float = 0.4 -class PunctuationEndpointingConfig( - EndpointingConfig, type=EndpointingType.PUNCTUATION_BASED # type: ignore -): +class PunctuationEndpointingConfig(EndpointingConfig): + type: Literal["endpointing_punctuation_based"] = "endpointing_punctuation_based" time_cutoff_seconds: float = 0.4 -class TranscriberConfig(TypedModel, type=TranscriberType.BASE.value): # type: ignore +class TranscriberConfig(AdaptiveObject, ABC): + type: Any sampling_rate: int audio_encoding: AudioEncoding chunk_size: int @@ -60,7 +45,8 @@ class TranscriberConfig(TypedModel, type=TranscriberType.BASE.value): # type: i min_interrupt_confidence: Optional[float] = None mute_during_speech: bool = False - @validator("min_interrupt_confidence") + @field_validator("min_interrupt_confidence", mode="after") + @classmethod def min_interrupt_confidence_must_be_between_0_and_1(cls, v): if v is not None and (v < 0 or v > 1): raise ValueError("must be between 0 and 1") @@ -116,7 +102,8 @@ def from_livekit_input_device(cls, **kwargs): ) -class DeepgramTranscriberConfig(TranscriberConfig, type=TranscriberType.DEEPGRAM.value): # type: ignore +class DeepgramTranscriberConfig(TranscriberConfig): + type: Literal["transcriber_deepgram"] = "transcriber_deepgram" language: Optional[str] = None model: Optional[str] = "nova" tier: Optional[str] = None @@ -127,38 +114,39 @@ class DeepgramTranscriberConfig(TranscriberConfig, type=TranscriberType.DEEPGRAM ws_url: str = DEEPGRAM_API_WS_URL -class GladiaTranscriberConfig(TranscriberConfig, type=TranscriberType.GLADIA.value): # type: ignore +class GladiaTranscriberConfig(TranscriberConfig): + type: Literal["transcriber_gladia"] = "transcriber_gladia" buffer_size_seconds: float = 0.1 -class GoogleTranscriberConfig(TranscriberConfig, type=TranscriberType.GOOGLE.value): # type: ignore +class GoogleTranscriberConfig(TranscriberConfig): + type: Literal["transcriber_google"] = "transcriber_google" model: Optional[str] = None language_code: str = "en-US" -class AzureTranscriberConfig(TranscriberConfig, type=TranscriberType.AZURE.value): # type: ignore +class AzureTranscriberConfig(TranscriberConfig): + type: Literal["transcriber_azure"] = "transcriber_azure" language: str = AZURE_DEFAULT_LANGUAGE candidate_languages: Optional[List[str]] = None -class AssemblyAITranscriberConfig( - TranscriberConfig, type=TranscriberType.ASSEMBLY_AI.value # type: ignore -): +class AssemblyAITranscriberConfig(TranscriberConfig): + type: Literal["transcriber_assembly_ai"] = "transcriber_assembly_ai" buffer_size_seconds: float = 0.1 word_boost: Optional[List[str]] = None end_utterance_silence_threshold_milliseconds: Optional[int] = None -class WhisperCPPTranscriberConfig( - TranscriberConfig, type=TranscriberType.WHISPER_CPP.value # type: ignore -): +class WhisperCPPTranscriberConfig(TranscriberConfig): + type: Literal["transcriber_whisper_cpp"] = "transcriber_whisper_cpp" buffer_size_seconds: float = 1 libname: str fname_model: str -class RevAITranscriberConfig(TranscriberConfig, type=TranscriberType.REV_AI.value): # type: ignore - pass +class RevAITranscriberConfig(TranscriberConfig): + type: Literal["transcriber_rev_ai"] = "transcriber_rev_ai" class Transcription(BaseModel): diff --git a/vocode/streaming/models/transcript.py b/vocode/streaming/models/transcript.py index 380047d3a9..0f4806dbfd 100644 --- a/vocode/streaming/models/transcript.py +++ b/vocode/streaming/models/transcript.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import List, Literal, Optional -from pydantic.v1 import BaseModel, Field +from pydantic import BaseModel, Field from vocode.streaming.models.actions import ActionInput, ActionOutput from vocode.streaming.models.events import ActionEvent, Event, EventType, Sender @@ -105,7 +105,8 @@ def to_string(self, include_timestamp: bool = False, include_sender: bool = True return msg_string -class TranscriptEvent(Event, type=EventType.TRANSCRIPT): # type: ignore +class TranscriptEvent(Event): + type: Literal["event_transcript"] = "event_transcript" text: str sender: Sender timestamp: float @@ -273,5 +274,6 @@ def was_last_message_interrupted(self): return False -class TranscriptCompleteEvent(Event, type=EventType.TRANSCRIPT_COMPLETE): # type: ignore +class TranscriptCompleteEvent(Event): + type: Literal["event_transcript_complete"] = "event_transcript_complete" transcript: Transcript diff --git a/vocode/streaming/models/vector_db.py b/vocode/streaming/models/vector_db.py index 44fdf64c6d..bb81a73a52 100644 --- a/vocode/streaming/models/vector_db.py +++ b/vocode/streaming/models/vector_db.py @@ -1,21 +1,19 @@ +from abc import ABC from enum import Enum -from typing import Optional +from typing import Any, Literal, Optional -from .model import TypedModel +from vocode.streaming.models.adaptive_object import AdaptiveObject DEFAULT_EMBEDDINGS_MODEL = "text-embedding-ada-002" -class VectorDBType(str, Enum): - BASE = "vector_db_base" - PINECONE = "vector_db_pinecone" - - -class VectorDBConfig(TypedModel, type=VectorDBType.BASE.value): # type: ignore +class VectorDBConfig(AdaptiveObject, ABC): + type: Any embeddings_model: str = DEFAULT_EMBEDDINGS_MODEL -class PineconeConfig(VectorDBConfig, type=VectorDBType.PINECONE.value): # type: ignore +class PineconeConfig(VectorDBConfig): + type: Literal["vector_db_pinecone"] = "vector_db_pinecone" index: str api_key: Optional[str] api_environment: Optional[str] diff --git a/vocode/streaming/models/websocket.py b/vocode/streaming/models/websocket.py index 0629ea19a4..31fc69c956 100644 --- a/vocode/streaming/models/websocket.py +++ b/vocode/streaming/models/websocket.py @@ -1,7 +1,9 @@ import base64 +from abc import ABC from enum import Enum -from typing import Optional +from typing import Any, Literal, Optional +from vocode.streaming.models.adaptive_object import AdaptiveObject from vocode.streaming.models.client_backend import InputAudioConfig, OutputAudioConfig from .agent import AgentConfig @@ -22,11 +24,12 @@ class WebSocketMessageType(str, Enum): AUDIO_CONFIG_START = "websocket_audio_config_start" -class WebSocketMessage(TypedModel, type=WebSocketMessageType.BASE): # type: ignore - pass +class WebSocketMessage(AdaptiveObject, ABC): + type: Any -class AudioMessage(WebSocketMessage, type=WebSocketMessageType.AUDIO): # type: ignore +class AudioMessage(WebSocketMessage): + type: Literal["websocket_audio"] = "websocket_audio" data: str @classmethod @@ -37,7 +40,8 @@ def get_bytes(self) -> bytes: return base64.b64decode(self.data) -class TranscriptMessage(WebSocketMessage, type=WebSocketMessageType.TRANSCRIPT): # type: ignore +class TranscriptMessage(WebSocketMessage): + type: Literal["websocket_transcript"] = "websocket_transcript" text: str sender: Sender timestamp: float @@ -47,25 +51,25 @@ def from_event(cls, event: TranscriptEvent): return cls(text=event.text, sender=event.sender, timestamp=event.timestamp) -class StartMessage(WebSocketMessage, type=WebSocketMessageType.START): # type: ignore +class StartMessage(WebSocketMessage): + type: Literal["websocket_start"] = "websocket_start" transcriber_config: TranscriberConfig agent_config: AgentConfig synthesizer_config: SynthesizerConfig conversation_id: Optional[str] = None -class AudioConfigStartMessage( - WebSocketMessage, type=WebSocketMessageType.AUDIO_CONFIG_START # type: ignore -): +class AudioConfigStartMessage(WebSocketMessage): + type: Literal["websocket_audio_config_start"] = "websocket_audio_config_start" input_audio_config: InputAudioConfig output_audio_config: OutputAudioConfig conversation_id: Optional[str] = None subscribe_transcript: Optional[bool] = None -class ReadyMessage(WebSocketMessage, type=WebSocketMessageType.READY): # type: ignore - pass +class ReadyMessage(WebSocketMessage): + type: Literal["websocket_ready"] = "websocket_ready" -class StopMessage(WebSocketMessage, type=WebSocketMessageType.STOP): # type: ignore - pass +class StopMessage(WebSocketMessage): + type: Literal["websocket_stop"] = "websocket_stop" diff --git a/vocode/streaming/models/websocket_agent.py b/vocode/streaming/models/websocket_agent.py index f9d06e59e7..baea63ab6b 100644 --- a/vocode/streaming/models/websocket_agent.py +++ b/vocode/streaming/models/websocket_agent.py @@ -1,8 +1,11 @@ +from abc import ABC from enum import Enum -from typing import Optional +from typing import Any, Literal, Optional -from vocode.streaming.models.agent import AgentConfig, AgentType -from vocode.streaming.models.model import BaseModel, TypedModel +from pydantic import BaseModel + +from vocode.streaming.models.adaptive_object import AdaptiveObject +from vocode.streaming.models.agent import AgentConfig class WebSocketAgentMessageType(str, Enum): @@ -11,13 +14,14 @@ class WebSocketAgentMessageType(str, Enum): STOP = "websocket_agent_stop" -class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE): # type: ignore +class WebSocketAgentMessage(AdaptiveObject, ABC): + type: Any conversation_id: Optional[str] = None -class WebSocketAgentTextMessage( - WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT # type: ignore -): +class WebSocketAgentTextMessage(WebSocketAgentMessage): + type: Literal["websocket_agent_text"] = "websocket_agent_text" + class Payload(BaseModel): text: str @@ -28,16 +32,13 @@ def from_text(cls, text: str, conversation_id: Optional[str] = None): return cls(data=cls.Payload(text=text), conversation_id=conversation_id) -class WebSocketAgentStopMessage( - WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP # type: ignore -): - pass +class WebSocketAgentStopMessage(WebSocketAgentMessage): + type: Literal["websocket_agent_stop"] = "websocket_agent_stop" -class WebSocketUserImplementedAgentConfig( - AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED.value # type: ignore -): +class WebSocketUserImplementedAgentConfig(AgentConfig): class RouteConfig(BaseModel): url: str + type: Literal["agent_websocket_user_implemented"] = "agent_websocket_user_implemented" respond: RouteConfig diff --git a/vocode/streaming/synthesizer/coqui_synthesizer.py b/vocode/streaming/synthesizer/coqui_synthesizer.py index 2feedb10d7..b2e322f5cc 100644 --- a/vocode/streaming/synthesizer/coqui_synthesizer.py +++ b/vocode/streaming/synthesizer/coqui_synthesizer.py @@ -5,7 +5,6 @@ from vocode import getenv from vocode.streaming.models.message import BaseMessage -from vocode.streaming.models.synthesizer import CoquiSynthesizerConfig from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer, SynthesisResult from vocode.streaming.utils.async_requester import AsyncRequestor diff --git a/vocode/streaming/telephony/config_manager/redis_config_manager.py b/vocode/streaming/telephony/config_manager/redis_config_manager.py index 33f4aabcf9..d0c861002d 100644 --- a/vocode/streaming/telephony/config_manager/redis_config_manager.py +++ b/vocode/streaming/telephony/config_manager/redis_config_manager.py @@ -18,6 +18,7 @@ async def _set_with_one_day_expiration(self, *args, **kwargs): async def save_config(self, conversation_id: str, config: BaseCallConfig): logger.debug(f"Saving config for {conversation_id}") + print(config) await self._set_with_one_day_expiration(conversation_id, config.json()) async def get_config(self, conversation_id) -> Optional[BaseCallConfig]: diff --git a/vocode/streaming/telephony/server/base.py b/vocode/streaming/telephony/server/base.py index 5e04c3c05c..4ea6fb6aa2 100644 --- a/vocode/streaming/telephony/server/base.py +++ b/vocode/streaming/telephony/server/base.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Form, Request, Response from loguru import logger -from pydantic.v1 import BaseModel, Field +from pydantic import BaseModel, Field from vocode.streaming.agent.abstract_factory import AbstractAgentFactory from vocode.streaming.agent.default_factory import DefaultAgentFactory diff --git a/vocode/streaming/transcriber/deepgram_transcriber.py b/vocode/streaming/transcriber/deepgram_transcriber.py index 7f64f4a976..c8875ac5d3 100644 --- a/vocode/streaming/transcriber/deepgram_transcriber.py +++ b/vocode/streaming/transcriber/deepgram_transcriber.py @@ -1,13 +1,13 @@ import asyncio import json from datetime import datetime, timezone -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union from urllib.parse import urlencode import sentry_sdk import websockets from loguru import logger -from pydantic.v1 import BaseModel, Field +from pydantic import BaseModel, Field from websockets.client import WebSocketClientProtocol from vocode import getenv @@ -37,14 +37,14 @@ class TimeSilentConfig(BaseModel): post_punctuation_time_seconds: float = 0.5 -class InternalPunctuationEndpointingConfig( # type: ignore - EndpointingConfig, type="internal_punctuation_based" -): +class InternalPunctuationEndpointingConfig(EndpointingConfig): # type: ignore + type: Literal["internal_punctuation_based"] = "internal_punctuation_based" time_silent_config: TimeSilentConfig = Field(default_factory=TimeSilentConfig) use_single_utterance_endpointing_for_first_utterance: bool = False -class DeepgramEndpointingConfig(EndpointingConfig, type="deepgram"): # type: ignore +class DeepgramEndpointingConfig(EndpointingConfig): + type: Literal["deepgram"] = "deepgram" vad_threshold_ms: int = 500 utterance_cutoff_ms: int = 1000 time_silent_config: Optional[TimeSilentConfig] = Field(default_factory=TimeSilentConfig) diff --git a/vocode/streaming/utils/redis_conversation_message_queue.py b/vocode/streaming/utils/redis_conversation_message_queue.py index a4c6e00caa..56cfe51e12 100644 --- a/vocode/streaming/utils/redis_conversation_message_queue.py +++ b/vocode/streaming/utils/redis_conversation_message_queue.py @@ -1,7 +1,7 @@ from typing import AsyncGenerator from loguru import logger -from pydantic.v1 import BaseModel, parse_obj_as +from pydantic import BaseModel, parse_obj_as from redis.asyncio import Redis from vocode.streaming.utils.redis import initialize_redis From eeb3dad805e61497d3c5f4481be80b265e0f70b2 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 11:38:07 -0700 Subject: [PATCH 3/8] require that adaptiveobjects be serialized as any --- vocode/streaming/models/adaptive_object.py | 10 +++++++--- .../telephony/config_manager/redis_config_manager.py | 1 - 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vocode/streaming/models/adaptive_object.py b/vocode/streaming/models/adaptive_object.py index f8eb60b1f6..4238dd2f57 100644 --- a/vocode/streaming/models/adaptive_object.py +++ b/vocode/streaming/models/adaptive_object.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any +from typing import Any, Dict from pydantic import BaseModel, ValidationError, model_validator @@ -17,9 +17,7 @@ def _resolve_adaptive_object(cls, data: dict, handler) -> Any: return handler(data) # try to validate the data for each possible type - print(data) for subcls in cls._find_all_possible_types(): - print(subcls) try: # return the first successful validation return subcls.model_validate(data) @@ -42,3 +40,9 @@ def _find_all_possible_types(cls): # continue looking for possible types in subclasses for subclass in cls.__subclasses__(): yield from subclass._find_all_possible_types() + + def model_dump(self, **kwargs) -> Dict[str, Any]: + return super().model_dump(serialize_as_any=True, **kwargs) + + def model_dump_json(self, **kwargs) -> str: + return super().model_dump_json(serialize_as_any=True, **kwargs) diff --git a/vocode/streaming/telephony/config_manager/redis_config_manager.py b/vocode/streaming/telephony/config_manager/redis_config_manager.py index d0c861002d..33f4aabcf9 100644 --- a/vocode/streaming/telephony/config_manager/redis_config_manager.py +++ b/vocode/streaming/telephony/config_manager/redis_config_manager.py @@ -18,7 +18,6 @@ async def _set_with_one_day_expiration(self, *args, **kwargs): async def save_config(self, conversation_id: str, config: BaseCallConfig): logger.debug(f"Saving config for {conversation_id}") - print(config) await self._set_with_one_day_expiration(conversation_id, config.json()) async def get_config(self, conversation_id) -> Optional[BaseCallConfig]: From e10249822508343b173847d67c0fb32b3cfef5f6 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 11:54:34 -0700 Subject: [PATCH 4/8] fix mypy in vocode/ + tests --- tests/fixtures/synthesizer.py | 5 +++- tests/fixtures/transcriber.py | 5 +++- .../streaming/action/test_end_conversation.py | 2 +- tests/streaming/agent/test_openai_utils.py | 27 +++++++++++++++---- tests/streaming/agent/test_streaming_utils.py | 2 +- .../streaming/test_streaming_conversation.py | 2 +- tests/streaming/utils/test_events_manager.py | 22 +++++++-------- vocode/streaming/agent/base_agent.py | 4 +-- .../agent/restful_user_implemented_agent.py | 5 +++- vocode/streaming/models/agent.py | 4 +-- vocode/streaming/models/message.py | 8 ------ .../synthesizer/coqui_synthesizer.py | 4 +-- 12 files changed, 54 insertions(+), 36 deletions(-) diff --git a/tests/fixtures/synthesizer.py b/tests/fixtures/synthesizer.py index 6181db2a20..8792dcb979 100644 --- a/tests/fixtures/synthesizer.py +++ b/tests/fixtures/synthesizer.py @@ -1,5 +1,6 @@ import wave from io import BytesIO +from typing import Literal from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.synthesizer import SynthesizerConfig @@ -17,9 +18,11 @@ def create_fake_audio(message: str, synthesizer_config: SynthesizerConfig): return file -class TestSynthesizerConfig(SynthesizerConfig, type="synthesizer_test"): +class TestSynthesizerConfig(SynthesizerConfig): __test__ = False + type: Literal["synthesizer_test"] = "synthesizer_test" + class TestSynthesizer(BaseSynthesizer[TestSynthesizerConfig]): """Accepts text and creates a SynthesisResult containing audio data which is the same as the text as bytes.""" diff --git a/tests/fixtures/transcriber.py b/tests/fixtures/transcriber.py index 4c1b2e9f49..e78aad15b9 100644 --- a/tests/fixtures/transcriber.py +++ b/tests/fixtures/transcriber.py @@ -1,12 +1,15 @@ import asyncio +from typing import Literal from vocode.streaming.models.transcriber import TranscriberConfig from vocode.streaming.transcriber.base_transcriber import BaseAsyncTranscriber, Transcription -class TestTranscriberConfig(TranscriberConfig, type="transcriber_test"): +class TestTranscriberConfig(TranscriberConfig): __test__ = False + type: Literal["transcriber_test"] = "transcriber_test" + class TestAsyncTranscriber(BaseAsyncTranscriber[TestTranscriberConfig]): """Accepts fake audio chunks and sends out transcriptions which are the same as the audio chunks.""" diff --git a/tests/streaming/action/test_end_conversation.py b/tests/streaming/action/test_end_conversation.py index 312c72285b..357ed307f4 100644 --- a/tests/streaming/action/test_end_conversation.py +++ b/tests/streaming/action/test_end_conversation.py @@ -4,7 +4,7 @@ from uuid import UUID import pytest -from pydantic.v1 import BaseModel +from pydantic import BaseModel from pytest_mock import MockerFixture from tests.fakedata.id import generate_uuid diff --git a/tests/streaming/agent/test_openai_utils.py b/tests/streaming/agent/test_openai_utils.py index 4858c92943..eec3b1d075 100644 --- a/tests/streaming/agent/test_openai_utils.py +++ b/tests/streaming/agent/test_openai_utils.py @@ -1,3 +1,7 @@ +from typing import Literal + +from pydantic import BaseModel + from vocode.streaming.agent.openai_utils import format_openai_chat_messages_from_transcript from vocode.streaming.models.actions import ( ACTION_FINISHED_FORMAT_STRING, @@ -11,7 +15,15 @@ from vocode.streaming.models.transcript import ActionFinish, ActionStart, Message, Transcript -class WeatherActionConfig(ActionConfig, type="weather"): +class WeatherActionConfig(ActionConfig): + type: Literal["weather"] = "weather" + + +class WeatherActionParams(BaseModel): + pass + + +class WeatherActionResponse(BaseModel): pass @@ -23,12 +35,12 @@ def test_format_openai_chat_messages_from_transcript(): test_action_input_nophrase = ActionInput( action_config=WeatherActionConfig(), conversation_id="asdf", - params={}, + params=WeatherActionParams(), ) test_action_input_phrase = ActionInput( action_config=WeatherActionConfig(action_trigger=create_fake_vocode_phrase_trigger()), conversation_id="asdf", - params={}, + params=WeatherActionParams(), ) test_cases = [ @@ -85,7 +97,9 @@ def test_format_openai_chat_messages_from_transcript(): ActionFinish( action_type="weather", action_input=test_action_input_nophrase, - action_output=ActionOutput(action_type="weather", response={}), + action_output=ActionOutput( + action_type="weather", response=WeatherActionResponse() + ), ), ] ), @@ -127,7 +141,9 @@ def test_format_openai_chat_messages_from_transcript(): ActionFinish( action_type="weather", action_input=test_action_input_phrase, - action_output=ActionOutput(action_type="weather", response={}), + action_output=ActionOutput( + action_type="weather", response=WeatherActionResponse() + ), ), ] ), @@ -301,3 +317,4 @@ def test_format_openai_chat_messages_from_transcript_context_limit(): for params, expected_output in test_cases: assert format_openai_chat_messages_from_transcript(*params) == expected_output + assert format_openai_chat_messages_from_transcript(*params) == expected_output diff --git a/tests/streaming/agent/test_streaming_utils.py b/tests/streaming/agent/test_streaming_utils.py index db46c4a2a9..58e3a57966 100644 --- a/tests/streaming/agent/test_streaming_utils.py +++ b/tests/streaming/agent/test_streaming_utils.py @@ -2,7 +2,7 @@ import pytest from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from pydantic.v1 import BaseModel +from pydantic import BaseModel from vocode.streaming.agent.openai_utils import openai_get_tokens from vocode.streaming.agent.streaming_utils import collate_response_async diff --git a/tests/streaming/test_streaming_conversation.py b/tests/streaming/test_streaming_conversation.py index 89149994ab..0feba811ec 100644 --- a/tests/streaming/test_streaming_conversation.py +++ b/tests/streaming/test_streaming_conversation.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest -from pydantic.v1 import BaseModel +from pydantic import BaseModel from pytest_mock import MockerFixture from tests.fakedata.conversation import ( diff --git a/tests/streaming/utils/test_events_manager.py b/tests/streaming/utils/test_events_manager.py index 75cb02f32f..4f03cfb41d 100644 --- a/tests/streaming/utils/test_events_manager.py +++ b/tests/streaming/utils/test_events_manager.py @@ -19,9 +19,9 @@ async def test_initialization(): @pytest.mark.asyncio async def test_publish_event(): event = PhoneCallEndedEvent( - conversation_id=CONVERSATION_ID, type=EventType.PHONE_CALL_ENDED + conversation_id=CONVERSATION_ID, type="event_phone_call_ended" ) # Replace with actual Event creation - manager = EventsManager([EventType.PHONE_CALL_ENDED]) + manager = EventsManager(["event_phone_call_ended"]) manager.publish_event(event) assert not manager.queue.empty() @@ -29,16 +29,16 @@ async def test_publish_event(): @pytest.mark.asyncio async def test_handle_event_default_implementation(): event = PhoneCallEndedEvent( - conversation_id=CONVERSATION_ID, type=EventType.PHONE_CALL_ENDED + conversation_id=CONVERSATION_ID, type="event_phone_call_ended" ) # Replace with actual Event creation - manager = EventsManager([EventType.PHONE_CALL_ENDED]) + manager = EventsManager(["event_phone_call_ended"]) await manager.handle_event(event) @pytest.mark.asyncio async def test_handle_event_non_async_override(mocker): - event = PhoneCallEndedEvent(conversation_id=CONVERSATION_ID, type=EventType.PHONE_CALL_ENDED) - manager = EventsManager([EventType.PHONE_CALL_ENDED]) + event = PhoneCallEndedEvent(conversation_id=CONVERSATION_ID, type="event_phone_call_ended") + manager = EventsManager(["event_phone_call_ended"]) manager.publish_event(event) error_logger_mock = mocker.patch("vocode.streaming.utils.events_manager.logger.error") @@ -53,9 +53,9 @@ async def test_handle_event_non_async_override(mocker): @pytest.mark.asyncio async def test_start_and_active_loop(): event = PhoneCallEndedEvent( - conversation_id=CONVERSATION_ID, type=EventType.PHONE_CALL_ENDED + conversation_id=CONVERSATION_ID, type="event_phone_call_ended" ) # Replace with actual Event creation - manager = EventsManager([EventType.PHONE_CALL_ENDED]) + manager = EventsManager(["event_phone_call_ended"]) asyncio.create_task(manager.start()) manager.publish_event(event) await asyncio.sleep(0.1) @@ -64,8 +64,8 @@ async def test_start_and_active_loop(): @pytest.mark.asyncio async def test_flush_method(): - event = PhoneCallEndedEvent(conversation_id=CONVERSATION_ID, type=EventType.PHONE_CALL_ENDED) - manager = EventsManager([EventType.PHONE_CALL_ENDED]) + event = PhoneCallEndedEvent(conversation_id=CONVERSATION_ID, type="event_phone_call_ended") + manager = EventsManager(["event_phone_call_ended"]) for _ in range(5): manager.publish_event(event) await manager.flush() @@ -74,6 +74,6 @@ async def test_flush_method(): @pytest.mark.asyncio async def test_queue_empty_and_timeout(): - manager = EventsManager([EventType.TRANSCRIPT]) + manager = EventsManager(["event_transcript"]) await manager.flush() assert manager.queue.empty() diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index fd926e5903..8eeb8eaed6 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -72,8 +72,8 @@ class _AgentInput(BaseModel): conversation_id: str - vonage_uuid: Optional[str] - twilio_sid: Optional[str] + vonage_uuid: Optional[str] = None + twilio_sid: Optional[str] = None agent_response_tracker: Optional[asyncio.Event] = None class Config: diff --git a/vocode/streaming/agent/restful_user_implemented_agent.py b/vocode/streaming/agent/restful_user_implemented_agent.py index cda7399f6c..870b1535d3 100644 --- a/vocode/streaming/agent/restful_user_implemented_agent.py +++ b/vocode/streaming/agent/restful_user_implemented_agent.py @@ -2,6 +2,7 @@ import aiohttp from loguru import logger +from pydantic import TypeAdapter from vocode.streaming.agent.base_agent import RespondAgent from vocode.streaming.models.agent import ( @@ -44,7 +45,9 @@ async def respond( timeout=aiohttp.ClientTimeout(total=15), ) as response: assert response.status == 200 - output: RESTfulAgentOutput = RESTfulAgentOutput.parse_obj(await response.json()) + output: RESTfulAgentOutput = TypeAdapter(RESTfulAgentOutput).validate_python( + await response.json() + ) output_response = None should_stop = False if isinstance(output, RESTfulAgentText): diff --git a/vocode/streaming/models/agent.py b/vocode/streaming/models/agent.py index 3bfa209cba..f06f8a137f 100644 --- a/vocode/streaming/models/agent.py +++ b/vocode/streaming/models/agent.py @@ -195,12 +195,12 @@ class RESTfulAgentInput(BaseModel): class RESTfulAgentText(BaseModel): - type: Literal["restful_agent_text"] = "restful_agent_text" + type: Literal["restful_agent_text"] response: str class RESTfulAgentEnd(BaseModel): - type: Literal["restful_agent_end"] = "restful_agent_end" + type: Literal["restful_agent_end"] RESTfulAgentOutput = Union[RESTfulAgentText, RESTfulAgentEnd] diff --git a/vocode/streaming/models/message.py b/vocode/streaming/models/message.py index 1bb5624fd3..cd6b6765a8 100644 --- a/vocode/streaming/models/message.py +++ b/vocode/streaming/models/message.py @@ -3,14 +3,6 @@ from pydantic import BaseModel - -class MessageType(str, Enum): - BASE = "message_base" - SSML = "message_ssml" - BOT_BACKCHANNEL = "bot_backchannel" - LLM_TOKEN = "llm_token" - - MessageType = Literal["message_base", "message_ssml", "bot_backchannel", "llm_token"] diff --git a/vocode/streaming/synthesizer/coqui_synthesizer.py b/vocode/streaming/synthesizer/coqui_synthesizer.py index b2e322f5cc..489ba54826 100644 --- a/vocode/streaming/synthesizer/coqui_synthesizer.py +++ b/vocode/streaming/synthesizer/coqui_synthesizer.py @@ -13,8 +13,8 @@ COQUI_BASE_URL = "https://app.coqui.ai/api/v2" -class CoquiSynthesizer(BaseSynthesizer[CoquiSynthesizerConfig]): - def __init__(self, synthesizer_config: CoquiSynthesizerConfig): +class CoquiSynthesizer(BaseSynthesizer): + def __init__(self, synthesizer_config): super().__init__(synthesizer_config) self.api_key = synthesizer_config.api_key or getenv("COQUI_API_KEY") self.voice_id = synthesizer_config.voice_id From e7b9df340f5d7801e46bd56fcf278118feb48e25 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 12:04:35 -0700 Subject: [PATCH 5/8] fix mypy --- playground/streaming/agent/chat.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index dfcede2dec..62dede8a6b 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -3,7 +3,7 @@ import typing from dotenv import load_dotenv -from pydantic.v1 import BaseModel +from pydantic import BaseModel from vocode.streaming.action.abstract_factory import AbstractActionFactory from vocode.streaming.action.base_action import BaseAction @@ -28,8 +28,9 @@ from vocode.streaming.agent import ChatGPTAgent from vocode.streaming.agent.base_agent import ( AgentResponse, + AgentResponseFillerAudio, AgentResponseMessage, - AgentResponseType, + AgentResponseStop, BaseAgent, TranscriptionAgentInput, ) @@ -39,7 +40,8 @@ BACKCHANNELS = ["Got it", "Sure", "Okay", "I understand"] -class ShoutActionConfig(ActionConfig, type="shout"): # type: ignore +class ShoutActionConfig(ActionConfig): + type: typing.Literal["shout"] = "shout" num_exclamation_marks: int @@ -114,16 +116,15 @@ async def receiver(): try: event = await agent_response_queue.get() response = event.payload - if response.type == AgentResponseType.FILLER_AUDIO: + if isinstance(response, AgentResponseFillerAudio): print("Would have sent filler audio") - elif response.type == AgentResponseType.STOP: + elif isinstance(response, AgentResponseStop): print("Agent returned stop") ended = True break - elif response.type == AgentResponseType.MESSAGE: - agent_response = typing.cast(AgentResponseMessage, response) + elif isinstance(response, AgentResponseMessage): - if isinstance(agent_response.message, EndOfTurn): + if isinstance(response.message, EndOfTurn): ignore_until_end_of_turn = False if random.random() < backchannel_probability: backchannel = random.choice(BACKCHANNELS) @@ -133,7 +134,7 @@ async def receiver(): conversation_id, is_backchannel=True, ) - elif isinstance(agent_response.message, BaseMessage): + elif isinstance(response.message, BaseMessage): if ignore_until_end_of_turn: continue @@ -141,12 +142,12 @@ async def receiver(): is_final: bool # TODO: consider allowing the user to interrupt the agent manually by responding fast if random.random() < interruption_probability: - stop_idx = random.randint(0, len(agent_response.message.text)) - message_sent = agent_response.message.text[:stop_idx] + stop_idx = random.randint(0, len(response.message.text)) + message_sent = response.message.text[:stop_idx] ignore_until_end_of_turn = True is_final = False else: - message_sent = agent_response.message.text + message_sent = response.message.text is_final = True agent.transcript.add_bot_message( From 2331b1fb12bf788ca802d0ced7f2cec4cd3dbada Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 13:59:33 -0700 Subject: [PATCH 6/8] remove old Type enums --- apps/langchain_agent/telephony_app.py | 11 +++++------ tests/streaming/utils/test_events_manager.py | 2 +- vocode/streaming/action/default_factory.py | 16 ++++++++-------- .../streaming/client_backend/conversation.py | 9 ++++----- .../livekit/livekit_events_manager.py | 4 ++-- vocode/streaming/models/actions.py | 19 +++++++++---------- vocode/streaming/models/events.py | 17 +++++++++-------- vocode/streaming/models/message.py | 1 - vocode/streaming/models/transcript.py | 2 +- vocode/streaming/models/websocket.py | 10 ---------- vocode/streaming/models/websocket_agent.py | 6 ------ 11 files changed, 39 insertions(+), 58 deletions(-) diff --git a/apps/langchain_agent/telephony_app.py b/apps/langchain_agent/telephony_app.py index 3c4a5a15f5..494b919b5e 100644 --- a/apps/langchain_agent/telephony_app.py +++ b/apps/langchain_agent/telephony_app.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from fastapi import FastAPI -from vocode.streaming.models.events import Event, EventType +from vocode.streaming.models.events import Event from vocode.streaming.models.transcript import TranscriptCompleteEvent from vocode.streaming.telephony.config_manager.redis_config_manager import RedisConfigManager from vocode.streaming.telephony.server.base import TelephonyServer @@ -23,14 +23,13 @@ class EventsManager(events_manager.EventsManager): def __init__(self): - super().__init__(subscriptions=[EventType.TRANSCRIPT_COMPLETE]) + super().__init__(subscriptions=["transcript_complete"]) async def handle_event(self, event: Event): - if event.type == EventType.TRANSCRIPT_COMPLETE: - transcript_complete_event = typing.cast(TranscriptCompleteEvent, event) + if isinstance(event, TranscriptCompleteEvent): add_transcript( - transcript_complete_event.conversation_id, - transcript_complete_event.transcript.to_string(), + event.conversation_id, + event.transcript.to_string(), ) diff --git a/tests/streaming/utils/test_events_manager.py b/tests/streaming/utils/test_events_manager.py index 4f03cfb41d..5ff96f583c 100644 --- a/tests/streaming/utils/test_events_manager.py +++ b/tests/streaming/utils/test_events_manager.py @@ -2,7 +2,7 @@ import pytest -from vocode.streaming.models.events import EventType, PhoneCallEndedEvent +from vocode.streaming.models.events import PhoneCallEndedEvent from vocode.streaming.utils.events_manager import EventsManager CONVERSATION_ID = "1" diff --git a/vocode/streaming/action/default_factory.py b/vocode/streaming/action/default_factory.py index 227ed56c42..b5fb757f7d 100644 --- a/vocode/streaming/action/default_factory.py +++ b/vocode/streaming/action/default_factory.py @@ -15,20 +15,20 @@ from vocode.streaming.models.actions import ActionConfig, ActionType CONVERSATION_ACTIONS: Dict[ActionType, Type[BaseAction]] = { - ActionType.END_CONVERSATION: EndConversation, - ActionType.RECORD_EMAIL: RecordEmail, - ActionType.WAIT: Wait, - ActionType.EXECUTE_EXTERNAL_ACTION: ExecuteExternalAction, + "action_end_conversation": EndConversation, + "action_record_email": RecordEmail, + "action_wait": Wait, + "action_external": ExecuteExternalAction, } VONAGE_ACTIONS: Dict[ActionType, Type[VonagePhoneConversationAction]] = { - ActionType.TRANSFER_CALL: VonageTransferCall, - ActionType.DTMF: VonageDTMF, + "action_transfer_call": VonageTransferCall, + "action_dtmf": VonageDTMF, } TWILIO_ACTIONS: Dict[ActionType, Type[TwilioPhoneConversationAction]] = { - ActionType.TRANSFER_CALL: TwilioTransferCall, - ActionType.DTMF: TwilioDTMF, + "action_transfer_call": TwilioTransferCall, + "action_dtmf": TwilioDTMF, } diff --git a/vocode/streaming/client_backend/conversation.py b/vocode/streaming/client_backend/conversation.py index b2bd0f7094..c1b0a45e1d 100644 --- a/vocode/streaming/client_backend/conversation.py +++ b/vocode/streaming/client_backend/conversation.py @@ -6,7 +6,7 @@ from vocode.streaming.agent.base_agent import BaseAgent from vocode.streaming.models.client_backend import InputAudioConfig, OutputAudioConfig -from vocode.streaming.models.events import Event, EventType +from vocode.streaming.models.events import Event from vocode.streaming.models.synthesizer import AzureSynthesizerConfig from vocode.streaming.models.transcriber import ( DeepgramTranscriberConfig, @@ -111,13 +111,12 @@ def __init__( self, output_device: WebsocketOutputDevice, ): - super().__init__(subscriptions=[EventType.TRANSCRIPT]) + super().__init__(subscriptions=["event_transcript"]) self.output_device = output_device async def handle_event(self, event: Event): - if event.type == EventType.TRANSCRIPT: - transcript_event = typing.cast(TranscriptEvent, event) - await self.output_device.send_transcript(transcript_event) + if isinstance(event, TranscriptEvent): + await self.output_device.send_transcript(event) # logger.debug(event.dict()) def restart(self, output_device: WebsocketOutputDevice): diff --git a/vocode/streaming/livekit/livekit_events_manager.py b/vocode/streaming/livekit/livekit_events_manager.py index 5daaefcc87..38ac1e0fd8 100644 --- a/vocode/streaming/livekit/livekit_events_manager.py +++ b/vocode/streaming/livekit/livekit_events_manager.py @@ -18,8 +18,8 @@ def __init__( self, subscriptions: List[EventType] = [], ): - if EventType.TRANSCRIPT not in subscriptions: - subscriptions.append(EventType.TRANSCRIPT) + if "event_transcript" not in subscriptions: + subscriptions.append("event_transcript") super().__init__(subscriptions) def attach_conversation(self, conversation: "LiveKitConversation"): diff --git a/vocode/streaming/models/actions.py b/vocode/streaming/models/actions.py index fc8cd90477..dcc7ea0e99 100644 --- a/vocode/streaming/models/actions.py +++ b/vocode/streaming/models/actions.py @@ -52,16 +52,15 @@ class PhraseBasedActionTrigger(_ActionTrigger): ] -class ActionType(str, Enum): - BASE = "action_base" - NYLAS_SEND_EMAIL = "action_nylas_send_email" - WAIT = "action_wait" - RECORD_EMAIL = "action_record_email" - END_CONVERSATION = "action_end_conversation" - EXECUTE_EXTERNAL_ACTION = "action_external" - - TRANSFER_CALL = "action_transfer_call" - DTMF = "action_dtmf" +ActionType = Literal[ + "action_nylas_send_email", + "action_wait", + "action_record_email", + "action_end_conversation", + "action_external", + "action_transfer_call", + "action_dtmf", +] ParametersType = TypeVar("ParametersType", bound=BaseModel) diff --git a/vocode/streaming/models/events.py b/vocode/streaming/models/events.py index 8d9e878640..dabede82ca 100644 --- a/vocode/streaming/models/events.py +++ b/vocode/streaming/models/events.py @@ -18,14 +18,15 @@ class Sender(str, Enum): CONFERENCE = "conference" -class EventType(str, Enum): - TRANSCRIPT = "event_transcript" - TRANSCRIPT_COMPLETE = "event_transcript_complete" - PHONE_CALL_CONNECTED = "event_phone_call_connected" - PHONE_CALL_ENDED = "event_phone_call_ended" - PHONE_CALL_DID_NOT_CONNECT = "event_phone_call_did_not_connect" - RECORDING = "event_recording" - ACTION = "event_action" +EventType = Literal[ + "event_transcript", + "event_transcript_complete", + "event_phone_call_connected", + "event_phone_call_ended", + "event_phone_call_did_not_connect", + "event_recording", + "event_action", +] class Event(AdaptiveObject, ABC): diff --git a/vocode/streaming/models/message.py b/vocode/streaming/models/message.py index cd6b6765a8..05abfa52bb 100644 --- a/vocode/streaming/models/message.py +++ b/vocode/streaming/models/message.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import Literal, Optional from pydantic import BaseModel diff --git a/vocode/streaming/models/transcript.py b/vocode/streaming/models/transcript.py index 0f4806dbfd..e365d9dfc5 100644 --- a/vocode/streaming/models/transcript.py +++ b/vocode/streaming/models/transcript.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from vocode.streaming.models.actions import ActionInput, ActionOutput -from vocode.streaming.models.events import ActionEvent, Event, EventType, Sender +from vocode.streaming.models.events import ActionEvent, Event, Sender from vocode.streaming.utils.events_manager import EventsManager diff --git a/vocode/streaming/models/websocket.py b/vocode/streaming/models/websocket.py index 31fc69c956..18e36fed03 100644 --- a/vocode/streaming/models/websocket.py +++ b/vocode/streaming/models/websocket.py @@ -14,16 +14,6 @@ from .transcript import TranscriptEvent -class WebSocketMessageType(str, Enum): - BASE = "websocket_base" - START = "websocket_start" - AUDIO = "websocket_audio" - TRANSCRIPT = "websocket_transcript" - READY = "websocket_ready" - STOP = "websocket_stop" - AUDIO_CONFIG_START = "websocket_audio_config_start" - - class WebSocketMessage(AdaptiveObject, ABC): type: Any diff --git a/vocode/streaming/models/websocket_agent.py b/vocode/streaming/models/websocket_agent.py index baea63ab6b..d6e755400b 100644 --- a/vocode/streaming/models/websocket_agent.py +++ b/vocode/streaming/models/websocket_agent.py @@ -8,12 +8,6 @@ from vocode.streaming.models.agent import AgentConfig -class WebSocketAgentMessageType(str, Enum): - BASE = "websocket_agent_base" - TEXT = "websocket_agent_text" - STOP = "websocket_agent_stop" - - class WebSocketAgentMessage(AdaptiveObject, ABC): type: Any conversation_id: Optional[str] = None From c66a7effa737ccb42f862ede8eba56c030fa2674 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 15:10:36 -0700 Subject: [PATCH 7/8] delete typedmodel --- vocode/streaming/models/model.py | 60 -------------------------- vocode/streaming/models/transcriber.py | 4 +- vocode/streaming/models/websocket.py | 1 - 3 files changed, 1 insertion(+), 64 deletions(-) delete mode 100644 vocode/streaming/models/model.py diff --git a/vocode/streaming/models/model.py b/vocode/streaming/models/model.py deleted file mode 100644 index d6e21d18cc..0000000000 --- a/vocode/streaming/models/model.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Any, List, Tuple - -from pydantic import BaseModel as Pydantic1BaseModel - - -class BaseModel(Pydantic1BaseModel): - def __init__(self, **data): - for key, value in data.items(): - if isinstance(value, dict): - if ( - "type" in value and key != "action_trigger" - ): # TODO: this is a quick workaround until we get a vocode object version of action trigger (ajay has approved it) - data[key] = TypedModel.parse_obj(value) - if isinstance(value, list): - for i, v in enumerate(value): - if isinstance(v, dict): - if "type" in v: - value[i] = TypedModel.parse_obj(v) - super().__init__(**data) - - -# Adapted from https://github.com/pydantic/pydantic/discussions/3091 -class TypedModel(BaseModel): - _subtypes_: List[Tuple[Any, Any]] = [] - - def __init_subclass__(cls, type=None): # type: ignore - cls._subtypes_.append((type, cls)) - - @classmethod - def get_cls(_cls, type): - for t, cls in _cls._subtypes_: - if t == type: - return cls - raise ValueError(f"Unknown type {type}") - - @classmethod - def get_type(_cls, cls_name): - for t, cls in _cls._subtypes_: - if cls.__name__ == cls_name: - return t - raise ValueError(f"Unknown class {cls_name}") - - @classmethod - def parse_obj(cls, obj): - data_type = obj.get("type") - if data_type is None: - raise ValueError(f"type is required for {cls.__name__}") - - sub = cls.get_cls(data_type) - if sub is None: - raise ValueError(f"Unknown type {data_type}") - return sub(**obj) - - def _iter(self, **kwargs): - yield "type", self.get_type(self.__class__.__name__) - yield from super()._iter(**kwargs) - - @property - def type(self): - return self.get_type(self.__class__.__name__) diff --git a/vocode/streaming/models/transcriber.py b/vocode/streaming/models/transcriber.py index a48d502a9d..fb78531a33 100644 --- a/vocode/streaming/models/transcriber.py +++ b/vocode/streaming/models/transcriber.py @@ -1,14 +1,12 @@ from abc import ABC -from enum import Enum from typing import Any, List, Literal, Optional -from pydantic import field_validator +from pydantic import BaseModel, field_validator import vocode.streaming.livekit.constants as LiveKitConstants from vocode.streaming.input_device.base_input_device import BaseInputDevice from vocode.streaming.models.adaptive_object import AdaptiveObject from vocode.streaming.models.client_backend import InputAudioConfig -from vocode.streaming.models.model import BaseModel from vocode.streaming.telephony.constants import ( DEFAULT_AUDIO_ENCODING, DEFAULT_CHUNK_SIZE, diff --git a/vocode/streaming/models/websocket.py b/vocode/streaming/models/websocket.py index 18e36fed03..a1611a8d4c 100644 --- a/vocode/streaming/models/websocket.py +++ b/vocode/streaming/models/websocket.py @@ -8,7 +8,6 @@ from .agent import AgentConfig from .events import Sender -from .model import TypedModel from .synthesizer import SynthesizerConfig from .transcriber import TranscriberConfig from .transcript import TranscriptEvent From 056cecdb7c6992a65a7e264287dd486d68308470 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Fri, 12 Jul 2024 16:39:46 -0700 Subject: [PATCH 8/8] adds tests for adaptiveobject --- .../streaming/models/test_adaptive_object.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 tests/streaming/models/test_adaptive_object.py diff --git a/tests/streaming/models/test_adaptive_object.py b/tests/streaming/models/test_adaptive_object.py new file mode 100644 index 0000000000..b1557ee5cf --- /dev/null +++ b/tests/streaming/models/test_adaptive_object.py @@ -0,0 +1,47 @@ +from abc import ABC +from typing import Any, Literal + +from vocode.streaming.models.adaptive_object import AdaptiveObject + + +class B(AdaptiveObject, ABC): + type: Any + + +class SubB1(B): + type: Literal["sub_b1"] = "sub_b1" + x: int + + +class SubB2(B): + type: Literal["sub_b2"] = "sub_b2" + y: int + + +class A(AdaptiveObject, ABC): + type: Any + b: B + + +class SubA1(A): + type: Literal["sub_a1"] = "sub_a1" + x: int + + +class SubA2(A): + type: Literal["sub_a2"] = "sub_a2" + y: int + + +def test_serialize(): + sub_a1 = SubA1(b=SubB1(x=2), x=1) + assert sub_a1.model_dump() == {"b": {"type": "sub_b1", "x": 2}, "type": "sub_a1", "x": 1} + + +def test_deserialize(): + d = {"b": {"type": "sub_b1", "x": 2}, "type": "sub_a1", "x": 1} + sub_a1 = A.model_validate(d) + assert isinstance(sub_a1, SubA1) + assert isinstance(sub_a1.b, SubB1) + assert sub_a1.b.x == 2 + assert sub_a1.x == 1