Skip to content

Commit

Permalink
mypy: pglookout/pgutil.py [BF-1560]
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Giffard committed Mar 16, 2023
1 parent 75c5aaf commit f486036
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ long_ver = $(shell git describe --long 2>/dev/null || echo $(short_ver)-0-unknow
generated = pglookout/version.py

# Only include files that have been typed.
typed = pglookout/__main__.py
typed = pglookout/pgutil.py test/test_pgutil.py

# Flake8 ignores:
# E722: https://www.flake8rules.com/rules/E722.html Do not use bare except, specify exception instead
Expand Down
90 changes: 81 additions & 9 deletions pglookout/pgutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,110 @@
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""
from __future__ import annotations

from typing import cast
from typing_extensions import Literal, TypedDict
from urllib.parse import parse_qs, urlparse # pylint: disable=no-name-in-module, import-error

import psycopg2.extensions


def create_connection_string(connection_info):
class DsnDictBase(TypedDict):
user: str
password: str
host: str
port: str | int


class DsnDict(DsnDictBase):
dbname: str


class DsnDictDeprecated(DsnDictBase):
database: str


class ConnectionParameterKeywords(TypedDict, total=False):
"""Parameter Keywords for Connection.
See:
https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
"""

host: str
hostaddr: str
port: str
dbname: str
user: str
password: str
passfile: str
channel_binding: Literal["require", "prefer", "disable"]
connect_timeout: str
client_encoding: str
options: str
application_name: str
fallback_application_name: str
keepalives: Literal["0", "1"]
keepalives_idle: str
keepalives_interval: str
keepalives_count: str
tcp_user_timeout: str
replication: Literal["true", "on", "yes", "1", "database", "false", "off", "no", "0"]
gssencmode: Literal["disable", "prefer", "require"]
sslmode: Literal["disable", "allow", "prefer", "require", "verify-ca", "verify-full"]
requiressl: Literal["0", "1"]
sslcompression: Literal["0", "1"]
sslcert: str
sslkey: str
sslpassword: str
sslrootcert: str
sslcrl: str
sslcrldir: str
sslsni: Literal["0", "1"]
requirepeer: str
ssl_min_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"]
ssl_max_protocol_version: Literal["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"]
krbsrvname: str
gsslib: str
service: str
target_session_attrs: Literal["any", "read-write", "read-only", "primary", "standby", "prefer-standby"]


def create_connection_string(connection_info: DsnDict | DsnDictDeprecated | ConnectionParameterKeywords) -> str:
return psycopg2.extensions.make_dsn(**connection_info)


def mask_connection_info(info):
def mask_connection_info(info: str) -> str:
masked_info = get_connection_info(info)
password = masked_info.pop("password", None)
connection_string = create_connection_string(masked_info)
message = "no password" if password is None else "hidden password"
return f"{connection_string}; {message}"


def get_connection_info_from_config_line(line):
def get_connection_info_from_config_line(line: str) -> ConnectionParameterKeywords:
_, value = line.split("=", 1)
value = value.strip()[1:-1].replace("''", "'")
return get_connection_info(value)


def get_connection_info(info):
def get_connection_info(
info: str | DsnDict | DsnDictDeprecated | ConnectionParameterKeywords,
) -> ConnectionParameterKeywords:
"""turn a connection info object into a dict or return it if it was a
dict already. supports both the traditional libpq format and the new
url format"""
if isinstance(info, dict):
return info.copy()
# Potentially, we might clean deprecated DSN dicts: `database` -> `dbname`.
# Also, psycopg2 will validate the keys and values.
return parse_connection_string_libpq(create_connection_string(info))
if info.startswith("postgres://") or info.startswith("postgresql://"):
return parse_connection_string_url(info)
return parse_connection_string_libpq(info)


def parse_connection_string_url(url):
def parse_connection_string_url(url: str) -> ConnectionParameterKeywords:
# drop scheme from the url as some versions of urlparse don't handle
# query and path properly for urls with a non-http scheme
schemeless_url = url.split(":", 1)[1]
Expand All @@ -57,10 +126,10 @@ def parse_connection_string_url(url):
fields["dbname"] = p.path[1:]
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields
return cast(ConnectionParameterKeywords, fields)


def parse_connection_string_libpq(connection_string):
def parse_connection_string_libpq(connection_string: str) -> ConnectionParameterKeywords:
"""parse a postgresql connection string as defined in
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING"""
fields = {}
Expand Down Expand Up @@ -92,5 +161,8 @@ def parse_connection_string_libpq(connection_string):
value, connection_string = res
else:
value, connection_string = rem, ""
# This one is case-insensitive. To continue benefiting from mypy, we make it lowercase.
if key == "replication":
value = value.lower()
fields[key] = value
return fields
return cast(ConnectionParameterKeywords, fields)
8 changes: 4 additions & 4 deletions test/test_pgutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
See LICENSE for details
"""

from pglookout.pgutil import create_connection_string, get_connection_info, mask_connection_info
from pglookout.pgutil import ConnectionParameterKeywords, create_connection_string, get_connection_info, mask_connection_info
from pytest import raises


def test_connection_info():
def test_connection_info() -> None:
url = "postgres://hannu:[email protected]:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='hannu' dbname='abc'\nreplication=true password=secret sslmode=require port=5555"
ci = {
ci: ConnectionParameterKeywords = {
"host": "dbhost.local",
"port": "5555",
"user": "hannu",
Expand All @@ -39,7 +39,7 @@ def test_connection_info():
get_connection_info("foo=bar bar='x")


def test_mask_connection_info():
def test_mask_connection_info() -> None:
url = "postgres://michael:[email protected]:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='michael' dbname='abc'\nreplication=true password=secret sslmode=require port=5555"
ci = get_connection_info(cs)
Expand Down

0 comments on commit f486036

Please sign in to comment.