Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lloesche committed Oct 4, 2023
1 parent 4f884c1 commit d4be8d5
Show file tree
Hide file tree
Showing 10 changed files with 518 additions and 32 deletions.
8 changes: 5 additions & 3 deletions fixca/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import resotolib.proc
from signal import SIGTERM
from tempfile import TemporaryDirectory
from resotolib.logger import log, setup_logger, add_args as logging_add_args
from resotolib.web import WebServer
Expand All @@ -9,19 +10,20 @@
from .args import parse_args
from .ca import CA, WebApp, CaApp
from threading import Event
from typing import Any


shutdown_event = Event()


def shutdown(event) -> None:
def shutdown(even: Any) -> None:
log.info("Shutting down")
shutdown_event.set()


def main() -> None:
setup_logger("fixca")
args = parse_args([logging_add_args])
args = parse_args([logging_add_args]) # type: ignore
log.info(f"Starting FIX CA on port {args.port}")
resotolib.proc.initializer()
resotolib.proc.parent_pid = os.getpid()
Expand Down Expand Up @@ -63,7 +65,7 @@ def main() -> None:
shutdown_event.wait()
web_server.shutdown()

resotolib.proc.kill_children(resotolib.proc.SIGTERM, ensure_death=True)
resotolib.proc.kill_children(SIGTERM, ensure_death=True)
log.info("Shutdown complete")
sys.exit(0)

Expand Down
5 changes: 3 additions & 2 deletions fixca/args.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from argparse import ArgumentParser, Namespace
from typing import Callable, List
from resotolib.args import ArgumentParser as ResotoArgumentParser
from typing import Callable, List, Union


