diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index c65f68a6..5dd486d9 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -8,7 +8,7 @@ import asyncio import collections -from collections.abc import Callable +from collections.abc import Callable, Sequence import enum import functools import getpass @@ -41,10 +41,10 @@ class SSLMode(enum.IntEnum): verify_full = 5 @classmethod - def parse(cls, sslmode): - if isinstance(sslmode, cls): - return sslmode - return getattr(cls, sslmode.replace('-', '_')) + def parse(cls, sslmode: typing.Union[str, SSLMode]) -> SSLMode: + if isinstance(sslmode, str): + return getattr(cls, sslmode.replace('-', '_')) + return sslmode class SSLNegotiation(compat.StrEnum): @@ -52,20 +52,17 @@ class SSLNegotiation(compat.StrEnum): direct = "direct" -_ConnectionParameters = collections.namedtuple( - 'ConnectionParameters', - [ - 'user', - 'password', - 'database', - 'ssl', - 'sslmode', - 'ssl_negotiation', - 'server_settings', - 'target_session_attrs', - 'krbsrvname', - 'gsslib', - ]) +class _ConnectionParameters(typing.NamedTuple): + user: str + password: typing.Optional[str] + database: str + ssl: typing.Union[ssl_module.SSLContext, bool, str, SSLMode, None] + sslmode: SSLMode + ssl_negotiation: SSLNegotiation + server_settings: typing.Optional[typing.Dict[str, str]] + target_session_attrs: SessionAttribute + krbsrvname: typing.Optional[str] + gsslib: str _ClientConfiguration = collections.namedtuple( @@ -131,11 +128,13 @@ def _read_password_file(passfile: pathlib.Path) \ def _read_password_from_pgpass( - *, passfile: typing.Optional[pathlib.Path], - hosts: typing.List[str], - ports: typing.List[int], - database: str, - user: str): + *, + passfile: pathlib.Path, + hosts: Sequence[str], + ports: typing.List[int], + database: str, + user: str +) -> typing.Optional[str]: """Parse the pgpass file and return the matching password. :return: @@ -167,7 +166,9 @@ def _read_password_from_pgpass( return None -def _validate_port_spec(hosts, port): +def _validate_port_spec( + hosts: Sequence[object], port: typing.Union[int, typing.List[int]] +) -> typing.List[int]: if isinstance(port, list): # If there is a list of ports, its length must # match that of the host list. @@ -181,15 +182,20 @@ def _validate_port_spec(hosts, port): return port -def _parse_hostlist(hostlist, port, *, unquote=False): +def _parse_hostlist( + hostlist: str, + port: typing.Union[int, typing.List[int]], + *, + unquote: bool = False, +) -> typing.Tuple[typing.List[str], typing.List[int]]: if ',' in hostlist: # A comma-separated list of host addresses. hostspecs = hostlist.split(',') else: hostspecs = [hostlist] - hosts = [] - hostlist_ports = [] + hosts: typing.List[str] = [] + hostlist_ports: typing.List[int] = [] if not port: portspec = os.environ.get('PGPORT') @@ -267,10 +273,25 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: return (homedir / '.postgresql' / filename).resolve() -def _parse_connect_dsn_and_args(*, dsn, host, port, user, - password, passfile, database, ssl, - direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib): +def _parse_connect_dsn_and_args( + *, + dsn: str, + host: typing.Union[str, typing.List[str], typing.Tuple[str]], + port: typing.Union[int, typing.List[int]], + user: typing.Optional[str], + password: typing.Optional[str], + passfile: typing.Union[str, pathlib.Path, None], + database: typing.Optional[str], + ssl: typing.Union[bool, None, str, SSLMode], + direct_tls: typing.Optional[bool], + server_settings: typing.Optional[typing.Dict[str, str]], + target_session_attrs: typing.Optional[str], + krbsrvname: typing.Optional[str], + gsslib: typing.Optional[str], +) -> typing.Tuple[ + typing.List[typing.Union[str, typing.Tuple[str, int]]], + _ConnectionParameters, +]: # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -316,10 +337,12 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, password = urllib.parse.unquote(dsn_password) if parsed.query: - query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) - for key, val in query.items(): - if isinstance(val, list): - query[key] = val[-1] + query = { + key: val[-1] + for key, val in urllib.parse.parse_qs( + parsed.query, strict_parsing=True + ).items() + } if 'port' in query: val = query.pop('port') @@ -491,7 +514,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, database=database, user=user, passfile=passfile) - addrs = [] + addrs: typing.List[typing.Union[str, typing.Tuple[str, int]]] = [] have_tcp_addrs = False for h, p in zip(host, port): if h.startswith('/'):