Skip to content

Commit

Permalink
Merge pull request #19 from digidem/luisotee/audio-transcription
Browse files Browse the repository at this point in the history
Audio transcription in supervisor route #12
  • Loading branch information
luandro authored Dec 2, 2024
2 parents d2651e5 + 6214aa3 commit c540a59
Show file tree
Hide file tree
Showing 9 changed files with 369 additions and 146 deletions.
223 changes: 223 additions & 0 deletions apps/ai_api/eda_ai_api/api/routes/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import os
import tempfile
from typing import Any, Dict, List, Optional

from fastapi import APIRouter, File, Form, UploadFile
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
from onboarding.crew import OnboardingCrew
from opportunity_finder.crew import OpportunityFinderCrew
from proposal_writer.crew import ProposalWriterCrew

from eda_ai_api.models.classifier import ClassifierResponse
from eda_ai_api.utils.audio_converter import convert_ogg
from eda_ai_api.utils.transcriber import transcribe_audio

router = APIRouter()

ALLOWED_FORMATS = {
"audio/mpeg": "mp3",
"audio/mp4": "mp4",
"audio/mpga": "mpga",
"audio/wav": "wav",
"audio/webm": "webm",
"audio/ogg": "ogg",
}

# Setup LLM and prompt
llm = ChatGroq(
model_name="llama3-groq-70b-8192-tool-use-preview",
api_key=os.environ.get("GROQ_API_KEY"),
temperature=0.5,
)

ROUTER_TEMPLATE = """
Given a user message, determine the appropriate service to handle the request.
Choose between:
- discovery: For finding grant opportunities
- proposal: For writing grant proposals
- onboarding: For getting help using the system
- heartbeat: For checking system health
User message: {message}
Return only one word (discovery/proposal/onboarding/heartbeat):"""

TOPIC_EXTRACTOR_TEMPLATE = """
Extract up to 5 most relevant topics for grant opportunity research from the user message.
Return only a comma-separated list of topics (maximum 5), no other text.
User message: {message}
Topics:"""

PROPOSAL_EXTRACTOR_TEMPLATE = """
Extract the community project name and grant program name from the user message.
Return in format: project_name|grant_name
If either cannot be determined, use "unknown" as placeholder.
User message: {message}
Output:"""

# Create prompt templates and chains
router_prompt = PromptTemplate(input_variables=["message"], template=ROUTER_TEMPLATE)
topic_prompt = PromptTemplate(
input_variables=["message"], template=TOPIC_EXTRACTOR_TEMPLATE
)
proposal_prompt = PromptTemplate(
input_variables=["message"], template=PROPOSAL_EXTRACTOR_TEMPLATE
)

router_chain = LLMChain(llm=llm, prompt=router_prompt)
topic_chain = LLMChain(llm=llm, prompt=topic_prompt)
proposal_chain = LLMChain(llm=llm, prompt=proposal_prompt)


def detect_content_type(file: UploadFile) -> Optional[str]:
"""Helper to detect content type from file"""
if hasattr(file, "content_type") and file.content_type:
return file.content_type

if hasattr(file, "mime_type") and file.mime_type:
return file.mime_type

ext = os.path.splitext(file.filename)[1].lower()
return {
".mp3": "audio/mpeg",
".mp4": "audio/mp4",
".mpeg": "audio/mpeg",
".mpga": "audio/mpga",
".m4a": "audio/mp4",
".wav": "audio/wav",
".webm": "audio/webm",
".ogg": "audio/ogg",
}.get(ext)


async def process_audio(audio: UploadFile) -> str:
"""Process audio file and return transcription"""
content_type = detect_content_type(audio)
content = await audio.read()
audio_path = ""

try:
if not content_type:
content_type = "audio/mpeg"

if content_type == "audio/ogg":
audio_path = convert_ogg(content, output_format="mp3")
else:
with tempfile.NamedTemporaryFile(
suffix=f".{ALLOWED_FORMATS.get(content_type, 'mp3')}", delete=False
) as temp_file:
temp_file.write(content)
audio_path = temp_file.name

return transcribe_audio(audio_path)
finally:
if os.path.exists(audio_path):
os.unlink(audio_path)


