diff --git a/python-avd/pyavd/_utils/__init__.py b/python-avd/pyavd/_utils/__init__.py index 10a9fcec931..285dae80421 100644 --- a/python-avd/pyavd/_utils/__init__.py +++ b/python-avd/pyavd/_utils/__init__.py @@ -5,6 +5,7 @@ from .batch import batch from .compare_dicts import compare_dicts from .default import default +from .format_string import AvdStringFormatter from .get import get, get_v2 from .get_all import get_all, get_all_with_path from .get_indices_of_duplicate_items import get_indices_of_duplicate_items @@ -22,6 +23,7 @@ from .unique import unique __all__ = [ + "AvdStringFormatter", "append_if_not_duplicate", "batch", "compare_dicts", diff --git a/python-avd/pyavd/_utils/format_string.py b/python-avd/pyavd/_utils/format_string.py new file mode 100644 index 00000000000..0a68cb79a14 --- /dev/null +++ b/python-avd/pyavd/_utils/format_string.py @@ -0,0 +1,160 @@ +# Copyright (c) 2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +from collections.abc import Iterable +from string import Formatter + + +class AvdStringFormatter(Formatter): + """ + Custom string formatter class to provide extra protection from malicious format strings and support for prefixes and suffixes per field. + + The regular Python syntax is "{" [field_name] ["!" conversion] [":" format_spec] "}" + This class supports "{" [field_name] ["?"] ["<" prefix] [">" suffix] ["!" conversion] [":" format_spec] "}" + + where + ? ::= The literal ? signals that the field is optional and will not be printed if the value is missing or None. + prefix ::= string including spaces which will be inserted before the field value. + Most useful in combination with ?. Prefix should not contain "<", ">", "!" or ":". + suffix ::= string including spaces which will be inserted after the field value. + Most useful in combination with ?. Suffix should not contain "<", ">", "!" or ":". + conversion ::= "!u" for "upper()" (The regular Python conversions "!r", "!s", "!a" have been removed). + + Note the order of syntax field matters! + """ + + def _vformat(self, format_string: str, args: list, kwargs: dict, used_args: set, recursion_depth: int, auto_arg_index: int = 0) -> tuple[str, int]: + """ + Perform the actual formatting. + + Mostly a copy from the base class, but adding support for using "optional", "prefix" and "suffix" from the .parse() method. + + This should not be called directly. Instead call AvdStringFormatter().format(format_string, /, *args, **kwargs) + """ + if recursion_depth < 0: + msg = "Max string recursion exceeded" + raise ValueError(msg) + result = [] + for literal_text, org_field_name, org_format_spec, conversion, optional, prefix, suffix in self.parse(format_string): + # Make ruff happy. + field_name = org_field_name + format_spec = org_format_spec + + # output the literal text + if literal_text: + result.append(literal_text) + + # if there's a field, output it + if field_name is not None: + # this is some markup, find the object and do the formatting + + # handle arg indexing when empty field_names are given. + if field_name == "": + if auto_arg_index is False: + msg = "cannot switch from manual field specification to automatic field numbering" + raise ValueError(msg) + field_name = str(auto_arg_index) + auto_arg_index += 1 + elif field_name.isdigit(): + if auto_arg_index: + msg = "cannot switch from manual field specification to automatic field numbering" + raise ValueError(msg) + # disable auto arg incrementing, if it gets + # used later on, then an exception will be raised + auto_arg_index = False + + # given the field_name, find the object it references + # and the argument it came from + if optional: + try: + obj, arg_used = self.get_field(field_name, args, kwargs) + except (IndexError, KeyError): + # Skip this field if it is optional and not existing. + continue + if obj is None: + # Skip this field if it is optional and None. + continue + else: + obj, arg_used = self.get_field(field_name, args, kwargs) + + used_args.add(arg_used) + + # do any conversion on the resulting object + obj = self.convert_field(obj, conversion) + + # expand the format spec, if needed + format_spec, auto_arg_index = self._vformat(format_spec, args, kwargs, used_args, recursion_depth - 1, auto_arg_index=auto_arg_index) + + # Append prefix if set + if prefix: + result.append(prefix) + + # format the object and append to the result + result.append(self.format_field(obj, format_spec)) + + # Append suffix if set + if suffix: + result.append(suffix) + + return "".join(result), auto_arg_index + + def parse(self, format_string: str) -> Iterable[tuple[str, str | None, str | None, str | None, bool | None, str | None, str | None]]: + """ + Parse the format_string and yield elements back. + + Mostly a copy from the base class, but also returning "optional", "prefix" and "suffix" for every field. + """ + for literal_text, field_name, format_spec, conversion in super().parse(format_string): + if not field_name or not ("?" in field_name or ">" in field_name or "<" in field_name): + yield (literal_text, field_name, format_spec, conversion, None, None, None) + continue + + tmp_field_name = field_name + # Doing suffix first so the split will keep a potential prefix in the tmp_field_name + if ">" in tmp_field_name: + tmp_field_name, suffix = tmp_field_name.split(">", maxsplit=1) + else: + suffix = None + + if "<" in tmp_field_name: + tmp_field_name, prefix = tmp_field_name.split("<", maxsplit=1) + else: + prefix = None + + optional = tmp_field_name.endswith("?") + tmp_field_name = tmp_field_name.removesuffix("?") + + yield (literal_text, tmp_field_name, format_spec, conversion, optional, prefix, suffix) + + def convert_field(self, value: object, conversion: str | None) -> object: + """ + Convert the value according to the given conversion instruction. + + Mostly a copy from the base class, but only supporting !u for upper(). + """ + # do any conversion on the resulting object + if conversion is None: + return value + if conversion == "u": + return str(value).upper() + msg = f"Unknown conversion specifier {conversion!s}" + raise ValueError(msg) + + def get_field(self, field_name: str, args: list, kwargs: dict) -> tuple[object, str]: + """ + Get field value including parsing attributes/keys. + + Reusing base class after guarding against accessing attributes leading with underscore. + This protects against access to dunders etc. + """ + if not field_name or "_" not in field_name: + return super().get_field(field_name, args, kwargs) + + if any(attr.startswith("_") for attr in field_name.split(".")): + msg = f"Unsupported field name '{field_name}'. Avoid attributes starting with underscore." + raise ValueError(msg) + if any(key_and_more.startswith("_") for key_and_more in field_name.split("[")): + msg = f"Unsupported field name '{field_name}'. Avoid keys starting with underscore." + raise ValueError(msg) + + return super().get_field(field_name, args, kwargs) diff --git a/python-avd/tests/pyavd/utils/test_format_string.py b/python-avd/tests/pyavd/utils/test_format_string.py new file mode 100644 index 00000000000..4cd2cf52678 --- /dev/null +++ b/python-avd/tests/pyavd/utils/test_format_string.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. + +from __future__ import annotations + +import pytest + +from pyavd._utils import AvdStringFormatter + + +class DummyClass: + _private = "private" + public = "public" + + +FORMAT_STRING_TESTS = [ + # (, , , ) + # no fields + pytest.param("Ethernet1", (), {}, "Ethernet1", id="no_fields"), + pytest.param("Ethernet1", (), {"foo": "bar"}, "Ethernet1", id="no_fields_with_args"), + pytest.param("{{Ethernet1}}", (), {}, "{Ethernet1}", id="escaped_curly_brace"), + # named fields with upper + pytest.param("{interface!u}", (), {"interface": "Ethernet1"}, "ETHERNET1", id="field_with_existing_arg_and_upper"), + pytest.param("{interface?!u}", (), {}, "", id="optional_field_with_missing_arg_and_upper"), + pytest.param("{interface?!u}", (), {"interface": None}, "", id="optional_field_with_none_arg_and_upper"), + pytest.param("{interface?!u}", (), {"interface": "Ethernet1"}, "ETHERNET1", id="optional_field_with_existing_arg_and_upper"), + pytest.param("{interface.public?!u}", (), {"interface": DummyClass()}, "PUBLIC", id="optional_field_with_attribute_and_upper"), + # positional fields with upper + pytest.param("{!u}", ("Ethernet1",), {}, "ETHERNET1", id="positional_field_with_existing_arg_and_upper"), + pytest.param("{?!u}", (), {}, "", id="positional_optional_field_with_missing_arg_and_upper"), + pytest.param("{?!u}", (None,), {}, "", id="positional_optional_field_with_none_arg_and_upper"), + pytest.param("{?!u}", ("Ethernet1",), {}, "ETHERNET1", id="positional_optional_field_with_existing_arg_and_upper"), + pytest.param("{0?!u}{1?!u}{0?!u}", ("foo", "bar"), {}, "FOOBARFOO", id="positional_optional_repeated_fields_with_existing_args_and_upper"), + pytest.param("{0.public?!u}", (DummyClass(),), {}, "PUBLIC", id="positional_optional_field_with_attribute_and_upper"), + # named fields with prefix + pytest.param("{interfacefoo }", (), {"interface": "Ethernet1"}, "Ethernet1foo ", id="field_with_suffix_existing_arg"), + pytest.param("{interface?>foo}", (), {}, "", id="optional_field_with_suffix_missing_arg"), + pytest.param("{interface?> f o o }", (), {"interface": None}, "", id="optional_field_with_suffix_none_arg"), + pytest.param("{interface?> f o o }", (), {"interface": "Ethernet1"}, "Ethernet1 f o o ", id="optional_field_with_suffix_existing_arg"), + pytest.param("{interface.public?>foo}", (), {"interface": DummyClass()}, "publicfoo", id="optional_field_with_prefix_attribute"), + # positional fields with suffix + pytest.param("{>foo }", ("Ethernet1",), {}, "Ethernet1foo ", id="positional_field_with_suffix_existing_arg"), + pytest.param("{?>foo}", (), {}, "", id="positional_optional_field_with_suffix_missing_arg"), + pytest.param("{?> f o o }", (None,), {}, "", id="positional_optional_field_with_suffix_none_arg"), + pytest.param("{?> f o o }", ("Ethernet1",), {}, "Ethernet1 f o o ", id="positional_optional_field_with_suffix_existing_arg"), + pytest.param("{0>one}{1>two}{0>three}", ("foo", "bar"), {}, "fooonebartwofoothree", id="positional_repeated_fields_with_suffix_existing_args"), + # named fields with prefix and suffix + pytest.param("{interfacebar }", (), {"interface": "Ethernet1"}, "foo Ethernet1bar ", id="field_with_prefix_and_suffix_existing_arg"), + pytest.param("{interface?bar}", (), {}, "", id="optional_field_with_prefix_and_suffix_missing_arg"), + pytest.param("{interface?< f o o > b a r }", (), {"interface": None}, "", id="optional_field_with_prefix_and_suffix_none_arg"), + pytest.param( + "{interface?< f o o > b a r }", (), {"interface": "Ethernet1"}, " f o o Ethernet1 b a r ", id="optional_field_with_prefix_and_suffix_existing_arg" + ), + pytest.param("{interface.publicbar}", (), {"interface": DummyClass()}, "foopublicbar", id="field_with_prefix_attribute"), + # positional fields with prefix and suffix + pytest.param("{bar }", ("Ethernet1",), {}, "foo Ethernet1bar ", id="positional_field_with_prefix_and_suffix_existing_arg"), + pytest.param("{?bar}", (), {}, "", id="positional_optional_field_with_prefix_and_suffix_missing_arg"), + pytest.param("{?< f o o > b a r }", (None,), {}, "", id="positional_optional_field_with_prefix_and_suffix_none_arg"), + pytest.param("{?< f o o > b a r }", ("Ethernet1",), {}, " f o o Ethernet1 b a r ", id="positional_optional_field_with_prefix_and_suffix_existing_arg"), + pytest.param( + "{0one}_{1two}_{0three}", + ("foo", "bar"), + {}, + "aaafooone_bbbbartwo_cccfoothree", + id="positional_repeated_fields_with_prefix_and_suffix_existing_args", + ), + # positional fields with prefix and suffix and upper + pytest.param("{bar !u}", ("Ethernet1",), {}, "foo ETHERNET1bar ", id="positional_field_with_prefix_and_suffix_existing_arg_and_upper"), + pytest.param("{?bar!u}", (), {}, "", id="positional_optional_field_with_prefix_and_suffix_missing_arg_and_upper"), + pytest.param("{?< f o o > b a r !u}", (None,), {}, "", id="positional_optional_field_with_prefix_and_suffix_none_arg_and_upper"), + pytest.param( + "{?< f o o > b a r !u}", ("Ethernet1",), {}, " f o o ETHERNET1 b a r ", id="positional_optional_field_with_prefix_and_suffix_existing_arg_and_upper" + ), +] + + +SAFETY_TESTS = [ + # (, , ) + pytest.param("{foo.__class__.__name__}", (), {"foo": "bar"}, id="kwarg_dunder"), + pytest.param("{_foo}", (), {"_foo": "bar"}, id="kwarg_private"), + pytest.param("{foo._private}", (), {"foo": DummyClass()}, id="kwarg_private_attribute"), + pytest.param("{0.__class__.__name__}", ("foo",), {}, id="arg_dunder"), + pytest.param("{0._private}", (DummyClass(),), {}, id="arg_private_attribute"), +] + + +class TestAvdStringFormatter: + @pytest.mark.parametrize(("format_string", "args", "kwargs", "expected_output"), FORMAT_STRING_TESTS) + def test_avd_formatter(self, format_string: str, args: tuple, kwargs: dict, expected_output: list) -> None: + resp = AvdStringFormatter().format(format_string, *args, **kwargs) + assert resp == expected_output + + @pytest.mark.parametrize(("format_string", "args", "kwargs"), SAFETY_TESTS) + def test_avd_formatter_safety(self, format_string: str, args: tuple, kwargs: dict) -> None: + with pytest.raises(ValueError, match=r"Unsupported field name '.+'. Avoid (attributes|keys) starting with underscore."): + AvdStringFormatter().format(format_string, *args, **kwargs)