-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from digidem/luisotee/audio-transcription
Audio transcription in supervisor route #12
- Loading branch information
Showing
9 changed files
with
369 additions
and
146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
import os | ||
import tempfile | ||
from typing import Any, Dict, List, Optional | ||
|
||
from fastapi import APIRouter, File, Form, UploadFile | ||
from langchain.chains import LLMChain | ||
from langchain.prompts import PromptTemplate | ||
from langchain_groq import ChatGroq | ||
from onboarding.crew import OnboardingCrew | ||
from opportunity_finder.crew import OpportunityFinderCrew | ||
from proposal_writer.crew import ProposalWriterCrew | ||
|
||
from eda_ai_api.models.classifier import ClassifierResponse | ||
from eda_ai_api.utils.audio_converter import convert_ogg | ||
from eda_ai_api.utils.transcriber import transcribe_audio | ||
|
||
router = APIRouter() | ||
|
||
ALLOWED_FORMATS = { | ||
"audio/mpeg": "mp3", | ||
"audio/mp4": "mp4", | ||
"audio/mpga": "mpga", | ||
"audio/wav": "wav", | ||
"audio/webm": "webm", | ||
"audio/ogg": "ogg", | ||
} | ||
|
||
# Setup LLM and prompt | ||
llm = ChatGroq( | ||
model_name="llama3-groq-70b-8192-tool-use-preview", | ||
api_key=os.environ.get("GROQ_API_KEY"), | ||
temperature=0.5, | ||
) | ||
|
||
ROUTER_TEMPLATE = """ | ||
Given a user message, determine the appropriate service to handle the request. | ||
Choose between: | ||
- discovery: For finding grant opportunities | ||
- proposal: For writing grant proposals | ||
- onboarding: For getting help using the system | ||
- heartbeat: For checking system health | ||
User message: {message} | ||
Return only one word (discovery/proposal/onboarding/heartbeat):""" | ||
|
||
TOPIC_EXTRACTOR_TEMPLATE = """ | ||
Extract up to 5 most relevant topics for grant opportunity research from the user message. | ||
Return only a comma-separated list of topics (maximum 5), no other text. | ||
User message: {message} | ||
Topics:""" | ||
|
||
PROPOSAL_EXTRACTOR_TEMPLATE = """ | ||
Extract the community project name and grant program name from the user message. | ||
Return in format: project_name|grant_name | ||
If either cannot be determined, use "unknown" as placeholder. | ||
User message: {message} | ||
Output:""" | ||
|
||
# Create prompt templates and chains | ||
router_prompt = PromptTemplate(input_variables=["message"], template=ROUTER_TEMPLATE) | ||
topic_prompt = PromptTemplate( | ||
input_variables=["message"], template=TOPIC_EXTRACTOR_TEMPLATE | ||
) | ||
proposal_prompt = PromptTemplate( | ||
input_variables=["message"], template=PROPOSAL_EXTRACTOR_TEMPLATE | ||
) | ||
|
||
router_chain = LLMChain(llm=llm, prompt=router_prompt) | ||
topic_chain = LLMChain(llm=llm, prompt=topic_prompt) | ||
proposal_chain = LLMChain(llm=llm, prompt=proposal_prompt) | ||
|
||
|
||
def detect_content_type(file: UploadFile) -> Optional[str]: | ||
"""Helper to detect content type from file""" | ||
if hasattr(file, "content_type") and file.content_type: | ||
return file.content_type | ||
|
||
if hasattr(file, "mime_type") and file.mime_type: | ||
return file.mime_type | ||
|
||
ext = os.path.splitext(file.filename)[1].lower() | ||
return { | ||
".mp3": "audio/mpeg", | ||
".mp4": "audio/mp4", | ||
".mpeg": "audio/mpeg", | ||
".mpga": "audio/mpga", | ||
".m4a": "audio/mp4", | ||
".wav": "audio/wav", | ||
".webm": "audio/webm", | ||
".ogg": "audio/ogg", | ||
}.get(ext) | ||
|
||
|
||
async def process_audio(audio: UploadFile) -> str: | ||
"""Process audio file and return transcription""" | ||
content_type = detect_content_type(audio) | ||
content = await audio.read() | ||
audio_path = "" | ||
|
||
try: | ||
if not content_type: | ||
content_type = "audio/mpeg" | ||
|
||
if content_type == "audio/ogg": | ||
audio_path = convert_ogg(content, output_format="mp3") | ||
else: | ||
with tempfile.NamedTemporaryFile( | ||
suffix=f".{ALLOWED_FORMATS.get(content_type, 'mp3')}", delete=False | ||
) as temp_file: | ||
temp_file.write(content) | ||
audio_path = temp_file.name | ||
|
||
return transcribe_audio(audio_path) | ||
finally: | ||
if os.path.exists(audio_path): | ||
os.unlink(audio_path) | ||
|
||
|
||
def extract_topics(message: str) -> List[str]: | ||
"""Extract topics from message""" | ||
topics_raw = topic_chain.run(message=message) | ||
topics = [t.strip() for t in topics_raw.split(",") if t.strip()][:5] | ||
return topics if topics else ["AI", "Technology"] | ||
|
||
|
||
def extract_proposal_details(message: str) -> tuple[str, str]: | ||
"""Extract project and grant details""" | ||
extracted = proposal_chain.run(message=message).split("|") | ||
community_project = extracted[0].strip() if len(extracted) > 0 else "unknown" | ||
grant_call = extracted[1].strip() if len(extracted) > 1 else "unknown" | ||
return community_project, grant_call | ||
|
||
|
||
def process_decision(decision: str, message: str) -> Dict[str, Any]: | ||
"""Process routing decision and return result""" | ||
print("\n==================================================") | ||
print(f" DECISION: {decision}") | ||
print("==================================================\n") | ||
|
||
if decision == "discovery": | ||
topics = extract_topics(message) | ||
print("\n==================================================") | ||
print(f" EXTRACTED TOPICS: {topics}") | ||
print("==================================================\n") | ||
return ( | ||
OpportunityFinderCrew().crew().kickoff(inputs={"topics": ", ".join(topics)}) | ||
) | ||
elif decision == "proposal": | ||
community_project, grant_call = extract_proposal_details(message) | ||
print(f" PROJECT NAME: {community_project}") | ||
print(f" GRANT PROGRAM: {grant_call}") | ||
return ( | ||
ProposalWriterCrew( | ||
community_project=community_project, grant_call=grant_call | ||
) | ||
.crew() | ||
.kickoff() | ||
) | ||
elif decision == "heartbeat": | ||
return {"is_alive": True} | ||
elif decision == "onboarding": | ||
return OnboardingCrew().crew().kickoff() | ||
else: | ||
return {"error": f"Unknown decision type: {decision}"} | ||
|
||
|
||
@router.post("/classify", response_model=ClassifierResponse) | ||
async def classifier_route( | ||
message: Optional[str] = Form(default=None), | ||
audio: Optional[UploadFile] = File(default=None), | ||
) -> ClassifierResponse: | ||
"""Main route handler for classifier API""" | ||
try: | ||
combined_parts = [] | ||
has_valid_audio = False | ||
|
||
# Process audio if provided | ||
if audio is not None: | ||
# Check if audio is not empty | ||
audio_content = await audio.read() | ||
has_valid_audio = len(audio_content) > 0 | ||
|
||
if has_valid_audio: | ||
await audio.seek(0) | ||
transcription = await process_audio(audio) | ||
print("==================================================") | ||
print(f" TRANSCRIPTION: {transcription}") | ||
print("==================================================") | ||
combined_parts.append( | ||
f'Transcription of attached audio: "{transcription}"' | ||
) | ||
|
||
# Add message if provided | ||
if message: | ||
combined_parts.append(f'Message: "{message}"') | ||
|
||
# Combine parts with newlines | ||
combined_message = "\n".join(combined_parts) | ||
|
||
if combined_message: | ||
print("==================================================") | ||
print(f" COMBINED MESSAGE:\n{combined_message}") | ||
print("==================================================") | ||
|
||
# Ensure we have some input to process | ||
if not combined_message: | ||
return ClassifierResponse( | ||
result="Error: Neither valid message nor valid audio provided" | ||
) | ||
|
||
# Process the combined input | ||
decision = router_chain.run(message=combined_message).strip().lower() | ||
result = process_decision(decision, combined_message) | ||
|
||
return ClassifierResponse(result=str(result)) | ||
|
||
except Exception as e: | ||
return ClassifierResponse(result=f"Error processing request: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
from fastapi import APIRouter | ||
|
||
from eda_ai_api.api.routes import grant, heartbeat, onboarding, supervisor | ||
from eda_ai_api.api.routes import classifier, grant, heartbeat, onboarding | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(heartbeat.router, tags=["health"], prefix="/health") | ||
api_router.include_router(grant.router, tags=["discovery"], prefix="/grant") | ||
api_router.include_router(supervisor.router, tags=["supervisor"], prefix="/supervisor") | ||
api_router.include_router(classifier.router, tags=["classifier"], prefix="/classifier") | ||
api_router.include_router(onboarding.router, tags=["onboarding"], prefix="/onboarding") |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from typing import Optional | ||
|
||
from fastapi import UploadFile | ||
from pydantic import BaseModel | ||
|
||
|
||
class ClassifierRequest(BaseModel): | ||
message: Optional[str] = None | ||
audio: Optional[UploadFile] = None | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
|
||
class ClassifierResponse(BaseModel): | ||
result: str |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.