Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mypy: pglookout/pgutil.py [BF-1560] #98

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 90 additions & 15 deletions pglookout/pgutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,110 @@
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""
from __future__ import annotations

from typing import cast, Literal, TypedDict
from urllib.parse import parse_qs, urlparse # pylint: disable=no-name-in-module, import-error

import psycopg2.extensions


def create_connection_string(connection_info):
return psycopg2.extensions.make_dsn(**connection_info)
class DsnDictBase(TypedDict, total=False):
user: str
password: str
host: str
port: str | int


class DsnDict(DsnDictBase, total=False):
dbname: str


class DsnDictDeprecated(DsnDictBase, total=False):
database: str


class ConnectionParameterKeywords(TypedDict, total=False):
"""Parameter Keywords for Connection.

See:
https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
"""

def mask_connection_info(info):
host: str
hostaddr: str
port: str
dbname: str
user: str
password: str
passfile: str
channel_binding: Literal["require", "prefer", "disable"]
connect_timeout: str
client_encoding: str
options: str
application_name: str
fallback_application_name: str
keepalives: Literal["0", "1"]
keepalives_idle: str
keepalives_interval: str
keepalives_count: str
tcp_user_timeout: str
replication: Literal["true", "on", "yes", "1", "database", "false", "off", "no", "0"]
gssencmode: Literal["disable", "prefer", "require"]
sslmode: Literal["disable", "allow", "prefer", "require", "verify-ca", "verify-full"]
requiressl: Literal["0", "1"]
sslcompression: Literal["0", "1"]
sslcert: str
sslkey: str
sslpassword: str
sslrootcert: str
sslcrl: str
sslcrldir: str
sslsni: Literal["0", "1"]
requirepeer: str
ssl_min_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"]
ssl_max_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"]
krbsrvname: str
gsslib: str
service: str
target_session_attrs: Literal["any", "read-write", "read-only", "primary", "standby", "prefer-standby"]


def create_connection_string(connection_info: DsnDict | DsnDictDeprecated | ConnectionParameterKeywords) -> str:
return str(psycopg2.extensions.make_dsn(**connection_info))


def mask_connection_info(info: str) -> str:
masked_info = get_connection_info(info)
password = masked_info.pop("password", None)
connection_string = create_connection_string(masked_info)
message = "no password" if password is None else "hidden password"
return f"{connection_string}; {message}"


def get_connection_info_from_config_line(line):
def get_connection_info_from_config_line(line: str) -> ConnectionParameterKeywords:
_, value = line.split("=", 1)
value = value.strip()[1:-1].replace("''", "'")
return get_connection_info(value)


def get_connection_info(info):
"""turn a connection info object into a dict or return it if it was a
dict already. supports both the traditional libpq format and the new
url format"""
def get_connection_info(
info: str | DsnDict | DsnDictDeprecated | ConnectionParameterKeywords,
) -> ConnectionParameterKeywords:
"""Get a normalized connection info dict from a connection string or a dict.

Supports both the traditional libpq format and the new url format.
"""
if isinstance(info, dict):
return info.copy()
# Potentially, we might clean deprecated DSN dicts: `database` -> `dbname`.
# Also, psycopg2 will validate the keys and values.
return parse_connection_string_libpq(create_connection_string(info))
if info.startswith("postgres://") or info.startswith("postgresql://"):
return parse_connection_string_url(info)
return parse_connection_string_libpq(info)


def parse_connection_string_url(url):
def parse_connection_string_url(url: str) -> ConnectionParameterKeywords:
# drop scheme from the url as some versions of urlparse don't handle
# query and path properly for urls with a non-http scheme
schemeless_url = url.split(":", 1)[1]
Expand All @@ -57,12 +126,15 @@ def parse_connection_string_url(url):
fields["dbname"] = p.path[1:]
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields
return cast(ConnectionParameterKeywords, fields)


def parse_connection_string_libpq(connection_string: str) -> ConnectionParameterKeywords:
"""Parse a postgresql connection string.

def parse_connection_string_libpq(connection_string):
"""parse a postgresql connection string as defined in
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING"""
See:
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
"""
fields = {}
while True:
connection_string = connection_string.strip()
Expand Down Expand Up @@ -92,5 +164,8 @@ def parse_connection_string_libpq(connection_string):
value, connection_string = res
else:
value, connection_string = rem, ""
# This one is case-insensitive. To continue benefiting from mypy, we make it lowercase.
if key == "replication":
value = value.lower()
fields[key] = value
return fields
return cast(ConnectionParameterKeywords, fields)
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ exclude = [
'pglookout/current_master.py',
'pglookout/logutil.py',
'pglookout/pglookout.py',
'pglookout/pgutil.py',
'pglookout/statsd.py',
'pglookout/version.py',
'pglookout/webserver.py',
Expand All @@ -45,7 +44,6 @@ exclude = [
'test/test_cluster_monitor.py',
'test/test_common.py',
'test/test_lookout.py',
'test/test_pgutil.py',
'test/test_webserver.py',
# Other.
'setup.py',
Expand Down
8 changes: 4 additions & 4 deletions test/test_pgutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
See LICENSE for details
"""

from pglookout.pgutil import create_connection_string, get_connection_info, mask_connection_info
from pglookout.pgutil import ConnectionParameterKeywords, create_connection_string, get_connection_info, mask_connection_info
from pytest import raises


def test_connection_info():
def test_connection_info() -> None:
url = "postgres://hannu:[email protected]:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='hannu' dbname='abc'\nreplication=true password=secret sslmode=require port=5555"
ci = {
ci: ConnectionParameterKeywords = {
"host": "dbhost.local",
"port": "5555",
"user": "hannu",
Expand All @@ -39,7 +39,7 @@ def test_connection_info():
get_connection_info("foo=bar bar='x")


def test_mask_connection_info():
def test_mask_connection_info() -> None:
url = "postgres://michael:[email protected]:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='michael' dbname='abc'\nreplication=true password=secret sslmode=require port=5555"
ci = get_connection_info(cs)
Expand Down