Skip to content

Commit

Permalink
fix(auth): handling grist error (#119)
Browse files Browse the repository at this point in the history
Co-authored-by: leoguillaume <[email protected]>
  • Loading branch information
leoguillaume and leoguillaumegouv authored Dec 20, 2024
1 parent 1071cb0 commit e209d26
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 77 deletions.
158 changes: 102 additions & 56 deletions app/clients/_authenticationclient.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,124 @@
import base64
from collections import namedtuple
import datetime as dt
import hashlib
import json
from typing import Optional
import uuid
from typing import Any, Callable
from typing import List, Optional

from grist_api import GristDocAPI
from redis import Redis
import aiohttp
from pydantic import BaseModel
from redis.asyncio import Redis
import requests

from app.schemas.security import Role, User


class AuthenticationClient(GristDocAPI):
CACHE_EXPIRATION = 3600 # 1h
from app.utils.logging import logger


class AsyncGristDocAPI:
def __init__(self, doc_id: str, server: str, api_key: str):
self.doc_id = doc_id
self.server = server
self.api_key = api_key
self.base_url = f"{server}/api"

async def _request(self, method: str, endpoint: str, data: Optional[dict] = None):
headers = {"Authorization": f"Bearer {self.api_key}"}
async with aiohttp.ClientSession() as session:
if method in ["GET"]:
data = {"params": data}
else:
headers["Content-Type"] = "application/json"
data = {"json": data}

async with session.request(method, f"{self.base_url}{endpoint}", headers=headers, **data) as response:
response.raise_for_status()
return await response.json()

def ping(self):
headers = {"Authorization": f"Bearer {self.api_key}"}
endpoint = "/orgs"

try:
response = requests.get(f"{self.base_url}{endpoint}", headers=headers)
return response.status_code == 200
except Exception as e:
return False

async def fetch_table(self, table_name: str, filter: Optional[dict] = None, limit: int = 0) -> List[namedtuple]:
endpoint = f"/docs/{self.doc_id}/tables/{table_name}/records"
data = {"filter": json.dumps(filter), "limit": limit} if filter else {"limit": limit}
results = await self._request(method="GET", endpoint=endpoint, data=data)
results = [dict(id=result["id"], **result["fields"]) for result in results["records"]]
return [namedtuple(table_name, result.keys())(**result) for result in results]

async def update_records(self, table_name: str, record_dicts: List[dict]):
endpoint = f"/docs/{self.doc_id}/tables/{table_name}/records"
data = {"records": [{"id": record.pop("id"), "fields": record} for record in record_dicts]}
result = await self._request(method="PATCH", endpoint=endpoint, data=data)
return result


class AuthenticationClient(AsyncGristDocAPI):
CACHE_EXPIRATION = 172800 # 48h

class GristRecord(BaseModel):
ID2: Optional[str] = None
ROLE: str = Role.USER
EXPIRATION: int = dt.datetime.now().timestamp()
KEY: Optional[str] = "EMPTY"

class Config:
extra = "allow"

def __init__(self, cache: Redis, table_id: str, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.session_id = str(uuid.uuid4())
assert self.ping(), "Grist is not reachable"
self.table_id = table_id
self.redis = cache

def check_api_key(self, key: str) -> Optional[str]:
async def check_api_key(self, key: str) -> Optional[User]:
"""
Check if a key exists in a table of the Grist document.
Get API key details from cache or Grist and return a User object.
Args:
key (str): key to check
Returns:
Optional[str]: role of the key if it exists, None otherwise
"""
keys = self._get_api_keys()
if key in keys:
return User(id=self._api_key_to_user_id(input=key), role=Role[keys[key]["role"]], name=keys[key]["name"])

def cache(func) -> Callable[..., Any]:
"""
Decorator to cache the result of a function in Redis.
"""

def wrapper(self) -> Any:
key = f"auth-{self.session_id}"
result = self.redis.get(key)
if result:
result = json.loads(result)
return result
result = func(self)
self.redis.setex(key, self.CACHE_EXPIRATION, json.dumps(result))

return result

return wrapper

@cache
def _get_api_keys(self) -> dict:
"""
Get all keys from a table in the Grist document.
key (str): API key to look up
Returns:
dict: dictionary of keys and their corresponding access level
Optional[User]: User object if found, None otherwise
"""
records = self.fetch_table(table_name=self.table_id)

keys = dict()
for record in records:
if record.EXPIRATION > dt.datetime.now().timestamp():
keys[record.KEY] = {
"id": self._api_key_to_user_id(input=record.KEY),
"role": Role.get(name=record.ROLE.upper(), default=Role.USER)._name_,
"name": record.USER,
}

return keys
user_id = self._api_key_to_user_id(input=key)
ttl = -2

# fetch from Redis
redis_key = f"{self.table_id}_{user_id}"
cache_user = await self.redis.get(redis_key)

if cache_user:
cache_user = json.loads(cache_user)
user = User(id=cache_user["id"], role=Role.get(cache_user["role"]))
ttl = await self.redis.ttl(redis_key)
if ttl > 300:
return user

try:
# fetch from grist
records = await self.fetch_table(table_name=self.table_id, filter={"KEY": [key]}, limit=1)
record = self.GristRecord(**records[0]._asdict()) if records else self.GristRecord()
if record.ID2 != user_id:
record.ID2 = user_id
await self.update_records(table_name=self.table_id, record_dicts=[record.model_dump()])

if record.KEY == key and record.EXPIRATION > dt.datetime.now().timestamp():
cache_user = {"id": record.ID2, "role": Role.get(name=record.ROLE.upper(), default=Role.USER)._name_}
await self.redis.setex(redis_key, self.CACHE_EXPIRATION, json.dumps(cache_user))
user = User(id=cache_user["id"], role=Role.get(cache_user["role"]))
return user

except Exception as e:
logger.error(f"Error fetching user from Grist: {e}")
if ttl > -2:
await self.redis.setex(redis_key, self.CACHE_EXPIRATION, json.dumps(cache_user))
return user

@staticmethod
def _api_key_to_user_id(input: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions app/helpers/_clientsmanager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from redis import Redis as CacheManager
from redis.connection import ConnectionPool
from redis.asyncio import Redis as CacheManager
from redis.asyncio.connection import ConnectionPool

from app.clients import AuthenticationClient, ModelClients
from app.clients.internet import DuckDuckGoInternetClient, BraveInternetClient
Expand Down
3 changes: 1 addition & 2 deletions app/schemas/security.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Optional
from typing import Any

from pydantic import BaseModel

Expand All @@ -19,5 +19,4 @@ def get(cls, name: str, default=None) -> Enum | Any:

class User(BaseModel):
id: str
name: Optional[str] = None
role: Role
32 changes: 16 additions & 16 deletions app/utils/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,39 @@

if settings.auth:

def check_admin_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> User:
async def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> User:
"""
Check if the API key is valid and if the user has admin rights.
Check if the API key is valid.
Args:
api_key (Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key")]): The API key to check.
Returns:
User: User object, corresponding to the encoded API key or "no-auth" if no authentication is set in the configuration file.
"""
user = check_api_key(api_key=api_key)
if user.role != Role.ADMIN:
raise InsufficientRightsException()

if api_key.scheme != "Bearer":
raise InvalidAuthenticationSchemeException()

user = await clients.auth.check_api_key(api_key.credentials)
if user is None:
raise InvalidAPIKeyException()

return user

def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> User:
async def check_admin_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> User:
"""
Check if the API key is valid.
Check if the API key is valid and if the user has admin rights.
Args:
api_key (Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key")]): The API key to check.
Returns:
User: User object, corresponding to the encoded API key or "no-auth" if no authentication is set in the configuration file.
"""

if api_key.scheme != "Bearer":
raise InvalidAuthenticationSchemeException()

user = clients.auth.check_api_key(api_key.credentials)
if user is None:
raise InvalidAPIKeyException()
user = await check_api_key(api_key=api_key)
if user.role != Role.ADMIN:
raise InsufficientRightsException()

return user

Expand All @@ -57,7 +57,7 @@ def check_api_key(api_key: Optional[str] = None) -> User:
return User(id="no-auth", role=Role.ADMIN)


def check_rate_limit(request: Request) -> Optional[str]:
async def check_rate_limit(request: Request) -> Optional[str]:
"""
Check the rate limit for the user.
Expand All @@ -71,7 +71,7 @@ def check_rate_limit(request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, credentials = authorization.split(" ") if authorization else ("", "")
api_key = HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
user = check_api_key(api_key=api_key)
user = await check_api_key(api_key=api_key)

if user.role.value > Role.USER.value:
return None
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ app = [
"pydantic-settings==2.6.1",
"prometheus-fastapi-instrumentator==7.0.0",
"pyyaml==6.0.1",
"grist-api==0.1.0",
"six==1.16.0",
"pdfminer.six==20240706",
"beautifulsoup4==4.12.3",
Expand Down

0 comments on commit e209d26

Please sign in to comment.