def parse_args(add_args: List[Callable]) -> Namespace:
def parse_args(add_args: List[Callable[[ArgumentParser], None]]) -> Namespace:
parser = ArgumentParser(prog="fixca", description="FIX Certification Authority")
parser.add_argument("--psk", dest="psk", help="Pre-shared-key", default=os.environ.get("FIXCA_PSK"))
parser.add_argument(
Expand Down
31 changes: 17 additions & 14 deletions fixca/ca.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import wraps
from prometheus_client.exposition import generate_latest, CONTENT_TYPE_LATEST
from typing import Optional, Dict, Callable, Tuple, Union, Any, List
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.x509.base import Certificate, CertificateSigningRequest
from resotolib.logger import log
from resotolib.x509 import (
Expand All @@ -25,10 +25,10 @@


class CertificateAuthority:
def __init__(self):
self.cert = None
self.__key = None
self.__initialized = False
def __init__(self) -> None:
self.cert: Optional[Certificate] = None
self.__key: Optional[RSAPrivateKey] = None
self.__initialized: bool = False

@staticmethod
def requires_initialized(func: Callable[..., Any]) -> Callable[..., Any]:
Expand All @@ -42,6 +42,7 @@ def wrapper(ca_instance: "CertificateAuthority", *args: Any, **kwargs: Any) -> A

@requires_initialized
def sign(self, csr: CertificateSigningRequest) -> Certificate:
assert self.__key is not None and self.cert is not None
return sign_csr(csr, self.__key, self.cert)

def initialize(self, namespace: str = "cert-manager", secret_name: str = "fix-ca", dummy_ca: bool = False) -> None:
Expand All @@ -62,7 +63,7 @@ def __load_ca_data(
log.info("Loading CA data")
ca_secret = get_secret(namespace=namespace, secret_name=secret_name)

if isinstance(ca_secret, dict) and (not "tls.key" in ca_secret or not "tls.crt" in ca_secret):
if isinstance(ca_secret, dict) and ("tls.key" not in ca_secret or "tls.crt" not in ca_secret):
ca_secret = None
log.error("CA secret is missing key or cert")

Expand Down Expand Up @@ -126,6 +127,7 @@ def store_secret(
include_ca_bundle: bool = False,
) -> None:
log.info(f"Storing certificate {cert_crt.subject.rfc4514_string()} in {namespace}/{secret_name}")
assert self.cert is not None
secret = {
key_cert: cert_to_bytes(cert_crt).decode("utf-8"),
key_key: key_to_bytes(cert_key).decode("utf-8"),
Expand All @@ -143,10 +145,10 @@ def store_secret(


CA: CertificateAuthority = CertificateAuthority()
PSK: Optional[Union[str, Certificate, RSAPublicKey]] = None
PSK: Optional[str] = None


def jwt_check():
def jwt_check() -> None:
headers = cherrypy.request.headers
assert PSK is not None

Expand Down Expand Up @@ -203,23 +205,23 @@ def metrics(self) -> bytes:


class CaApp:
def __init__(self, ca: CertificateAuthority, psk_or_cert: Union[str, Certificate, RSAPublicKey]) -> None:
def __init__(self, ca: CertificateAuthority, psk: str) -> None:
global PSK
self.ca = ca
self.psk_or_cert = psk_or_cert
self.psk = psk
self.config = {"/": {"tools.gzip.on": False}}
PSK = self.psk_or_cert
PSK = self.psk

@cherrypy.expose
@cherrypy.tools.allow(methods=["GET"])
def cert(self) -> bytes:
assert self.psk_or_cert is not None
assert self.psk is not None and self.ca.cert is not None
fingerprint = cert_fingerprint(self.ca.cert)
cherrypy.response.headers["Content-Type"] = "application/x-pem-file"
cherrypy.response.headers["SHA256-Fingerprint"] = fingerprint
cherrypy.response.headers["Content-Disposition"] = 'attachment; filename="fix_root_ca.pem"'
cherrypy.response.headers["Authorization"] = "Bearer " + encode_jwt(
{"sha256_fingerprint": fingerprint}, self.psk_or_cert
{"sha256_fingerprint": fingerprint}, self.psk
)
return cert_to_bytes(self.ca.cert)

Expand Down Expand Up @@ -247,8 +249,9 @@ def sign(self) -> bytes:
@cherrypy.tools.json_in()
@cherrypy.tools.allow(methods=["POST"])
@cherrypy.tools.jwt_check()
def generate(self) -> bytes:
def generate(self) -> Dict[str, Any]:
try:
assert self.ca.cert is not None
request_json = cherrypy.request.json
remote_addr = cherrypy.request.remote.ip
include_ca_cert = str_to_bool(request_json.get("include_ca_cert", False))
Expand Down
19 changes: 10 additions & 9 deletions fixca/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from functools import wraps
from typing import Callable, Any, Tuple, Dict, Union, TypeVar
from typing import Callable, Any, Tuple, Dict, Union, TypeVar, Type


def str_to_bool(s: Union[str, bool]) -> bool:
Expand All @@ -10,15 +10,16 @@ def str_to_bool(s: Union[str, bool]) -> bool:
RT = TypeVar("RT")


def memoize(ttl: int = 60, cleanup_interval: int = 600) -> Callable:
state = {"last_cleanup": 0}
cache: Dict[Tuple[Callable, Tuple, frozenset], Tuple[RT, float]] = {}
def memoize(
ttl: int = 60, cleanup_interval: int = 600, time_fn: Callable[[], float] = time.time
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
last_cleanup: float = 0.0
cache: Dict[Tuple[Callable[..., RT], Tuple[Any, ...], frozenset[Tuple[str, Any]]], Tuple[RT, float]] = {}

def decorating_function(user_function: Callable[..., RT]) -> Callable[..., RT]:
@wraps(user_function)
def wrapper(*args: Any, **kwargs: Any) -> RT:
nonlocal cache
now = time.time()
now = time_fn()
key = (user_function, args, frozenset(kwargs.items()))
if key in cache:
result, timestamp = cache[key]
Expand All @@ -28,11 +29,11 @@ def wrapper(*args: Any, **kwargs: Any) -> RT:
result = user_function(*args, **kwargs)
cache[key] = (result, now)

nonlocal state
if now - state["last_cleanup"] > cleanup_interval:
nonlocal last_cleanup
if now - last_cleanup > cleanup_interval:
for k in [k for k, v in cache.items() if now - v[1] >= ttl]:
cache.pop(k)
state["last_cleanup"] = now
last_cleanup = now

return result

Expand Down
4 changes: 4 additions & 0 deletions genreq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
pip-compile --resolver=backtracking --upgrade --allow-unsafe --no-header --unsafe-package n/a --output-file requirements.txt
pip-compile --extra test --resolver=backtracking --upgrade --allow-unsafe --no-header --unsafe-package n/a --output-file requirements-test.txt

Loading

0 comments on commit d4be8d5

Please sign in to comment.