diff --git a/README.md b/README.md index a3a579c..c3a2720 100644 --- a/README.md +++ b/README.md @@ -36,3 +36,48 @@ Alternatively export the following environment variables: - `FIXCA_SECRET` Only the pre-shared-key is mandatory. The other options have sensible defaults. + +## K8s cluster issuer + +When using [cert-manager](https://cert-manager.io/) to issue certificates for your services you can use the following cluster issuer: + +```yaml +apiVersion: cert-manager.io/v1 +kind: ClusterIssuer +metadata: + name: fix-ca-issuer + namespace: cert-manager +spec: + ca: + secretName: fix-ca +``` + +### Example Certificate + +```yaml +apiVersion: cert-manager.io/v1 +kind: Certificate +metadata: + name: lukas-test-cert + namespace: fix +spec: + secretName: lukas-test + duration: 2160h # 90d + renewBefore: 360h # 15d + commonName: lukas.test + privateKey: + algorithm: RSA + encoding: PKCS1 + size: 2048 + usages: + - server auth + - client auth + dnsNames: + - redis.fix + issuerRef: + name: fix-ca-issuer + group: cert-manager.io + kind: ClusterIssuer +``` + +Check the [cert-manager documentation](https://cert-manager.io/docs/usage/certificate/) for more information. diff --git a/fixca/__init__.py b/fixca/__init__.py index 7a5cf99..919d037 100644 --- a/fixca/__init__.py +++ b/fixca/__init__.py @@ -1,13 +1,13 @@ """ -FIX Certification Authority -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +FIX Certificate Authority +~~~~~~~~~~~~~~~~~~~~~~~~~ Runs a web server that issues certificates for FIX components. :copyright: © 2023 Some Engineering Inc. :license: Apache 2.0, see LICENSE for more details. """ __title__ = "fixca" -__description__ = "FIX Certification Authority" +__description__ = "FIX Certificate Authority" __author__ = "Some Engineering Inc." __license__ = "Apache 2.0" __copyright__ = "Copyright © 2023 Some Engineering Inc." diff --git a/fixca/__main__.py b/fixca/__main__.py index 9932cc7..b61e106 100644 --- a/fixca/__main__.py +++ b/fixca/__main__.py @@ -7,7 +7,7 @@ from resotolib.event import EventType, add_event_listener from resotolib.x509 import gen_csr, gen_rsa_key, write_cert_to_file, write_key_to_file from .args import parse_args -from .ca import get_ca, WebApp, CaApp +from .ca import CA, WebApp, CaApp from threading import Event @@ -28,7 +28,7 @@ def main() -> None: add_event_listener(EventType.SHUTDOWN, shutdown) - CA = get_ca(namespace=args.namespace, secret_name=args.secret) + CA.initialize(namespace=args.namespace, secret_name=args.secret, dummy_ca=args.dummy_ca) common_name = "ca.fix" cert_key = gen_rsa_key() @@ -52,8 +52,11 @@ def main() -> None: web_port=args.port, ssl_cert=cert_path, ssl_key=key_path, + extra_config={ + "tools.proxy.on": True, + }, ) - web_server.mount("/ca", CaApp(get_ca(), args.psk)) + web_server.mount("/ca", CaApp(CA, args.psk)) web_server.daemon = True web_server.start() diff --git a/fixca/args.py b/fixca/args.py index a4a3c61..9220a98 100644 --- a/fixca/args.py +++ b/fixca/args.py @@ -16,8 +16,8 @@ def parse_args(add_args: List[Callable]) -> Namespace: parser.add_argument( "--namespace", dest="namespace", - help="K8s namespace (default: fix)", - default=os.environ.get("FIXCA_NAMESPACE", "fix"), + help="K8s namespace (default: cert-manager)", + default=os.environ.get("FIXCA_NAMESPACE", "cert-manager"), ) parser.add_argument( "--secret", @@ -25,6 +25,13 @@ def parse_args(add_args: List[Callable]) -> Namespace: help="Secret name (default: fix-ca)", default=os.environ.get("FIXCA_SECRET", "fix-ca"), ) + parser.add_argument( + "--dummy", + help="Start a dummy CA that does not persist its state", + dest="dummy_ca", + action="store_true", + default=False, + ) for add_arg in add_args: add_arg(parser) diff --git a/fixca/ca.py b/fixca/ca.py index fd6d11e..06ff19b 100644 --- a/fixca/ca.py +++ b/fixca/ca.py @@ -1,7 +1,8 @@ import os import cherrypy +from functools import wraps from prometheus_client.exposition import generate_latest, CONTENT_TYPE_LATEST -from typing import Optional, Dict, Callable, Tuple, Union +from typing import Optional, Dict, Callable, Tuple, Union, Any, List from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey from cryptography.x509.base import Certificate, CertificateSigningRequest from resotolib.logger import log @@ -14,54 +15,135 @@ load_csr_from_bytes, load_cert_from_bytes, load_key_from_bytes, + gen_rsa_key, + gen_csr, + gen_ca_bundle_bytes, ) from resotolib.jwt import encode_jwt, decode_jwt_from_headers from .k8s import get_secret, set_secret +from .utils import str_to_bool -CA: Optional["CertificateAuthority"] = None -PSK: Optional[Union[str, Certificate, RSAPublicKey]] = None +class CertificateAuthority: + def __init__(self): + self.cert = None + self.__key = None + self.__initialized = False + @staticmethod + def requires_initialized(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + def wrapper(ca_instance: "CertificateAuthority", *args: Any, **kwargs: Any) -> Any: + if not ca_instance.initialized: + raise Exception("CA not initialized") + return func(ca_instance, *args, **kwargs) -class CertificateAuthority: - def __init__(self, ca_key: RSAPrivateKey, ca_cert: Certificate): - self.ca_key = ca_key - self.ca_cert = ca_cert + return wrapper + @requires_initialized def sign(self, csr: CertificateSigningRequest) -> Certificate: - return sign_csr(csr, self.ca_key, self.ca_cert) + return sign_csr(csr, self.__key, self.cert) + + def initialize(self, namespace: str = "cert-manager", secret_name: str = "fix-ca", dummy_ca: bool = False) -> None: + if dummy_ca: + self.__key, self.cert = bootstrap_ca(common_name="FIX Certificate Authority") + else: + self.__key, self.cert = self.__load_ca_data(namespace=namespace, secret_name=secret_name) + self.__initialized = True + + @property + def initialized(self) -> bool: + return self.__initialized + @staticmethod + def __load_ca_data( + namespace: str = "cert-manager", secret_name: str = "fix-ca" + ) -> Tuple[RSAPrivateKey, Certificate]: + log.info("Loading CA data") + ca_secret = get_secret(namespace=namespace, secret_name=secret_name) -def load_ca_data(namespace: str = "fix", secret_name: str = "fix-ca") -> Tuple[RSAPrivateKey, Certificate]: - log.info("Loading CA data") - ca_secret = get_secret(namespace=namespace, secret_name=secret_name) + if isinstance(ca_secret, dict) and (not "tls.key" in ca_secret or not "tls.crt" in ca_secret): + ca_secret = None + log.error("CA secret is missing key or cert") - if isinstance(ca_secret, dict) and (not "key" in ca_secret or not "cert" in ca_secret): - ca_secret = None - log.error("CA secret is missing key or cert") + if ca_secret is None: + log.debug("Bootstrapping a new CA") + key, cert = bootstrap_ca(common_name="FIX Certificate Authority") + ca_secret = { + "tls.key": key_to_bytes(key).decode("utf-8"), + "tls.crt": cert_to_bytes(cert).decode("utf-8"), + } + set_secret(namespace=namespace, secret_name=secret_name, data=ca_secret) + else: + log.debug("Loading existing CA") + key_bytes, cert_bytes = ca_secret["tls.key"].encode(), ca_secret["tls.crt"].encode() + key = load_key_from_bytes(key_bytes) + cert = load_cert_from_bytes(cert_bytes) + + return key, cert + + @requires_initialized + def generate( + self, + common_name: str, + san_dns_names: Optional[List[str]] = None, + san_ip_addresses: Optional[List[str]] = None, + ) -> Tuple[RSAPrivateKey, Certificate]: + if san_dns_names is None: + san_dns_names = [] + elif isinstance(san_dns_names, str): + san_dns_names = [san_dns_names] + if san_ip_addresses is None: + san_ip_addresses = [] + elif isinstance(san_ip_addresses, str): + san_ip_addresses = [san_ip_addresses] + + cert_key = gen_rsa_key() + cert_csr = gen_csr( + cert_key, + common_name=common_name, + san_dns_names=san_dns_names, + san_ip_addresses=san_ip_addresses, + include_loopback=False, + connect_to_ips=None, + discover_local_dns_names=False, + discover_local_ip_addresses=False, + ) + cert_crt = self.sign(cert_csr) + return cert_key, cert_crt - if ca_secret is None: - log.debug("Bootstrapping a new CA") - key, cert = bootstrap_ca(common_name="FIX Certification Authority") - ca_secret = { - "key": key_to_bytes(key).decode("utf-8"), - "cert": cert_to_bytes(cert).decode("utf-8"), + def store_secret( + self, + cert_key: RSAPrivateKey, + cert_crt: Certificate, + namespace: str, + secret_name: str, + key_cert: str = "cert.pem", + key_key: str = "cert.key", + key_ca: str = "ca.pem", + key_ca_bundle: str = "ca.bundle.pem", + include_ca_cert: bool = False, + include_ca_bundle: bool = False, + ) -> None: + log.info(f"Storing certificate {cert_crt.subject.rfc4514_string()} in {namespace}/{secret_name}") + secret = { + key_cert: cert_to_bytes(cert_crt).decode("utf-8"), + key_key: key_to_bytes(cert_key).decode("utf-8"), } - set_secret(namespace=namespace, secret_name=secret_name, data=ca_secret) - else: - log.debug("Loading existing CA") - key_bytes, cert_bytes = ca_secret["key"].encode(), ca_secret["cert"].encode() - key = load_key_from_bytes(key_bytes) - cert = load_cert_from_bytes(cert_bytes) + if include_ca_cert: + secret[key_ca] = cert_to_bytes(self.cert).decode("utf-8") + if include_ca_bundle: + secret[key_ca_bundle] = gen_ca_bundle_bytes(self.cert).decode("utf-8") - return key, cert + set_secret( + namespace=namespace, + secret_name=secret_name, + data=secret, + ) -def get_ca(namespace: str = "fix", secret_name: str = "fix-ca") -> CertificateAuthority: - global CA - if CA is None: - CA = CertificateAuthority(*load_ca_data(namespace=namespace, secret_name=secret_name)) - return CA +CA: CertificateAuthority = CertificateAuthority() +PSK: Optional[Union[str, Certificate, RSAPublicKey]] = None def jwt_check(): @@ -94,7 +176,7 @@ def __init__( "tools.staticdir.on": True, "tools.staticdir.dir": f"{local_path}/static", } - self.ca = get_ca() + self.ca = CA self.config = {"/": config} self.health_conditions = health_conditions if health_conditions is not None else {} if self.mountpoint not in ("/", ""): @@ -132,14 +214,14 @@ def __init__(self, ca: CertificateAuthority, psk_or_cert: Union[str, Certificate @cherrypy.tools.allow(methods=["GET"]) def cert(self) -> bytes: assert self.psk_or_cert is not None - fingerprint = cert_fingerprint(self.ca.ca_cert) + fingerprint = cert_fingerprint(self.ca.cert) cherrypy.response.headers["Content-Type"] = "application/x-pem-file" cherrypy.response.headers["SHA256-Fingerprint"] = fingerprint cherrypy.response.headers["Content-Disposition"] = 'attachment; filename="fix_root_ca.pem"' cherrypy.response.headers["Authorization"] = "Bearer " + encode_jwt( {"sha256_fingerprint": fingerprint}, self.psk_or_cert ) - return cert_to_bytes(self.ca.ca_cert) + return cert_to_bytes(self.ca.cert) @cherrypy.expose @cherrypy.tools.allow(methods=["POST"]) @@ -159,3 +241,39 @@ def sign(self) -> bytes: cherrypy.response.headers["SHA256-Fingerprint"] = cert_fingerprint(crt) cherrypy.response.headers["Content-Disposition"] = f'attachment; filename="{filename}"' return cert_to_bytes(crt) + + @cherrypy.expose + @cherrypy.tools.json_out() + @cherrypy.tools.json_in() + @cherrypy.tools.allow(methods=["POST"]) + @cherrypy.tools.jwt_check() + def generate(self) -> bytes: + try: + request_json = cherrypy.request.json + remote_addr = cherrypy.request.remote.ip + include_ca_cert = str_to_bool(request_json.get("include_ca_cert", False)) + include_ca_bundle = str_to_bool(request_json.get("include_ca_bundle", False)) + common_name = request_json.get("common_name", remote_addr) + san_dns_name = request_json.get("common_name", "localhost") + cert_key, cert_crt = self.ca.generate( + common_name=common_name, + san_dns_names=[san_dns_name], + san_ip_addresses=[remote_addr], + ) + secret_key_cert = request_json.get("key_cert", "cert.pem") + secret_key_key = request_json.get("key_key", "cert.key") + secret_key_ca = request_json.get("key_ca", "ca.pem") + secret_key_ca_bundle = request_json.get("key_ca_bundle", "ca.bundle.pem") + secret = { + secret_key_cert: cert_to_bytes(cert_crt).decode("utf-8"), + secret_key_key: key_to_bytes(cert_key).decode("utf-8"), + } + if include_ca_cert: + secret[secret_key_ca] = cert_to_bytes(self.ca.cert).decode("utf-8") + if include_ca_bundle: + secret[secret_key_ca_bundle] = gen_ca_bundle_bytes(self.ca.cert).decode("utf-8") + except Exception: + cherrypy.response.status = 400 + return {"error": "Invalid request"} + + return secret diff --git a/fixca/k8s.py b/fixca/k8s.py index ab8d53a..d62f17d 100644 --- a/fixca/k8s.py +++ b/fixca/k8s.py @@ -4,9 +4,16 @@ from resotolib.logger import log from kubernetes import client, config from kubernetes.client.exceptions import ApiException +from .utils import memoize def k8s_client() -> client.CoreV1Api: + k8s_config_load() + return client.CoreV1Api() + + +@memoize() +def k8s_config_load() -> None: try: config.load_incluster_config() except config.config_exception.ConfigException: @@ -15,7 +22,6 @@ def k8s_client() -> client.CoreV1Api: except config.config_exception.ConfigException as e: log.critical(f"Failed to load Kubernetes config: {e}") sys.exit(1) - return client.CoreV1Api() def get_secret(namespace: str, secret_name: str) -> Optional[dict[str, str]]: diff --git a/fixca/static/index.html b/fixca/static/index.html index 1bdb539..acb9498 100644 --- a/fixca/static/index.html +++ b/fixca/static/index.html @@ -7,10 +7,10 @@ - FIX Certification Authority + FIX Certificate Authority -

