Skip to content

Commit

Permalink
mypy: pglookout/pgutil.py [BF-1560]
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Giffard committed Mar 16, 2023
1 parent a4f67c2 commit 75b5715
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 21 deletions.
105 changes: 90 additions & 15 deletions pglookout/pgutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,110 @@
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""
from __future__ import annotations

from typing import cast, Literal, TypedDict
from urllib.parse import parse_qs, urlparse # pylint: disable=no-name-in-module, import-error

import psycopg2.extensions


def create_connection_string(connection_info):
return psycopg2.extensions.make_dsn(**connection_info)
class DsnDictBase(TypedDict, total=False):
user: str
password: str
host: str
port: str | int


class DsnDict(DsnDictBase, total=False):
dbname: str


class DsnDictDeprecated(DsnDictBase, total=False):
database: str


class ConnectionParameterKeywords(TypedDict, total=False):
"""Parameter Keywords for Connection.
See:
https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
"""

def mask_connection_info(info):
host: str
hostaddr: str
port: str
dbname: str
user: str
password: str
passfile: str
channel_binding: Literal["require", "prefer", "disable"]
connect_timeout: str
client_encoding: str
options: str
application_name: str
fallback_application_name: str
keepalives: Literal["0", "1"]
keepalives_idle: str
keepalives_interval: str
keepalives_count: str
tcp_user_timeout: str
replication: Literal["true", "on", "yes", "1", "database", "false", "off", "no", "0"]
gssencmode: Literal["disable", "prefer", "require"]
sslmode: Literal["disable", "allow", "prefer", "require", "verify-ca", "verify-full"]
requiressl: Literal["0", "1"]
sslcompression: Literal["0", "1"]
sslcert: str
sslkey: str
sslpassword: str
sslrootcert: str
sslcrl: str
sslcrldir: str
sslsni: Literal["0", "1"]
requirepeer: str
ssl_min_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"]
ssl_max_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"]
krbsrvname: str
gsslib: str
service: str
target_session_attrs: Literal["any", "read-write", "read-only", "primary", "standby", "prefer-standby"]


def create_connection_string(connection_info: DsnDict | DsnDictDeprecated | ConnectionParameterKeywords) -> str:
return str(psycopg2.extensions.make_dsn(**connection_info))


def mask_connection_info(info: str) -> str:
masked_info = get_connection_info(info)
password = masked_info.pop("password", None)
connection_string = create_connection_string(masked_info)
message = "no password" if password is None else "hidden password"
return f"{connection_string}; {message}"


def get_connection_info_from_config_line(line):
def get_connection_info_from_config_line(line: str) -> ConnectionParameterKeywords:
_, value = line.split("=", 1)
value = value.strip()[1:-1].replace("''", "'")
return get_connection_info(value)


def get_connection_info(info):
"""turn a connection info object into a dict or return it if it was a
dict already. supports both the traditional libpq format and the new
url format"""
def get_connection_info(
info: str | DsnDict | DsnDictDeprecated | ConnectionParameterKeywords,
) -> ConnectionParameterKeywords:
"""Get a normalized connection info dict from a connection string or a dict.
Supports both the traditional libpq format and the new url format.
"""
if isinstance(info, dict):
return info.copy()
# Potentially, we might clean deprecated DSN dicts: `database` -> `dbname`.
# Also, psycopg2 will validate the keys and values.
return parse_connection_string_libpq(create_connection_string(info))
if info.startswith("postgres://") or info.startswith("postgresql://"):
return parse_connection_string_url(info)
return parse_connection_string_libpq(info)


def parse_connection_string_url(url):
def parse_connection_string_url(url: str) -> ConnectionParameterKeywords:
# drop scheme from the url as some versions of urlparse don't handle
# query and path properly for urls with a non-http scheme
schemeless_url = url.split(":", 1)[1]
Expand All @@ -57,12 +126,15 @@ def parse_connection_string_url(url):
fields["dbname"] = p.path[1:]
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields
return cast(ConnectionParameterKeywords, fields)


def parse_connection_string_libpq(connection_string: str) -> ConnectionParameterKeywords:
"""Parse a postgresql connection string.
def parse_connection_string_libpq(connection_string):
"""parse a postgresql connection string as defined in
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING"""
See:
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
"""
fields = {}
while True:
connection_string = connection_string.strip()
Expand Down Expand Up @@ -92,5 +164,8 @@ def parse_connection_string_libpq(connection_string):
value, connection_string = res
else:
value, connection_string = rem, ""
# This one is case-insensitive. To continue benefiting from mypy, we make it lowercase.
if key == "replication":
value = value.lower()
fields[key] = value
return fields
return cast(ConnectionParameterKeywords, fields)
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ exclude = [
'pglookout/current_master.py',
'pglookout/logutil.py',
'pglookout/pglookout.py',
'pglookout/pgutil.py',
'pglookout/statsd.py',
'pglookout/version.py',
'pglookout/webserver.py',
Expand All @@ -45,7 +44,6 @@ exclude = [
'test/test_cluster_monitor.py',
'test/test_common.py',
'test/test_lookout.py',
'test/test_pgutil.py',
'test/test_webserver.py',
# Other.
'setup.py',
Expand Down
8 changes: 4 additions & 4 deletions test/test_pgutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
See LICENSE for details
"""

from pglookout.pgutil import create_connection_string, get_connection_info, mask_connection_info
from pglookout.pgutil import ConnectionParameterKeywords, create_connection_string, get_connection_info, mask_connection_info
from pytest import raises


def test_connection_info():
def test_connection_info() -> None:
url = "postgres://hannu:[email protected]:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='hannu' dbname='abc'\nreplication=true password=secret sslmode=require port=5555"
ci = {
ci: ConnectionParameterKeywords = {
"host": "dbhost.local",
"port": "5555",
"user": "hannu",
Expand All @@ -39,7 +39,7 @@ def test_connection_info():
get_connection_info("foo=bar bar='x")


def test_mask_connection_info():
def test_mask_connection_info() -> None:
url = "postgres://michael:[email protected]:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='michael' dbname='abc'\nreplication=true password=secret sslmode=require port=5555"
ci = get_connection_info(cs)
Expand Down

0 comments on commit 75b5715

Please sign in to comment.