-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
307 additions
and
248 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,262 +1,120 @@ | ||
import json | ||
import os | ||
from typing import Optional, List, Dict | ||
import uuid | ||
from typing import Any, Dict, Optional | ||
|
||
from fastapi import APIRouter, File, Form, UploadFile | ||
from llama_index.core import PromptTemplate | ||
from llama_index.llms.groq import Groq | ||
from langchain_groq import ChatGroq | ||
from langchain_core.messages import HumanMessage | ||
from loguru import logger | ||
import httpx | ||
|
||
from eda_ai_api.models.classifier import ClassifierResponse, MessageHistory | ||
from eda_ai_api.utils.audio_utils import process_audio_file | ||
from eda_ai_api.utils.memory import ZepConversationManager | ||
from eda_ai_api.utils.memory import SupabaseMemory | ||
from eda_ai_api.utils.prompts import ( | ||
ROUTER_TEMPLATE, | ||
INSUFFICIENT_TEMPLATES, | ||
PROPOSAL_TEMPLATE, | ||
RESPONSE_PROCESSOR_TEMPLATE, | ||
TOPIC_TEMPLATE, | ||
) | ||
|
||
router = APIRouter() | ||
|
||
# Setup LLM | ||
llm = Groq( | ||
llm = ChatGroq( | ||
model="llama-3.3-70b-versatile", | ||
api_key=os.environ.get("GROQ_API_KEY"), | ||
api_key="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", | ||
temperature=0.5, | ||
) | ||
|
||
|
||
async def extract_topics(message: str, history: list) -> list[str]: | ||
"""Extract topics with conversation context""" | ||
context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history]) | ||
response = llm.complete(TOPIC_TEMPLATE.format(context=context, message=message)) | ||
if response.text.strip() == "INSUFFICIENT_CONTEXT": | ||
return ["INSUFFICIENT_CONTEXT"] | ||
topics = [t.strip() for t in response.text.split(",") if t.strip()][:5] | ||
return topics if topics else ["INSUFFICIENT_CONTEXT"] | ||
|
||
|
||
async def extract_proposal_details(message: str, history: list) -> tuple[str, str]: | ||
"""Extract project and grant details with conversation context""" | ||
context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history]) | ||
response = llm.complete(PROPOSAL_TEMPLATE.format(context=context, message=message)) | ||
extracted = response.text.split("|") | ||
community_project = extracted[0].strip() if len(extracted) > 0 else "unknown" | ||
grant_call = extracted[1].strip() if len(extracted) > 1 else "unknown" | ||
return community_project, grant_call | ||
|
||
|
||
async def process_llm_response(message: str, response: str) -> str: | ||
processed = llm.complete( | ||
RESPONSE_PROCESSOR_TEMPLATE.format(original_message=message, response=response) | ||
async def process_discovery(message: str, context: str) -> Dict[str, str]: | ||
"""Process discovery requests""" | ||
response = llm.invoke( | ||
[HumanMessage(content=TOPIC_TEMPLATE.format(context=context, message=message))] | ||
) | ||
logger.info(f"Processed response: {processed.text}") | ||
return processed.text | ||
|
||
|
||
async def process_decision( | ||
decision: str, | ||
message: str, | ||
zep_history: list, | ||
supabase_history: list[MessageHistory] = [], | ||
) -> Dict[str, Any]: | ||
"""Process routing decision with conversation context from both sources""" | ||
logger.info(f"Processing decision: {decision} for message: {message}") | ||
|
||
# Combine histories for context | ||
context_parts = [] | ||
|
||
if supabase_history: | ||
supabase_context = format_supabase_history(supabase_history) | ||
context_parts.append(f"Recent conversation:\n{supabase_context}") | ||
|
||
if zep_history: | ||
zep_context = "\n".join( | ||
[f"{msg['role']}: {msg['content']}" for msg in zep_history] | ||
) | ||
context_parts.append(f"Long-term memory:\n{zep_context}") | ||
|
||
combined_context = "\n\n".join(context_parts) | ||
|
||
async with httpx.AsyncClient() as client: | ||
if decision == "discovery": | ||
topics = await extract_topics(message, zep_history) | ||
logger.info(f"Extracted topics: {topics}") | ||
|
||
if topics == ["INSUFFICIENT_CONTEXT"]: | ||
response = llm.complete( | ||
INSUFFICIENT_TEMPLATES["discovery"].format( | ||
context=combined_context, message=message | ||
) | ||
) | ||
processed_response = await process_llm_response(message, response.text) | ||
return {"response": processed_response} | ||
|
||
# Call discovery API instead of crew directly | ||
api_response = await client.post( | ||
"http://127.0.0.1:8083/api/grant/discovery", json={"topics": topics} | ||
) | ||
result = api_response.json() | ||
processed_response = await process_llm_response(message, str(result)) | ||
return {"response": processed_response} | ||
|
||
elif decision == "proposal": | ||
community_project, grant_call = await extract_proposal_details( | ||
message, zep_history | ||
) | ||
logger.info(f"Project: {community_project}, Grant: {grant_call}") | ||
|
||
if community_project == "unknown" or grant_call == "unknown": | ||
response = llm.complete( | ||
INSUFFICIENT_TEMPLATES["proposal"].format( | ||
context=combined_context, message=message | ||
) | ||
) | ||
processed_response = await process_llm_response(message, response.text) | ||
return {"response": processed_response} | ||
|
||
# Call proposal API instead of crew directly | ||
api_response = await client.post( | ||
"http://127.0.0.1:8083/api/grant/proposal", | ||
json={"project": community_project, "grant": grant_call}, | ||
) | ||
result = api_response.json() | ||
processed_response = await process_llm_response(message, str(result)) | ||
return {"response": processed_response} | ||
|
||
elif decision == "heartbeat": | ||
processed_response = await process_llm_response( | ||
message, str({"is_alive": True}) | ||
) | ||
return {"response": processed_response} | ||
topics = [t.strip() for t in response.content.split(",") if t.strip()][:5] | ||
logger.info(f"Extracted Topics: {topics}") | ||
|
||
elif decision == "onboarding": | ||
# Use existing guide endpoint instead of creating new one | ||
api_response = await client.get( | ||
"http://127.0.0.1:8083/api/onboarding/guide" | ||
if not topics or topics == ["INSUFFICIENT_CONTEXT"]: | ||
return { | ||
"result": INSUFFICIENT_TEMPLATES["discovery"].format( | ||
context=context, message=message | ||
) | ||
result = api_response.json() | ||
processed_response = await process_llm_response(message, str(result)) | ||
return {"response": processed_response} | ||
|
||
else: | ||
return {"error": f"Unknown decision type: {decision}"} | ||
|
||
} | ||
|
||
def format_supabase_history(history: list[MessageHistory]) -> str: | ||
"""Format last 10 Supabase messages into conversation format""" | ||
if not history: | ||
return "" | ||
async with httpx.AsyncClient() as client: | ||
logger.info("Calling discovery API...") | ||
api_response = await client.post( | ||
"http://127.0.0.1:8083/api/grant/discovery", json={"topics": topics} | ||
) | ||
logger.info(f"Discovery API Response: {api_response.json()}") | ||
return {"result": str(api_response.json())} | ||
|
||
# Get last 10 messages | ||
limited_history = history[-10:] | ||
|
||
formatted = [] | ||
for msg in limited_history: | ||
formatted.extend([f"human: {msg.human}", f"assistant: {msg.ai}"]) | ||
async def route_message(message: str, context: str) -> str: | ||
"""Route message to appropriate handler""" | ||
logger.info(f"\n=== Routing Message ===\nContext: {context}\nMessage: {message}") | ||
|
||
return "\n".join(formatted[-10:]) # Take last 10 messages total | ||
response = llm.invoke( | ||
[HumanMessage(content=ROUTER_TEMPLATE.format(message=message))] | ||
) | ||
decision = response.content.lower().strip() | ||
logger.info(f"Router Decision: {decision}") | ||
return decision | ||
|
||
|
||
@router.post("/classify", response_model=ClassifierResponse) | ||
async def classifier_route( | ||
message: Optional[str] = Form(default=None), | ||
audio: Optional[UploadFile] = File(default=None), | ||
session_id: Optional[str] = Form(default=None), | ||
message_history: Optional[str] = Form(default=None), # JSON string | ||
message_history: Optional[str] = Form(default=None), | ||
) -> ClassifierResponse: | ||
"""Main route handler with conversation memory""" | ||
"""Main route handler""" | ||
try: | ||
# Generate a default session_id if none provided | ||
current_session_id = session_id or str(uuid.uuid4()) | ||
user_id = f"{current_session_id}_{uuid.uuid4().hex}" | ||
|
||
logger.info(f"New request - Session: {current_session_id}, User: {user_id}") | ||
|
||
# Initialize both history sources | ||
zep = ZepConversationManager() | ||
zep_session_id = await zep.get_or_create_session( | ||
user_id=user_id, session_id=current_session_id | ||
) | ||
current_session = session_id or str(uuid.uuid4()) | ||
logger.info(f"New request - Session: {current_session}") | ||
|
||
# Process inputs | ||
combined_parts = [] | ||
|
||
if audio is not None: | ||
combined_message = [] | ||
if audio: | ||
transcription = await process_audio_file(audio) | ||
logger.info(f"Audio transcription: {transcription}") | ||
combined_parts.append(f'Transcription: "{transcription}"') | ||
|
||
combined_message.append(f'Transcription: "{transcription}"') | ||
if message: | ||
logger.info(f"Text message: {message}") | ||
combined_parts.append(f'Message: "{message}"') | ||
combined_message.append(f'Message: "{message}"') | ||
|
||
combined_message = "\n".join(combined_parts) | ||
if not combined_message: | ||
return ClassifierResponse(result="Error: No valid input provided") | ||
|
||
# Get both conversation histories | ||
zep_history = await zep.get_conversation_history(zep_session_id) | ||
supabase_history = [] | ||
# Get conversation history | ||
history = [] | ||
if message_history: | ||
try: | ||
supabase_history = [ | ||
MessageHistory(**msg) for msg in json.loads(message_history) | ||
] | ||
history = [MessageHistory(**msg) for msg in json.loads(message_history)] | ||
except json.JSONDecodeError: | ||
logger.warning("Invalid message_history JSON format") | ||
|
||
# Combine both histories for context | ||
zep_context = "\n".join( | ||
[f"{msg['role']}: {msg['content']}" for msg in zep_history] | ||
) | ||
supabase_context = format_supabase_history(supabase_history) | ||
# Format context and route message | ||
context = SupabaseMemory.format_history(history) | ||
final_message = "\n".join(combined_message) | ||
decision = await route_message(final_message, context) | ||
logger.info(f"Routing decision: {decision}") | ||
|
||
combined_context = f"""Recent conversation:\n{supabase_context}\n\nLong-term memory:\n{zep_context}""" | ||
|
||
logger.info(f"Combined context:\n{combined_context}") | ||
|
||
# Use combined context in router prompt | ||
router_prompt = PromptTemplate( | ||
"""Previous conversation:\n{context}\n\n""" | ||
"""Given the current user message, determine the appropriate service:""" | ||
"""\n{message}\n\n""" | ||
"""Return only one word (discovery/proposal/onboarding/heartbeat):""" | ||
) | ||
|
||
response = llm.complete( | ||
router_prompt.format(context=combined_context, message=combined_message) | ||
) | ||
decision = response.text.strip().lower() | ||
|
||
# Process decision using combined context | ||
result = await process_decision( | ||
decision, combined_message, zep_history, supabase_history | ||
) | ||
|
||
# Process final result if it's not already processed | ||
if isinstance(result.get("response"), str): | ||
final_result = await process_llm_response(combined_message, str(result)) | ||
# Process based on route | ||
if decision == "discovery": | ||
response = await process_discovery(final_message, context) | ||
elif decision == "heartbeat": | ||
response = {"result": "*Yes, I'm here! 🟢*\n_Ready to help you!_"} | ||
else: | ||
final_result = str(result) | ||
|
||
# Truncate if result exceeds character limit | ||
if len(final_result) > 2499: | ||
logger.warning( | ||
f"Result exceeded 2499 characters (was {len(final_result)}). Truncating..." | ||
) | ||
final_result = final_result[:2499] | ||
response = {"result": f"Service '{decision}' not implemented yet"} | ||
|
||
# Log both result and character count | ||
logger.info(f"Final result ({len(final_result)} chars): {final_result}") | ||
result = response["result"] | ||
if len(result) > 2499: | ||
result = result[:2499] | ||
|
||
await zep.add_conversation(zep_session_id, combined_message, final_result) | ||
return ClassifierResponse(result=final_result, session_id=zep_session_id) | ||
return ClassifierResponse(result=result, session_id=current_session) | ||
|
||
except Exception as e: | ||
logger.error(f"Error in classifier route: {str(e)}") | ||
error_msg = await process_llm_response(combined_message, f"Error: {str(e)}") | ||
return ClassifierResponse(result=error_msg) | ||
return ClassifierResponse(result=f"Error: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.