diff --git a/pglookout/cluster_monitor.py b/pglookout/cluster_monitor.py index 287a79a..a642972 100644 --- a/pglookout/cluster_monitor.py +++ b/pglookout/cluster_monitor.py @@ -7,17 +7,23 @@ This file is under the Apache License, Version 2.0. See the file `LICENSE` for details. """ +from __future__ import annotations -from . import logutil -from .common import get_iso_timestamp, parse_iso_datetime -from .pgutil import mask_connection_info from concurrent.futures import as_completed, ThreadPoolExecutor from dataclasses import asdict, dataclass from email.utils import parsedate -from psycopg2.extras import RealDictCursor -from queue import Empty +from logging.handlers import SysLogHandler +from pglookout import logutil +from pglookout.common import get_iso_timestamp, parse_iso_datetime +from pglookout.common_types import MemberState, ObserverState, ReplicationSlotAsDict +from pglookout.config import Config +from pglookout.pgutil import mask_connection_info +from pglookout.statsd import StatsClient +from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE +from psycopg2.extras import RealDictCursor, RealDictRow +from queue import Empty, Queue from threading import Thread -from typing import List +from typing import Callable, cast, Final import datetime import errno @@ -27,6 +33,9 @@ import select import time +# https://www.psycopg.org/docs/connection.html#connection.server_version +PG_VERSION_10: Final[int] = 10_00_00 # 10.0.0 + class PglookoutTimeout(Exception): pass @@ -44,17 +53,17 @@ class ReplicationSlot: state_data: str -def wait_select(conn, timeout=5.0): +def wait_select(conn: psycopg2.connection, timeout: float = 5.0) -> None: end_time = time.monotonic() + timeout while time.monotonic() < end_time: time_left = end_time - time.monotonic() state = conn.poll() try: - if state == psycopg2.extensions.POLL_OK: + if state == POLL_OK: return - if state == psycopg2.extensions.POLL_READ: + if state == POLL_READ: select.select([conn.fileno()], [], [], min(timeout, time_left)) - elif state == psycopg2.extensions.POLL_WRITE: + elif state == POLL_WRITE: select.select([], [conn.fileno()], [], min(timeout, time_left)) else: raise psycopg2.OperationalError(f"bad state from poll: {state}") @@ -67,14 +76,14 @@ def wait_select(conn, timeout=5.0): class ClusterMonitor(Thread): def __init__( self, - config, - cluster_state, - observer_state, - create_alert_file, - cluster_monitor_check_queue, - failover_decision_queue, - is_replication_lag_over_warning_limit, - stats, + config: Config, + cluster_state: dict[str, MemberState], + observer_state: dict[str, ObserverState], + create_alert_file: Callable[[str], None], + cluster_monitor_check_queue: Queue[str], + failover_decision_queue: Queue[str], + is_replication_lag_over_warning_limit: Callable[[], bool], + stats: StatsClient, ): """Thread which collects cluster state. @@ -83,36 +92,42 @@ def __init__( in the cluster_state/observer_state dictionaries, which are shared with the main thread. """ Thread.__init__(self) - self.log = logging.getLogger("ClusterMonitor") - self.stats = stats - self.running = True - self.cluster_state = cluster_state - self.observer_state = observer_state - self.config = config - self.create_alert_file = create_alert_file - self.db_conns = {} - self.cluster_monitor_check_queue = cluster_monitor_check_queue - self.failover_decision_queue = failover_decision_queue - self.is_replication_lag_over_warning_limit = is_replication_lag_over_warning_limit - self.session = requests.Session() + self.log: logging.Logger = logging.getLogger("ClusterMonitor") + self.stats: StatsClient = stats + self.running: bool = True + self.cluster_state: dict[str, MemberState] = cluster_state + self.observer_state: dict[str, ObserverState] = observer_state + self.config: Config = config + self.create_alert_file: Callable[[str], None] = create_alert_file + self.db_conns: dict[str, psycopg2.connection | None] = {} + self.cluster_monitor_check_queue: Queue[str] = cluster_monitor_check_queue + self.failover_decision_queue: Queue[str] = failover_decision_queue + self.is_replication_lag_over_warning_limit: Callable[[], bool] = is_replication_lag_over_warning_limit + self.session: requests.Session = requests.Session() if self.config.get("syslog"): - self.syslog_handler = logutil.set_syslog_handler( + # Function `set_syslog_handler` already adds the handler to the provided logger. + # We just keep a reference to it here. + self.syslog_handler: SysLogHandler = logutil.set_syslog_handler( address=self.config.get("syslog_address", "/dev/log"), facility=self.config.get("syslog_facility", "local2"), logger=self.log, ) - self.last_monitoring_success_time = None + self.last_monitoring_success_time: float | None = None self.log.debug("Initialized ClusterMonitor with: %r", cluster_state) - def _connect_to_db(self, instance, dsn): + def _connect_to_db(self, instance: str, dsn: str | None) -> psycopg2.connection | None: conn = self.db_conns.get(instance) + if conn: return conn + if not dsn: self.log.warning("Can't connect to %s, dsn is %r", instance, dsn) return None + masked_connection_info = mask_connection_info(dsn) inst_info_str = f"{instance!r} ({masked_connection_info})" + try: self.log.info("Connecting to %s", inst_info_str) conn = psycopg2.connect(dsn=dsn, async_=True) @@ -133,19 +148,29 @@ def _connect_to_db(self, instance, dsn): self.log.exception("Failed to connect to %s (%s)", instance, inst_info_str) self.stats.unexpected_exception(ex, where="_connect_to_db") conn = None + self.db_conns[instance] = conn return conn - def _fetch_observer_state(self, instance, uri): - result = {"fetch_time": get_iso_timestamp(), "connection": True} + def _fetch_observer_state(self, instance: str, uri: str) -> ObserverState | None: + now_iso = get_iso_timestamp() + result = {"fetch_time": now_iso, "connection": True} fetch_uri = uri + "/state.json" + try: response = self.session.get(fetch_uri, timeout=5.0) # check time difference for large skews - remote_server_time = parsedate(response.headers["date"]) - remote_server_time = datetime.datetime.fromtimestamp(time.mktime(remote_server_time)) - time_diff = parse_iso_datetime(result["fetch_time"]) - remote_server_time + remote_server_ptime = parsedate(response.headers["date"]) + if remote_server_ptime is None: + self.log.error( + "Failed to parse date from observer node %r, response: %r, ignoring response", + instance, + response.json(), + ) + return None + remote_server_time = datetime.datetime.fromtimestamp(time.mktime(remote_server_ptime)) + time_diff = parse_iso_datetime(now_iso) - remote_server_time if time_diff > datetime.timedelta(seconds=5): self.log.error( "Time difference between us and observer node %r is %r, response: %r, ignoring response", @@ -168,16 +193,19 @@ def _fetch_observer_state(self, instance, uri): self.log.exception("Problem in fetching state from observer: %r, %r", instance, fetch_uri) self.stats.unexpected_exception(ex, where="_fetch_observer_state") result["connection"] = False + return result - def fetch_observer_state(self, instance, uri): + def fetch_observer_state(self, instance: str, uri: str) -> None: start_time = time.monotonic() result = self._fetch_observer_state(instance, uri) + if result: if instance in self.observer_state: self.observer_state[instance].update(result) else: self.observer_state[instance] = result + self.log.debug( "Observer: %r state was: %r, took: %.4fs to fetch", instance, @@ -185,18 +213,20 @@ def fetch_observer_state(self, instance, uri): time.monotonic() - start_time, ) - def connect_to_cluster_nodes_and_cleanup_old_nodes(self): + def connect_to_cluster_nodes_and_cleanup_old_nodes(self) -> None: leftover_conns = set(self.db_conns) - set(self.config.get("remote_conns", {})) + for leftover_instance in leftover_conns: self.log.debug("Removing leftover state for: %r", leftover_instance) self.db_conns.pop(leftover_instance) - self.cluster_state.pop(leftover_instance, "") - self.observer_state.pop(leftover_instance, "") + self.cluster_state.pop(leftover_instance, None) + self.observer_state.pop(leftover_instance, None) + # Making sure we have a connection to all currently configured db hosts for instance, connect_string in self.config.get("remote_conns", {}).items(): self._connect_to_db(instance, dsn=connect_string) - def _fetch_replication_slot_info(self, instance: str, cursor: RealDictCursor) -> List[ReplicationSlot]: + def _fetch_replication_slot_info(self, instance: str, cursor: RealDictCursor) -> list[ReplicationSlot]: """Fetch logical replication slot definitions""" self.log.debug("reading replication slot state from %r", instance) @@ -217,59 +247,49 @@ def _fetch_replication_slot_info(self, instance: str, cursor: RealDictCursor) -> """ ) wait_select(cursor.connection) - replication_slots = [ReplicationSlot(**slot) for slot in cursor.fetchall()] + replication_slots = [ + ReplicationSlot(**cast(RealDictRow, slot)) for slot in cursor.fetchall() # type: ignore[redundant-cast] + ] self.log.debug("found %d replication slot(s)", len(replication_slots)) return replication_slots - def _query_cluster_member_state(self, instance, db_conn): + def _query_cluster_member_state(self, instance: str, db_conn: psycopg2.connection | None) -> MemberState: """Query a single cluster member for its state""" - f_result = None - result = {"fetch_time": get_iso_timestamp(), "connection": False} + f_result: MemberState | None = None + result: MemberState = {"fetch_time": get_iso_timestamp(), "connection": False} + if not db_conn: - db_conn = self._connect_to_db(instance, self.config["remote_conns"].get(instance)) + dsn: str | None = self.config["remote_conns"].get(instance) + db_conn = self._connect_to_db(instance, dsn) if not db_conn: return result + phase = "querying status from" try: self.log.debug("%s %r", phase, instance) + c = db_conn.cursor(cursor_factory=RealDictCursor) - if db_conn.server_version >= 100000: - fields = [ - "now() AS db_time", - "pg_is_in_recovery()", - "pg_last_xact_replay_timestamp()", - "pg_last_wal_receive_lsn() AS pg_last_xlog_receive_location", - "pg_last_wal_replay_lsn() AS pg_last_xlog_replay_location", - ] - else: - fields = [ - "now() AS db_time", - "pg_is_in_recovery()", - "pg_last_xact_replay_timestamp()", - "pg_last_xlog_receive_location()", - "pg_last_xlog_replay_location()", - ] - joined_fields = ", ".join(fields) - c.execute(f"SELECT {joined_fields}") + + c.execute(self._get_statement_query_status(db_conn.server_version)) wait_select(c.connection) - maybe_standby_result = c.fetchone() + maybe_standby_result: MemberState = cast(MemberState, c.fetchone()) + if maybe_standby_result["pg_is_in_recovery"]: f_result = maybe_standby_result else: # First try reading current WAL LSN separately as txid_current may fail in some cases phase = "getting master LSN position" - if db_conn.server_version >= 100000: - wal_lsn_column = "pg_current_wal_lsn() AS pg_last_xlog_replay_location" - else: - wal_lsn_column = "pg_current_xlog_location() AS pg_last_xlog_replay_location" - c.execute(f"SELECT {wal_lsn_column}") + + c.execute(self._get_statement_query_master_lsn_position(db_conn.server_version)) wait_select(c.connection) - master_position = c.fetchone() + master_position: RealDictRow = cast(RealDictRow, c.fetchone()) maybe_standby_result["pg_last_xlog_replay_location"] = master_position["pg_last_xlog_replay_location"] f_result = maybe_standby_result - if db_conn.server_version >= 100000: - f_result["replication_slots"] = [asdict(slot) for slot in self._fetch_replication_slot_info(instance, c)] + if db_conn.server_version >= PG_VERSION_10: + f_result["replication_slots"] = [ + cast(ReplicationSlotAsDict, asdict(slot)) for slot in self._fetch_replication_slot_info(instance, c) + ] # This is only run on masters to create txid traffic every db_poll_interval phase = "updating transaction on" @@ -277,9 +297,10 @@ def _query_cluster_member_state(self, instance, db_conn): # With pg_current_wal_lsn we simulate replay_location on the master # With txid_current we force a new transaction to occur every poll interval to ensure there's # a heartbeat for the replication lag. - c.execute(f"SELECT txid_current(), {wal_lsn_column}") + c.execute(self._get_statement_query_updating_transaction(db_conn.server_version)) wait_select(c.connection) - master_result = c.fetchone() + master_result: RealDictRow = cast(RealDictRow, c.fetchone()) + f_result["pg_last_xlog_replay_location"] = master_result["pg_last_xlog_replay_location"] except ( PglookoutTimeout, @@ -296,12 +317,45 @@ def _query_cluster_member_state(self, instance, db_conn): return result @staticmethod - def _parse_status_query_result(result): + def _get_statement_query_status(server_version: int) -> str: + if server_version >= PG_VERSION_10: + return ( + "SELECT now() AS db_time, " + "pg_is_in_recovery(), " + "pg_last_xact_replay_timestamp(), " + "pg_last_wal_receive_lsn() AS pg_last_xlog_receive_location, " + "pg_last_wal_replay_lsn() AS pg_last_xlog_replay_location" + ) + return ( + "SELECT now() AS db_time, " + "pg_is_in_recovery(), " + "pg_last_xact_replay_timestamp(), " + "pg_last_xlog_receive_location(), " + "pg_last_xlog_replay_location()" + ) + + @staticmethod + def _get_statement_query_master_lsn_position(server_version: int) -> str: + if server_version >= PG_VERSION_10: + return "SELECT pg_current_wal_lsn() AS pg_last_xlog_replay_location" + return "SELECT pg_current_xlog_location() AS pg_last_xlog_replay_location" + + @staticmethod + def _get_statement_query_updating_transaction(server_version: int) -> str: + if server_version >= PG_VERSION_10: + return "SELECT txid_current(), pg_current_wal_lsn() AS pg_last_xlog_replay_location" + return "SELECT txid_current(), pg_current_xlog_location() AS pg_last_xlog_replay_location" + + # FIXME: Find a tighter input + return type + @staticmethod + def _parse_status_query_result(result: MemberState) -> MemberState: if not result: return {} + + db_time = cast(datetime.datetime, result["db_time"]) # abs is for catching time travel (as in going from the future to the past - if result["pg_last_xact_replay_timestamp"]: - replication_time_lag = abs(result["db_time"] - result["pg_last_xact_replay_timestamp"]) + if isinstance(result["pg_last_xact_replay_timestamp"], datetime.datetime): + replication_time_lag: datetime.timedelta = abs(db_time - result["pg_last_xact_replay_timestamp"]) result["replication_time_lag"] = replication_time_lag.total_seconds() result["pg_last_xact_replay_timestamp"] = get_iso_timestamp(result["pg_last_xact_replay_timestamp"]) @@ -317,10 +371,10 @@ def _parse_status_query_result(result): "replication_time_lag": None, # differentiate from actual lag=0.0 } ) - result.update({"db_time": get_iso_timestamp(result["db_time"]), "connection": True}) + result.update({"db_time": get_iso_timestamp(db_time), "connection": True}) return result - def update_cluster_member_state(self, instance, db_conn): + def update_cluster_member_state(self, instance: str, db_conn: psycopg2.connection | None) -> None: """Update the cluster state entry for a single cluster member""" start_time = time.monotonic() result = self._query_cluster_member_state(instance, db_conn) @@ -348,7 +402,7 @@ def update_cluster_member_state(self, instance, db_conn): else: self.cluster_state[instance]["min_replication_time_lag"] = min(min_lag, now_lag) - def main_monitoring_loop(self, requested_check=False): + def main_monitoring_loop(self, requested_check: bool = False) -> None: self.connect_to_cluster_nodes_and_cleanup_old_nodes() thread_count = len(self.db_conns) + len(self.config.get("observers", {})) futures = [] @@ -367,12 +421,14 @@ def main_monitoring_loop(self, requested_check=False): self.last_monitoring_success_time = time.monotonic() - def run(self): + def run(self) -> None: self.main_monitoring_loop() while self.running: requested_check = False try: - requested_check = self.cluster_monitor_check_queue.get(timeout=self.config.get("db_poll_interval", 5.0)) + requested_check = bool( + self.cluster_monitor_check_queue.get(timeout=self.config.get("db_poll_interval", 5.0)) + ) except Empty: pass self.main_monitoring_loop(requested_check) diff --git a/pglookout/config.py b/pglookout/config.py index 97f8f69..761263d 100644 --- a/pglookout/config.py +++ b/pglookout/config.py @@ -34,7 +34,7 @@ class Config(TypedDict, total=False): pg_stop_command: str poll_observers_on_warning_only: bool primary_conninfo_template: str - remote_conns: dict[str, str] + remote_conns: dict[str, str] # instance name -> dsn replication_catchup_timeout: float replication_state_check_interval: float statsd: Statsd diff --git a/pyproject.toml b/pyproject.toml index 8cae98b..e01c7e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,18 +31,20 @@ files = [ exclude = [ # Implementation. 'pglookout/__main__.py', - 'pglookout/cluster_monitor.py', 'pglookout/pglookout.py', 'pglookout/version.py', # Tests. 'test/conftest.py', - 'test/test_cluster_monitor.py', 'test/test_lookout.py', # Other. 'setup.py', 'version.py', ] +[[tool.mypy.overrides]] +module = "test.test_cluster_monitor" +# ignore no-untyped-call because conftest can only type hinted at the end. Remove at the end. +disallow_untyped_calls = false [tool.pylint.'MESSAGES CONTROL'] disable = [ diff --git a/requirements.dev.txt b/requirements.dev.txt index dca3637..1dd3f89 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -2,7 +2,6 @@ black==22.10.0 flake8 flake8-pyproject==1.2.2 isort==5.10.1 -mock mypy==1.1.1 types-psycopg2==2.9.21.8 types-requests==2.26.1 diff --git a/test/test_cluster_monitor.py b/test/test_cluster_monitor.py index 58d3ab7..710ccee 100644 --- a/test/test_cluster_monitor.py +++ b/test/test_cluster_monitor.py @@ -7,12 +7,15 @@ from .conftest import TestPG from contextlib import closing from datetime import datetime, timedelta -from mock import patch from packaging import version from pglookout import statsd from pglookout.cluster_monitor import ClusterMonitor +from pglookout.common_types import MemberState, ObserverState +from pglookout.config import Config from psycopg2.extras import RealDictCursor from queue import Queue +from typing import NoReturn +from unittest.mock import patch import base64 import psycopg2 @@ -20,28 +23,31 @@ import time -def test_replication_lag(): +def test_replication_lag() -> None: # pylint: disable=protected-access now = datetime.now() - status = { + status: MemberState = { "db_time": now, "pg_is_in_recovery": True, "pg_last_xact_replay_timestamp": now, "pg_last_xlog_receive_location": "0/0000001", "pg_last_xlog_replay_location": "0/0000002", } + result = ClusterMonitor._parse_status_query_result(status.copy()) assert result["replication_time_lag"] == 0.0 - status["db_time"] += timedelta(seconds=50, microseconds=42) + + status["db_time"] = now + timedelta(seconds=50, microseconds=42) result = ClusterMonitor._parse_status_query_result(status.copy()) assert result["replication_time_lag"] == 50.000042 + status["db_time"] = now + timedelta(hours=42) result = ClusterMonitor._parse_status_query_result(status.copy()) assert result["replication_time_lag"] == 151200.0 -def test_main_loop(db): - config = { +def test_main_loop(db: TestPG) -> None: + config: Config = { "remote_conns": { "test1db": db.connection_string("testuser"), "test2db": db.connection_string("otheruser"), @@ -49,14 +55,14 @@ def test_main_loop(db): "observers": {"local": "URL"}, "poll_observers_on_warning_only": True, } - cluster_state = {} - observer_state = {} + cluster_state: dict[str, MemberState] = {} + observer_state: dict[str, ObserverState] = {} - def create_alert_file(arg): + def create_alert_file(arg: str) -> NoReturn: raise Exception(arg) - cluster_monitor_check_queue = Queue() - failover_decision_queue = Queue() + cluster_monitor_check_queue: Queue[str] = Queue() + failover_decision_queue: Queue[str] = Queue() cm = ClusterMonitor( config=config, @@ -103,7 +109,7 @@ def test_fetch_replication_slot_info(db: TestPG) -> None: if version.parse(db.pgver) < version.parse("10"): pytest.skip(f"unsupported pg version: {db.pgver}") - config = { + config: Config = { "remote_conns": { "test1db": db.connection_string("testuser"), "test2db": db.connection_string("otheruser"), @@ -111,14 +117,14 @@ def test_fetch_replication_slot_info(db: TestPG) -> None: "observers": {"local": "URL"}, "poll_observers_on_warning_only": True, } - cluster_state = {} - observer_state = {} + cluster_state: dict[str, MemberState] = {} + observer_state: dict[str, ObserverState] = {} - def create_alert_file(arg): + def create_alert_file(arg: str) -> NoReturn: raise Exception(arg) - cluster_monitor_check_queue = Queue() - failover_decision_queue = Queue() + cluster_monitor_check_queue: Queue[str] = Queue() + failover_decision_queue: Queue[str] = Queue() cm = ClusterMonitor( config=config,