From 20fd3e12bc30fee543ca7e9a331ea6c6cb9e5863 Mon Sep 17 00:00:00 2001 From: Jelmer Vernooij Date: Sat, 14 Sep 2024 20:43:33 +0000 Subject: [PATCH] Fix argument parsing and add tests. See #2 --- prometheus_xmpp/__main__.py | 76 ++++++++++++++++++++++++------------- tests/__init__.py | 3 +- tests/test_main.py | 34 +++++++++++++++++ 3 files changed, 86 insertions(+), 27 deletions(-) create mode 100644 tests/test_main.py diff --git a/prometheus_xmpp/__main__.py b/prometheus_xmpp/__main__.py index 73e739a..b6efeac 100755 --- a/prometheus_xmpp/__main__.py +++ b/prometheus_xmpp/__main__.py @@ -314,14 +314,14 @@ async def serve_root(request): body=INDEX) -def main(): +def parse_args(argv=None, env=os.environ): parser = argparse.ArgumentParser() parser.add_argument('--config', dest='config_path', type=str, default=None, help='Path to configuration file.') parser.add_argument('--optional-config', dest='optional_config_path', type=str, default=DEFAULT_CONF_PATH, - help=argparse.HIDDEN) + help=argparse.SUPPRESS) parser.add_argument("-q", "--quiet", help="set logging to ERROR", action="store_const", dest="loglevel", const=logging.ERROR, default=logging.INFO) @@ -329,7 +329,7 @@ def main(): action="store_const", dest="loglevel", const=logging.DEBUG, default=logging.INFO) - args = parser.parse_args() + args = parser.parse_args(argv) # Setup logging. logging.basicConfig(level=args.loglevel, format="%(levelname)-8s %(message)s") @@ -338,26 +338,29 @@ def main(): if os.path.isfile(args.optional_config_path): args.config_path = args.optional_config_path - with open(args.config_path) as f: - if getattr(yaml, "FullLoader", None): - config = yaml.load(f, Loader=yaml.FullLoader) # type: ignore - else: - # Backwards compatibility with older versions of Python - config = yaml.load(f) # type: ignore + if args.config_path: + with open(args.config_path) as f: + if getattr(yaml, "FullLoader", None): + config = yaml.load(f, Loader=yaml.FullLoader) # type: ignore + else: + # Backwards compatibility with older versions of Python + config = yaml.load(f) # type: ignore + else: + config = {} - if 'XMPP_ID' in os.environ: - jid = os.environ['XMPP_ID'] + if 'XMPP_ID' in env: + jid = env['XMPP_ID'] elif 'jid' in config: jid = config['jid'] else: - parser.error('no jid set in configuration or environment') + parser.error('no jid set in configuration (`jid`) or environment (`XMPP_ID`)') hostname = socket.gethostname() jid = "{}/{}".format(jid, hostname) - if 'XMPP_PASS' in os.environ: + if 'XMPP_PASS' in env: def password_cb(): - return os.environ['XMPP_PASS'] + return env['XMPP_PASS'] elif config.get('password'): def password_cb(): @@ -371,27 +374,48 @@ def password_cb(): def password_cb(): return None - if 'XMPP_RECIPIENTS' in os.environ: - recipients = os.environ['XMPP_RECIPIENTS'].split(',') + if 'XMPP_RECIPIENTS' in env: + recipients = env['XMPP_RECIPIENTS'].split(',') elif 'recipients' in config: recipients = config['recipients'] + if not isinstance(recipients, list): + recipients = [recipients] elif 'to_jid' in config: - recipients = config['to_jid'] + recipients = [config['to_jid']] else: parser.error( - 'no recipients specified in configuration or environment') + 'no recipients specified in configuration (`recipients` or `to_jid`) or environment (`XMPP_RECIPIENTS`)') - if 'XMPP_AMTOOL_ALLOWED' in os.environ: - amtool_allowed = os.environ['XMPP_AMTOOL_ALLOWED'].split(',') + if 'XMPP_AMTOOL_ALLOWED' in env: + amtool_allowed = env['XMPP_AMTOOL_ALLOWED'].split(',') + config['amtool_allowed'] = amtool_allowed elif 'amtool_allowed' in config: - amtool_allowed = config['amtool_allowed'] + if not isinstance(config['amtool_allowed'], list): + config['amtool_allowed'] = [config['amtool_allowed']] else: - amtool_allowed = list(recipients) + config['amtool_allowed'] = list(recipients) - if 'ALERTMANAGER_URL' in os.environ: - alertmanager_url = os.environ['ALERTMANAGER_URL'] - else: - alertmanager_url = config.get('alertmanager_url') + if 'ALERTMANAGER_URL' in env: + config['alertmanager_url'] = env['ALERTMANAGER_URL'] + + return ( + jid, + password_cb, + recipients, + config, + ) + + +def main(): + ( + jid, + password_cb, + recipients, + config, + ) = parse_args() + + amtool_allowed = config.get('amtool_allowed') + alertmanager_url = config.get('alertmanager_url') xmpp_app = XmppApp( jid, password_cb, diff --git a/tests/__init__.py b/tests/__init__.py index 107935c..207bfb7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -84,7 +84,8 @@ def test_parse_with_timezone(self): ) + def test_suite(): - module_names = ["tests"] + module_names = ["tests", "tests.test_main"] loader = unittest.TestLoader() return loader.loadTestsFromNames(module_names) diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..af42721 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,34 @@ +# __init__.py -- The tests for prometheus_xmpp +# Copyright (C) 2018 Jelmer Vernooij +# + +import tempfile +import unittest +from prometheus_xmpp.__main__ import parse_args + + +class TestParseArgs(unittest.TestCase): + + def test_parse_args_env(self): + (jid, password_cb, recipients, config) = parse_args([], env={'XMPP_ID': 'foo@bar', 'XMPP_PASS': 'baz', 'XMPP_AMTOOL_ALLOWED': 'jelmer@jelmer.uk', 'XMPP_RECIPIENTS': 'foo@bar.com'}) + + self.assertTrue(jid.startswith('foo@bar/')) + self.assertEqual(password_cb(), 'baz') + self.assertEqual(recipients, ['foo@bar.com']) + self.assertEqual(config['amtool_allowed'], ['jelmer@jelmer.uk']) + + def test_parse_args_config(self): + with tempfile.NamedTemporaryFile() as f: + f.write(b"""\ +jid: foo@bar +password: baz +to_jid: jelmer@jelmer.uk +amtool_allowed: foo@example.com +""") + f.flush() + (jid, password_cb, recipients, config) = parse_args(['--config', f.name], env={}) + + self.assertTrue(jid.startswith('foo@bar/')) + self.assertEqual(password_cb(), 'baz') + self.assertEqual(recipients, ['jelmer@jelmer.uk']) + self.assertEqual(config['amtool_allowed'], ['foo@example.com'])