Skip to content

Commit

Permalink
Added support for multiple API keys, generating OAI-like tokens by de…
Browse files Browse the repository at this point in the history
…fault. Implements theroyallab#79
  • Loading branch information
kir-gadjello committed May 1, 2024
1 parent 7556dcf commit 61f6403
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,35 @@
application, it should be fine.
"""

import string
import secrets
import yaml
from fastapi import Header, HTTPException
from pydantic import BaseModel
from loguru import logger
from typing import Optional
from typing import Optional, List


class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
It contains two types of keys: 'api_key'/'api_keys' and 'admin_key'.
The 'api_key'/'api_keys' are used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
to verify if a given key matches the stored 'api_key'/'api_keys' or 'admin_key'.
"""

api_key: str
api_keys: Optional[List[str]] = None
admin_key: str

def verify_key(self, test_key: str, key_type: str):
"""Verify if a given key matches the stored key."""
if key_type == "admin_key":
return test_key == self.admin_key
if key_type == "api_key":
if isinstance(self.api_keys, list) and test_key in self.api_keys:
return True
# Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key
return False
Expand All @@ -38,6 +42,20 @@ def verify_key(self, test_key: str, key_type: str):
DISABLE_AUTH: bool = False


def gen_rand_ascii(length):
chars = string.ascii_letters + string.digits
nchars = len(chars)
secure_token = secrets.token_bytes(length)
return "".join(map(lambda b: chars[(b % nchars)], secure_token))


# some apps check this regexp https://github.com/secretlint/secretlint/issues/676
def gen_oai_like_key():
prefix = "sk-"
suffix = "T3BlbkFJ"
return f"{prefix}{gen_rand_ascii(20)}{suffix}{gen_rand_ascii(20)}"


def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
Expand All @@ -60,15 +78,21 @@ def load_auth_keys(disable_from_config: bool):
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
api_keys=[gen_oai_like_key()],
api_key=gen_oai_like_key(),
admin_key=secrets.token_hex(16),
)
AUTH_KEYS = new_auth_keys

with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)

multiple_keys_msg = ""
if isinstance(AUTH_KEYS.api_keys, list):
multiple_keys_msg = f"Your additional API keys are: {'\n'.join(AUTH_KEYS.api_keys)}\n"

logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your API key is: {AUTH_KEYS.api_key}\n{multiple_keys_msg}"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
Expand Down

0 comments on commit 61f6403

Please sign in to comment.