Skip to content

Commit

Permalink
Support async client (#1980)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dev-Khant authored Oct 22, 2024
1 parent c5d298e commit fbf1d8c
Show file tree
Hide file tree
Showing 11 changed files with 214 additions and 59 deletions.
17 changes: 17 additions & 0 deletions docs/platform/quickstart.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ const client = new MemoryClient('your-api-key');

</CodeGroup>

### 3.1 Instantiate Async Client (Python only)

For asynchronous operations in Python, you can use the AsyncMemoryClient:

```python Python
from mem0 import AsyncMemoryClient

client = AsyncMemoryClient(api_key="your-api-key")


async def main():
response = await client.add("I'm travelling to SF", user_id="john")
print(response)

await main()
```

## 4. Memory Operations

Mem0 provides a simple and customizable interface for performing CRUD operations on memory.
Expand Down
2 changes: 1 addition & 1 deletion mem0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

__version__ = importlib.metadata.version("mem0ai")

from mem0.client.main import MemoryClient # noqa
from mem0.client.main import MemoryClient, AsyncMemoryClient # noqa
from mem0.memory.main import Memory # noqa
133 changes: 124 additions & 9 deletions mem0/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ class MemoryClient:
"""

def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None
):
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
):
"""Initialize the MemoryClient.
Args:
Expand Down Expand Up @@ -275,9 +275,7 @@ def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.organization, "project_name": self.project}
entities = self.users()
for entity in entities["results"]:
response = self.client.delete(
f"/v1/entities/{entity['type']}/{entity['id']}/", params=params
)
response = self.client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()

capture_client_event("client.delete_users", self)
Expand Down Expand Up @@ -362,3 +360,120 @@ def _prepare_params(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
kwargs["run_id"] = kwargs.pop("session_id")

return {k: v for k, v in kwargs.items() if v is not None}


class AsyncMemoryClient:
"""Asynchronous client for interacting with the Mem0 API."""

def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
organization: Optional[str] = None,
project: Optional[str] = None,
):
self.sync_client = MemoryClient(api_key, host, organization, project)
self.async_client = httpx.AsyncClient(
base_url=self.sync_client.host,
headers=self.sync_client.client.headers,
timeout=60,
)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.async_client.aclose()

@api_error_handler
async def add(self, messages: Union[str, List[Dict[str, str]]], **kwargs) -> Dict[str, Any]:
payload = self.sync_client._prepare_payload(messages, kwargs)
response = await self.async_client.post("/v1/memories/", json=payload)
response.raise_for_status()
capture_client_event("async_client.add", self.sync_client)
return response.json()