FIX Certification Authority


+

FIX Certificate Authority


diff --git a/fixca/utils.py b/fixca/utils.py new file mode 100644 index 0000000..7518325 --- /dev/null +++ b/fixca/utils.py @@ -0,0 +1,41 @@ +import time +from functools import wraps +from typing import Callable, Any, Tuple, Dict, Union, TypeVar + + +def str_to_bool(s: Union[str, bool]) -> bool: + return str(s).lower() in ("true", "1", "yes") + + +RT = TypeVar("RT") + + +def memoize(ttl: int = 60, cleanup_interval: int = 600) -> Callable: + state = {"last_cleanup": 0} + cache: Dict[Tuple[Callable, Tuple, frozenset], Tuple[RT, float]] = {} + + def decorating_function(user_function: Callable[..., RT]) -> Callable[..., RT]: + @wraps(user_function) + def wrapper(*args: Any, **kwargs: Any) -> RT: + nonlocal cache + now = time.time() + key = (user_function, args, frozenset(kwargs.items())) + if key in cache: + result, timestamp = cache[key] + if now - timestamp < ttl: + return result + + result = user_function(*args, **kwargs) + cache[key] = (result, now) + + nonlocal state + if now - state["last_cleanup"] > cleanup_interval: + for k in [k for k, v in cache.items() if now - v[1] >= ttl]: + cache.pop(k) + state["last_cleanup"] = now + + return result + + return wrapper + + return decorating_function