def extract_topics(message: str) -> List[str]:
"""Extract topics from message"""
topics_raw = topic_chain.run(message=message)
topics = [t.strip() for t in topics_raw.split(",") if t.strip()][:5]
return topics if topics else ["AI", "Technology"]


def extract_proposal_details(message: str) -> tuple[str, str]:
"""Extract project and grant details"""
extracted = proposal_chain.run(message=message).split("|")
community_project = extracted[0].strip() if len(extracted) > 0 else "unknown"
grant_call = extracted[1].strip() if len(extracted) > 1 else "unknown"
return community_project, grant_call


def process_decision(decision: str, message: str) -> Dict[str, Any]:
"""Process routing decision and return result"""
print("\n==================================================")
print(f" DECISION: {decision}")
print("==================================================\n")

if decision == "discovery":
topics = extract_topics(message)
print("\n==================================================")
print(f" EXTRACTED TOPICS: {topics}")
print("==================================================\n")
return (
OpportunityFinderCrew().crew().kickoff(inputs={"topics": ", ".join(topics)})
)
elif decision == "proposal":
community_project, grant_call = extract_proposal_details(message)
print(f" PROJECT NAME: {community_project}")
print(f" GRANT PROGRAM: {grant_call}")
return (
ProposalWriterCrew(
community_project=community_project, grant_call=grant_call
)
.crew()
.kickoff()
)
elif decision == "heartbeat":
return {"is_alive": True}
elif decision == "onboarding":
return OnboardingCrew().crew().kickoff()
else:
return {"error": f"Unknown decision type: {decision}"}


@router.post("/classify", response_model=ClassifierResponse)
async def classifier_route(
message: Optional[str] = Form(default=None),
audio: Optional[UploadFile] = File(default=None),
) -> ClassifierResponse:
"""Main route handler for classifier API"""
try:
combined_parts = []
has_valid_audio = False

# Process audio if provided
if audio is not None:
# Check if audio is not empty
audio_content = await audio.read()
has_valid_audio = len(audio_content) > 0

if has_valid_audio:
await audio.seek(0)
transcription = await process_audio(audio)
print("==================================================")
print(f" TRANSCRIPTION: {transcription}")
print("==================================================")
combined_parts.append(
f'Transcription of attached audio: "{transcription}"'
)

# Add message if provided
if message:
combined_parts.append(f'Message: "{message}"')

# Combine parts with newlines
combined_message = "\n".join(combined_parts)

if combined_message:
print("==================================================")
print(f" COMBINED MESSAGE:\n{combined_message}")
print("==================================================")

# Ensure we have some input to process
if not combined_message:
return ClassifierResponse(
result="Error: Neither valid message nor valid audio provided"
)

# Process the combined input
decision = router_chain.run(message=combined_message).strip().lower()
result = process_decision(decision, combined_message)

return ClassifierResponse(result=str(result))

except Exception as e:
return ClassifierResponse(result=f"Error processing request: {str(e)}")
4 changes: 2 additions & 2 deletions apps/ai_api/eda_ai_api/api/routes/router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fastapi import APIRouter

from eda_ai_api.api.routes import grant, heartbeat, onboarding, supervisor
from eda_ai_api.api.routes import classifier, grant, heartbeat, onboarding

api_router = APIRouter()
api_router.include_router(heartbeat.router, tags=["health"], prefix="/health")
api_router.include_router(grant.router, tags=["discovery"], prefix="/grant")
api_router.include_router(supervisor.router, tags=["supervisor"], prefix="/supervisor")
api_router.include_router(classifier.router, tags=["classifier"], prefix="/classifier")
api_router.include_router(onboarding.router, tags=["onboarding"], prefix="/onboarding")
134 changes: 0 additions & 134 deletions apps/ai_api/eda_ai_api/api/routes/supervisor.py

This file was deleted.

16 changes: 16 additions & 0 deletions apps/ai_api/eda_ai_api/models/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional

from fastapi import UploadFile
from pydantic import BaseModel


class ClassifierRequest(BaseModel):
message: Optional[str] = None
audio: Optional[UploadFile] = None

class Config:
arbitrary_types_allowed = True


class ClassifierResponse(BaseModel):
result: str
9 changes: 0 additions & 9 deletions apps/ai_api/eda_ai_api/models/supervisor.py

This file was deleted.

Loading

0 comments on commit c540a59

Please sign in to comment.