Skip to content

Commit

Permalink
Integrate Mem0ConversationManager for enhanced session management and…
Browse files Browse the repository at this point in the history
… update API key handling
  • Loading branch information
Luisotee committed Jan 11, 2025
1 parent 3ba2471 commit 70b1436
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 24 deletions.
30 changes: 23 additions & 7 deletions apps/ai_api/eda_ai_api/api/routes/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from eda_ai_api.models.classifier import ClassifierResponse, MessageHistory
from eda_ai_api.utils.audio_utils import process_audio_file
from eda_ai_api.utils.memory import SupabaseMemory
from eda_ai_api.utils.memory import SupabaseMemory, Mem0ConversationManager
from eda_ai_api.utils.prompts import (
ROUTER_TEMPLATE,
INSUFFICIENT_TEMPLATES,
Expand All @@ -19,11 +19,12 @@
)

router = APIRouter()
mem0_manager = Mem0ConversationManager()

# Setup LLM
llm = ChatGroq(
model="llama-3.3-70b-versatile",
api_key="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",
api_key=os.environ.get("GROQ_API_KEY"), # Use environment variable
temperature=0.5,
)

Expand Down Expand Up @@ -76,6 +77,11 @@ async def classifier_route(
current_session = session_id or str(uuid.uuid4())
logger.info(f"New request - Session: {current_session}")

# Initialize/get Mem0 session
current_session = await mem0_manager.get_or_create_session(
session_id=current_session
)

# Process inputs
combined_message = []
if audio:
Expand All @@ -87,23 +93,26 @@ async def classifier_route(
if not combined_message:
return ClassifierResponse(result="Error: No valid input provided")

# Get conversation history
# Get conversation history from both systems
history = []
if message_history:
try:
history = [MessageHistory(**msg) for msg in json.loads(message_history)]
except json.JSONDecodeError:
logger.warning("Invalid message_history JSON format")

# Format context and route message
context = SupabaseMemory.format_history(history)
# Get long-term context from Mem0
mem0_history = await mem0_manager.get_conversation_history(current_session)

# Combine contexts
supabase_context = SupabaseMemory.format_history(history)
final_message = "\n".join(combined_message)
decision = await route_message(final_message, context)
decision = await route_message(final_message, supabase_context)
logger.info(f"Routing decision: {decision}")

# Process based on route
if decision == "discovery":
response = await process_discovery(final_message, context)
response = await process_discovery(final_message, supabase_context)
elif decision == "heartbeat":
response = {"result": "*Yes, I'm here! 🟢*\n_Ready to help you!_"}
else:
Expand All @@ -113,6 +122,13 @@ async def classifier_route(
if len(result) > 2499:
result = result[:2499]

# Store interaction in Mem0
await mem0_manager.add_conversation(
session_id=current_session,
user_message=final_message,
assistant_response=result,
)

return ClassifierResponse(result=result, session_id=current_session)

except Exception as e:
Expand Down
36 changes: 19 additions & 17 deletions apps/ai_api/eda_ai_api/utils/memory.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,53 @@
import uuid
import os
from typing import Dict, List, Optional
from dotenv import load_dotenv
from eda_ai_api.models.classifier import MessageHistory
from loguru import logger
from mem0 import Memory

# Load environment variables
load_dotenv()


class Mem0ConversationManager:
def __init__(self):
# Validate API keys
groq_api_key = os.getenv("GROQ_API_KEY")
huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY")

if not huggingface_api_key:
raise ValueError("HUGGINGFACE_API_KEY environment variable is not set")
if not groq_api_key:
raise ValueError("GROQ_API_KEY environment variable is not set")

config = {
"version": "v1.1", # Required version field
"version": "v1.1",
"graph_store": {
"provider": "neo4j",
"config": {
"url": "bolt://localhost:7687",
"username": "neo4j",
"password": "password",
},
"llm": { # Graph store specific LLM
"provider": "groq",
"config": {
"model": "llama-3.3-70b-versatile",
"api_key": os.environ.get("GROQ_API_KEY"),
"temperature": 0.0,
},
},
},
"embedder": {
"provider": "huggingface",
"config": {
"model": "sentence-transformers/all-mpnet-base-v2",
"api_key": os.environ.get("HUGGINGFACE_API_KEY"),
"api_key": huggingface_api_key,
},
},
"llm": { # Main LLM
"llm": {
"provider": "groq",
"config": {
"model": "llama-3.3-70b-versatile",
"api_key": os.environ.get("GROQ_API_KEY"),
"temperature": 0,
"max_tokens": 8000,
"api_key": groq_api_key,
"max_tokens": 1000,
},
},
}
self.memory = Memory.from_config(
config_dict=config
) # Changed to use config_dict parameter
self.memory = Memory.from_config(config_dict=config)

async def get_or_create_session(
self, user_id: Optional[str] = None, session_id: Optional[str] = None
Expand Down
1 change: 1 addition & 0 deletions apps/ai_api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"neo4j>=5.27.0",
"langchain>=0.3.4",
"langgraph>=0.2.35",
"anthropic>=0.42.0",
]

[project.optional-dependencies]
Expand Down
20 changes: 20 additions & 0 deletions apps/ai_api/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 70b1436

Please sign in to comment.