@api_error_handler
async def get(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.get", self.sync_client)
return response.json()

@api_error_handler
async def get_all(self, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
params = self.sync_client._prepare_params(kwargs)
if version == "v1":
response = await self.async_client.get(f"/{version}/memories/", params=params)
elif version == "v2":
response = await self.async_client.post(f"/{version}/memories/", json=params)
response.raise_for_status()
capture_client_event(
"async_client.get_all", self.sync_client, {"filters": len(params), "limit": kwargs.get("limit", 100)}
)
return response.json()

@api_error_handler
async def search(self, query: str, version: str = "v1", **kwargs) -> List[Dict[str, Any]]:
payload = {"query": query}
payload.update(self.sync_client._prepare_params(kwargs))
response = await self.async_client.post(f"/{version}/memories/search/", json=payload)
response.raise_for_status()
capture_client_event("async_client.search", self.sync_client, {"limit": kwargs.get("limit", 100)})
return response.json()

@api_error_handler
async def update(self, memory_id: str, data: str) -> Dict[str, Any]:
response = await self.async_client.put(f"/v1/memories/{memory_id}/", json={"text": data})
response.raise_for_status()
capture_client_event("async_client.update", self.sync_client)
return response.json()

@api_error_handler
async def delete(self, memory_id: str) -> Dict[str, Any]:
response = await self.async_client.delete(f"/v1/memories/{memory_id}/")
response.raise_for_status()
capture_client_event("async_client.delete", self.sync_client)
return response.json()

@api_error_handler
async def delete_all(self, **kwargs) -> Dict[str, str]:
params = self.sync_client._prepare_params(kwargs)
response = await self.async_client.delete("/v1/memories/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_all", self.sync_client, {"params": len(params)})
return response.json()

@api_error_handler
async def history(self, memory_id: str) -> List[Dict[str, Any]]:
response = await self.async_client.get(f"/v1/memories/{memory_id}/history/")
response.raise_for_status()
capture_client_event("async_client.history", self.sync_client)
return response.json()

@api_error_handler
async def users(self) -> Dict[str, Any]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
response = await self.async_client.get("/v1/entities/", params=params)
response.raise_for_status()
capture_client_event("async_client.users", self.sync_client)
return response.json()

@api_error_handler
async def delete_users(self) -> Dict[str, str]:
params = {"org_name": self.sync_client.organization, "project_name": self.sync_client.project}
entities = await self.users()
for entity in entities["results"]:
response = await self.async_client.delete(f"/v1/entities/{entity['type']}/{entity['id']}/", params=params)
response.raise_for_status()
capture_client_event("async_client.delete_users", self.sync_client)
return {"message": "All users, agents, and sessions deleted."}

@api_error_handler
async def reset(self) -> Dict[str, str]:
await self.delete_users()
capture_client_event("async_client.reset", self.sync_client)
return {"message": "Client reset successful. All users and memories deleted."}

async def chat(self):
raise NotImplementedError("Chat is not implemented yet")
4 changes: 3 additions & 1 deletion mem0/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,6 @@ class AzureConfig(BaseModel):
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version: str = Field(description="The version of the Azure API being used.", default=None)
default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None)
default_headers: Optional[Dict[str, str]] = Field(
description="Headers to include in requests to the Azure API.", default=None
)
5 changes: 3 additions & 2 deletions mem0/embeddings/gemini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Optional

import google.generativeai as genai

from mem0.configs.embeddings.base import BaseEmbedderConfig
Expand All @@ -9,7 +10,7 @@
class GoogleGenAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

self.config.model = self.config.model or "models/text-embedding-004"
self.config.embedding_dims = self.config.embedding_dims or 768

Expand All @@ -27,4 +28,4 @@ def embed(self, text):
"""
text = text.replace("\n", " ")
response = genai.embed_content(model=self.config.model, content=text)
return response['embedding']
return response["embedding"]
4 changes: 2 additions & 2 deletions mem0/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):

self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)

self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()

def embed(self, text):
"""
Expand All @@ -26,4 +26,4 @@ def embed(self, text):
Returns:
list: The embedding vector.
"""
return self.model.encode(text, convert_to_numpy = True).tolist()
return self.model.encode(text, convert_to_numpy=True).tolist()
69 changes: 39 additions & 30 deletions mem0/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from google.generativeai import GenerativeModel
from google.generativeai.types import content_types
except ImportError:
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.")
raise ImportError(
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
)

from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
Expand Down Expand Up @@ -44,16 +46,16 @@ def _parse_response(self, response, tools):
if fn := part.function_call:
processed_response["tool_calls"].append(
{
"name": fn.name,
"arguments": {key:val for key, val in fn.args.items()},
"name": fn.name,
"arguments": {key: val for key, val in fn.args.items()},
}
)

return processed_response
else:
return response.candidates[0].content.parts[0].text

def _reformat_messages(self, messages : List[Dict[str, str]]):
def _reformat_messages(self, messages: List[Dict[str, str]]):
"""
Reformat messages for Gemini.
Expand All @@ -71,9 +73,8 @@ def _reformat_messages(self, messages : List[Dict[str, str]]):

else:
content = message["content"]

new_messages.append({"parts": content,
"role": "model" if message["role"] == "model" else "user"})

new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"})

return new_messages

Expand All @@ -89,24 +90,24 @@ def _reformat_tools(self, tools: Optional[List[Dict]]):
"""

def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data
"""Recursively removes 'additionalProperties' from nested dictionaries."""

if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data

new_tools = []
if tools:
for tool in tools:
func = tool['function'].copy()
new_tools.append({"function_declarations":[remove_additional_properties(func)]})
func = tool["function"].copy()
new_tools.append({"function_declarations": [remove_additional_properties(func)]})

return new_tools
else:
return None
Expand Down Expand Up @@ -142,13 +143,21 @@ def generate_response(
params["response_schema"] = list[response_format]
if tool_choice:
tool_config = content_types.to_tool_config(
{"function_calling_config":
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None}
})

response = self.client.generate_content(contents = self._reformat_messages(messages),
tools = self._reformat_tools(tools),
generation_config = genai.GenerationConfig(**params),
tool_config = tool_config)
{
"function_calling_config": {
"mode": tool_choice,
"allowed_function_names": [tool["function"]["name"] for tool in tools]
if tool_choice == "any"
else None,
}
}
)

response = self.client.generate_content(
contents=self._reformat_messages(messages),
tools=self._reformat_tools(tools),
generation_config=genai.GenerationConfig(**params),
tool_config=tool_config,
)

return self._parse_response(response, tools)
4 changes: 3 additions & 1 deletion mem0/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1",
base_url=self.config.openrouter_base_url
or os.getenv("OPENROUTER_API_BASE")
or "https://openrouter.ai/api/v1",
)
else:
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
Expand Down
Loading

0 comments on commit fbf1d8c

Please sign in to comment.