diff --git a/pglookout/pgutil.py b/pglookout/pgutil.py index 3b8562b..7d5f692 100644 --- a/pglookout/pgutil.py +++ b/pglookout/pgutil.py @@ -5,16 +5,80 @@ Copyright (c) 2015 Ohmu Ltd See LICENSE for details """ +from __future__ import annotations + +from typing import cast, Literal, TypedDict from urllib.parse import parse_qs, urlparse # pylint: disable=no-name-in-module, import-error import psycopg2.extensions -def create_connection_string(connection_info): - return psycopg2.extensions.make_dsn(**connection_info) +class DsnDictBase(TypedDict, total=False): + user: str + password: str + host: str + port: str | int + + +class DsnDict(DsnDictBase, total=False): + dbname: str + + +class DsnDictDeprecated(DsnDictBase, total=False): + database: str + + +class ConnectionParameterKeywords(TypedDict, total=False): + """Parameter Keywords for Connection. + See: + https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS + """ -def mask_connection_info(info): + host: str + hostaddr: str + port: str + dbname: str + user: str + password: str + passfile: str + channel_binding: Literal["require", "prefer", "disable"] + connect_timeout: str + client_encoding: str + options: str + application_name: str + fallback_application_name: str + keepalives: Literal["0", "1"] + keepalives_idle: str + keepalives_interval: str + keepalives_count: str + tcp_user_timeout: str + replication: Literal["true", "on", "yes", "1", "database", "false", "off", "no", "0"] + gssencmode: Literal["disable", "prefer", "require"] + sslmode: Literal["disable", "allow", "prefer", "require", "verify-ca", "verify-full"] + requiressl: Literal["0", "1"] + sslcompression: Literal["0", "1"] + sslcert: str + sslkey: str + sslpassword: str + sslrootcert: str + sslcrl: str + sslcrldir: str + sslsni: Literal["0", "1"] + requirepeer: str + ssl_min_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"] + ssl_max_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"] + krbsrvname: str + gsslib: str + service: str + target_session_attrs: Literal["any", "read-write", "read-only", "primary", "standby", "prefer-standby"] + + +def create_connection_string(connection_info: DsnDict | DsnDictDeprecated | ConnectionParameterKeywords) -> str: + return str(psycopg2.extensions.make_dsn(**connection_info)) + + +def mask_connection_info(info: str) -> str: masked_info = get_connection_info(info) password = masked_info.pop("password", None) connection_string = create_connection_string(masked_info) @@ -22,24 +86,29 @@ def mask_connection_info(info): return f"{connection_string}; {message}" -def get_connection_info_from_config_line(line): +def get_connection_info_from_config_line(line: str) -> ConnectionParameterKeywords: _, value = line.split("=", 1) value = value.strip()[1:-1].replace("''", "'") return get_connection_info(value) -def get_connection_info(info): - """turn a connection info object into a dict or return it if it was a - dict already. supports both the traditional libpq format and the new - url format""" +def get_connection_info( + info: str | DsnDict | DsnDictDeprecated | ConnectionParameterKeywords, +) -> ConnectionParameterKeywords: + """Get a normalized connection info dict from a connection string or a dict. + + Supports both the traditional libpq format and the new url format. + """ if isinstance(info, dict): - return info.copy() + # Potentially, we might clean deprecated DSN dicts: `database` -> `dbname`. + # Also, psycopg2 will validate the keys and values. + return parse_connection_string_libpq(create_connection_string(info)) if info.startswith("postgres://") or info.startswith("postgresql://"): return parse_connection_string_url(info) return parse_connection_string_libpq(info) -def parse_connection_string_url(url): +def parse_connection_string_url(url: str) -> ConnectionParameterKeywords: # drop scheme from the url as some versions of urlparse don't handle # query and path properly for urls with a non-http scheme schemeless_url = url.split(":", 1)[1] @@ -57,12 +126,15 @@ def parse_connection_string_url(url): fields["dbname"] = p.path[1:] for k, v in parse_qs(p.query).items(): fields[k] = v[-1] - return fields + return cast(ConnectionParameterKeywords, fields) + +def parse_connection_string_libpq(connection_string: str) -> ConnectionParameterKeywords: + """Parse a postgresql connection string. -def parse_connection_string_libpq(connection_string): - """parse a postgresql connection string as defined in - http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING""" + See: + http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING + """ fields = {} while True: connection_string = connection_string.strip() @@ -92,5 +164,8 @@ def parse_connection_string_libpq(connection_string): value, connection_string = res else: value, connection_string = rem, "" + # This one is case-insensitive. To continue benefiting from mypy, we make it lowercase. + if key == "replication": + value = value.lower() fields[key] = value - return fields + return cast(ConnectionParameterKeywords, fields) diff --git a/pyproject.toml b/pyproject.toml index 16ce87c..76ed844 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ exclude = [ 'pglookout/current_master.py', 'pglookout/logutil.py', 'pglookout/pglookout.py', - 'pglookout/pgutil.py', 'pglookout/statsd.py', 'pglookout/version.py', 'pglookout/webserver.py', @@ -45,7 +44,6 @@ exclude = [ 'test/test_cluster_monitor.py', 'test/test_common.py', 'test/test_lookout.py', - 'test/test_pgutil.py', 'test/test_webserver.py', # Other. 'setup.py', diff --git a/test/test_pgutil.py b/test/test_pgutil.py index d70e6f5..7dc969f 100644 --- a/test/test_pgutil.py +++ b/test/test_pgutil.py @@ -6,14 +6,14 @@ See LICENSE for details """ -from pglookout.pgutil import create_connection_string, get_connection_info, mask_connection_info +from pglookout.pgutil import ConnectionParameterKeywords, create_connection_string, get_connection_info, mask_connection_info from pytest import raises -def test_connection_info(): +def test_connection_info() -> None: url = "postgres://hannu:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require" cs = "host=dbhost.local user='hannu' dbname='abc'\nreplication=true password=secret sslmode=require port=5555" - ci = { + ci: ConnectionParameterKeywords = { "host": "dbhost.local", "port": "5555", "user": "hannu", @@ -39,7 +39,7 @@ def test_connection_info(): get_connection_info("foo=bar bar='x") -def test_mask_connection_info(): +def test_mask_connection_info() -> None: url = "postgres://michael:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require" cs = "host=dbhost.local user='michael' dbname='abc'\nreplication=true password=secret sslmode=require port=5555" ci = get_connection_info(cs)