Skip to content

Commit

Permalink
Feat(plugins): Add AVD String Formatter for later use in custom descr…
Browse files Browse the repository at this point in the history
…iptions (#4432)
  • Loading branch information
ClausHolbechArista authored Sep 11, 2024
1 parent 8a35e5c commit 626cdd1
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python-avd/pyavd/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .batch import batch
from .compare_dicts import compare_dicts
from .default import default
from .format_string import AvdStringFormatter
from .get import get, get_v2
from .get_all import get_all, get_all_with_path
from .get_indices_of_duplicate_items import get_indices_of_duplicate_items
Expand All @@ -22,6 +23,7 @@
from .unique import unique

__all__ = [
"AvdStringFormatter",
"append_if_not_duplicate",
"batch",
"compare_dicts",
Expand Down
160 changes: 160 additions & 0 deletions python-avd/pyavd/_utils/format_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
from collections.abc import Iterable
from string import Formatter


class AvdStringFormatter(Formatter):
"""
Custom string formatter class to provide extra protection from malicious format strings and support for prefixes and suffixes per field.
The regular Python syntax is "{" [field_name] ["!" conversion] [":" format_spec] "}"
This class supports "{" [field_name] ["?"] ["<" prefix] [">" suffix] ["!" conversion] [":" format_spec] "}"
where
? ::= The literal ? signals that the field is optional and will not be printed if the value is missing or None.
prefix ::= string including spaces which will be inserted before the field value.
Most useful in combination with ?. Prefix should not contain "<", ">", "!" or ":".
suffix ::= string including spaces which will be inserted after the field value.
Most useful in combination with ?. Suffix should not contain "<", ">", "!" or ":".
conversion ::= "!u" for "upper()" (The regular Python conversions "!r", "!s", "!a" have been removed).
Note the order of syntax field matters!
"""

def _vformat(self, format_string: str, args: list, kwargs: dict, used_args: set, recursion_depth: int, auto_arg_index: int = 0) -> tuple[str, int]:
"""
Perform the actual formatting.
Mostly a copy from the base class, but adding support for using "optional", "prefix" and "suffix" from the .parse() method.
This should not be called directly. Instead call AvdStringFormatter().format(format_string, /, *args, **kwargs)
"""
if recursion_depth < 0:
msg = "Max string recursion exceeded"
raise ValueError(msg)
result = []
for literal_text, org_field_name, org_format_spec, conversion, optional, prefix, suffix in self.parse(format_string):
# Make ruff happy.
field_name = org_field_name
format_spec = org_format_spec

# output the literal text
if literal_text:
result.append(literal_text)

# if there's a field, output it
if field_name is not None:
# this is some markup, find the object and do the formatting

# handle arg indexing when empty field_names are given.
if field_name == "":
if auto_arg_index is False:
msg = "cannot switch from manual field specification to automatic field numbering"
raise ValueError(msg)
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
msg = "cannot switch from manual field specification to automatic field numbering"
raise ValueError(msg)
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False

# given the field_name, find the object it references
# and the argument it came from
if optional:
try:
obj, arg_used = self.get_field(field_name, args, kwargs)
except (IndexError, KeyError):
# Skip this field if it is optional and not existing.
continue
if obj is None:
# Skip this field if it is optional and None.
continue
else:
obj, arg_used = self.get_field(field_name, args, kwargs)

used_args.add(arg_used)

# do any conversion on the resulting object
obj = self.convert_field(obj, conversion)

# expand the format spec, if needed
format_spec, auto_arg_index = self._vformat(format_spec, args, kwargs, used_args, recursion_depth - 1, auto_arg_index=auto_arg_index)

# Append prefix if set
if prefix:
result.append(prefix)

# format the object and append to the result
result.append(self.format_field(obj, format_spec))

# Append suffix if set
if suffix:
result.append(suffix)

return "".join(result), auto_arg_index

def parse(self, format_string: str) -> Iterable[tuple[str, str | None, str | None, str | None, bool | None, str | None, str | None]]:
"""
Parse the format_string and yield elements back.
Mostly a copy from the base class, but also returning "optional", "prefix" and "suffix" for every field.
"""
for literal_text, field_name, format_spec, conversion in super().parse(format_string):
if not field_name or not ("?" in field_name or ">" in field_name or "<" in field_name):
yield (literal_text, field_name, format_spec, conversion, None, None, None)
continue

tmp_field_name = field_name
# Doing suffix first so the split will keep a potential prefix in the tmp_field_name
if ">" in tmp_field_name:
tmp_field_name, suffix = tmp_field_name.split(">", maxsplit=1)
else:
suffix = None

if "<" in tmp_field_name:
tmp_field_name, prefix = tmp_field_name.split("<", maxsplit=1)
else:
prefix = None

optional = tmp_field_name.endswith("?")
tmp_field_name = tmp_field_name.removesuffix("?")

yield (literal_text, tmp_field_name, format_spec, conversion, optional, prefix, suffix)

def convert_field(self, value: object, conversion: str | None) -> object:
"""
Convert the value according to the given conversion instruction.
Mostly a copy from the base class, but only supporting !u for upper().
"""
# do any conversion on the resulting object
if conversion is None:
return value
if conversion == "u":
return str(value).upper()
msg = f"Unknown conversion specifier {conversion!s}"
raise ValueError(msg)

def get_field(self, field_name: str, args: list, kwargs: dict) -> tuple[object, str]:
"""
Get field value including parsing attributes/keys.
Reusing base class after guarding against accessing attributes leading with underscore.
This protects against access to dunders etc.
"""
if not field_name or "_" not in field_name:
return super().get_field(field_name, args, kwargs)

if any(attr.startswith("_") for attr in field_name.split(".")):
msg = f"Unsupported field name '{field_name}'. Avoid attributes starting with underscore."
raise ValueError(msg)
if any(key_and_more.startswith("_") for key_and_more in field_name.split("[")):
msg = f"Unsupported field name '{field_name}'. Avoid keys starting with underscore."
raise ValueError(msg)

return super().get_field(field_name, args, kwargs)
109 changes: 109 additions & 0 deletions python-avd/tests/pyavd/utils/test_format_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2023-2024 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.

from __future__ import annotations

import pytest

from pyavd._utils import AvdStringFormatter


class DummyClass:
_private = "private"
public = "public"


FORMAT_STRING_TESTS = [
# (<format_string>, <args ()>, <kwargs {}>, <expected_output>)
# no fields
pytest.param("Ethernet1", (), {}, "Ethernet1", id="no_fields"),
pytest.param("Ethernet1", (), {"foo": "bar"}, "Ethernet1", id="no_fields_with_args"),
pytest.param("{{Ethernet1}}", (), {}, "{Ethernet1}", id="escaped_curly_brace"),
# named fields with upper
pytest.param("{interface!u}", (), {"interface": "Ethernet1"}, "ETHERNET1", id="field_with_existing_arg_and_upper"),
pytest.param("{interface?!u}", (), {}, "", id="optional_field_with_missing_arg_and_upper"),
pytest.param("{interface?!u}", (), {"interface": None}, "", id="optional_field_with_none_arg_and_upper"),
pytest.param("{interface?!u}", (), {"interface": "Ethernet1"}, "ETHERNET1", id="optional_field_with_existing_arg_and_upper"),
pytest.param("{interface.public?!u}", (), {"interface": DummyClass()}, "PUBLIC", id="optional_field_with_attribute_and_upper"),
# positional fields with upper
pytest.param("{!u}", ("Ethernet1",), {}, "ETHERNET1", id="positional_field_with_existing_arg_and_upper"),
pytest.param("{?!u}", (), {}, "", id="positional_optional_field_with_missing_arg_and_upper"),
pytest.param("{?!u}", (None,), {}, "", id="positional_optional_field_with_none_arg_and_upper"),
pytest.param("{?!u}", ("Ethernet1",), {}, "ETHERNET1", id="positional_optional_field_with_existing_arg_and_upper"),
pytest.param("{0?!u}{1?!u}{0?!u}", ("foo", "bar"), {}, "FOOBARFOO", id="positional_optional_repeated_fields_with_existing_args_and_upper"),
pytest.param("{0.public?!u}", (DummyClass(),), {}, "PUBLIC", id="positional_optional_field_with_attribute_and_upper"),
# named fields with prefix
pytest.param("{interface<foo }", (), {"interface": "Ethernet1"}, "foo Ethernet1", id="field_with_prefix_existing_arg"),
pytest.param("{interface?<foo}", (), {}, "", id="optional_field_with_prefix_missing_arg"),
pytest.param("{interface?< f o o }", (), {"interface": None}, "", id="optional_field_with_prefix_none_arg"),
pytest.param("{interface?< f o o }", (), {"interface": "Ethernet1"}, " f o o Ethernet1", id="optional_field_with_prefix_existing_arg"),
pytest.param("{interface.public?<foo}", (), {"interface": DummyClass()}, "foopublic", id="optional_field_with_prefix_attribute"),
# positional fields with prefix
pytest.param("{<foo }", ("Ethernet1",), {}, "foo Ethernet1", id="positional_field_with_prefix_existing_arg"),
pytest.param("{?<foo}", (), {}, "", id="positional_optional_field_with_prefix_missing_arg"),
pytest.param("{?< f o o }", (None,), {}, "", id="positional_optional_field_with_prefix_none_arg"),
pytest.param("{?< f o o }", ("Ethernet1",), {}, " f o o Ethernet1", id="positional_optional_field_with_prefix_existing_arg"),
pytest.param("{0<one}{1<two}{0<three}", ("foo", "bar"), {}, "onefootwobarthreefoo", id="positional_repeated_fields_with_prefix_existing_args"),
# named fields with suffix
pytest.param("{interface>foo }", (), {"interface": "Ethernet1"}, "Ethernet1foo ", id="field_with_suffix_existing_arg"),
pytest.param("{interface?>foo}", (), {}, "", id="optional_field_with_suffix_missing_arg"),
pytest.param("{interface?> f o o }", (), {"interface": None}, "", id="optional_field_with_suffix_none_arg"),
pytest.param("{interface?> f o o }", (), {"interface": "Ethernet1"}, "Ethernet1 f o o ", id="optional_field_with_suffix_existing_arg"),
pytest.param("{interface.public?>foo}", (), {"interface": DummyClass()}, "publicfoo", id="optional_field_with_prefix_attribute"),
# positional fields with suffix
pytest.param("{>foo }", ("Ethernet1",), {}, "Ethernet1foo ", id="positional_field_with_suffix_existing_arg"),
pytest.param("{?>foo}", (), {}, "", id="positional_optional_field_with_suffix_missing_arg"),
pytest.param("{?> f o o }", (None,), {}, "", id="positional_optional_field_with_suffix_none_arg"),
pytest.param("{?> f o o }", ("Ethernet1",), {}, "Ethernet1 f o o ", id="positional_optional_field_with_suffix_existing_arg"),
pytest.param("{0>one}{1>two}{0>three}", ("foo", "bar"), {}, "fooonebartwofoothree", id="positional_repeated_fields_with_suffix_existing_args"),
# named fields with prefix and suffix
pytest.param("{interface<foo >bar }", (), {"interface": "Ethernet1"}, "foo Ethernet1bar ", id="field_with_prefix_and_suffix_existing_arg"),
pytest.param("{interface?<foo>bar}", (), {}, "", id="optional_field_with_prefix_and_suffix_missing_arg"),
pytest.param("{interface?< f o o > b a r }", (), {"interface": None}, "", id="optional_field_with_prefix_and_suffix_none_arg"),
pytest.param(
"{interface?< f o o > b a r }", (), {"interface": "Ethernet1"}, " f o o Ethernet1 b a r ", id="optional_field_with_prefix_and_suffix_existing_arg"
),
pytest.param("{interface.public<foo>bar}", (), {"interface": DummyClass()}, "foopublicbar", id="field_with_prefix_attribute"),
# positional fields with prefix and suffix
pytest.param("{<foo >bar }", ("Ethernet1",), {}, "foo Ethernet1bar ", id="positional_field_with_prefix_and_suffix_existing_arg"),
pytest.param("{?<foo>bar}", (), {}, "", id="positional_optional_field_with_prefix_and_suffix_missing_arg"),
pytest.param("{?< f o o > b a r }", (None,), {}, "", id="positional_optional_field_with_prefix_and_suffix_none_arg"),
pytest.param("{?< f o o > b a r }", ("Ethernet1",), {}, " f o o Ethernet1 b a r ", id="positional_optional_field_with_prefix_and_suffix_existing_arg"),
pytest.param(
"{0<aaa>one}_{1<bbb>two}_{0<ccc>three}",
("foo", "bar"),
{},
"aaafooone_bbbbartwo_cccfoothree",
id="positional_repeated_fields_with_prefix_and_suffix_existing_args",
),
# positional fields with prefix and suffix and upper
pytest.param("{<foo >bar !u}", ("Ethernet1",), {}, "foo ETHERNET1bar ", id="positional_field_with_prefix_and_suffix_existing_arg_and_upper"),
pytest.param("{?<foo>bar!u}", (), {}, "", id="positional_optional_field_with_prefix_and_suffix_missing_arg_and_upper"),
pytest.param("{?< f o o > b a r !u}", (None,), {}, "", id="positional_optional_field_with_prefix_and_suffix_none_arg_and_upper"),
pytest.param(
"{?< f o o > b a r !u}", ("Ethernet1",), {}, " f o o ETHERNET1 b a r ", id="positional_optional_field_with_prefix_and_suffix_existing_arg_and_upper"
),
]


SAFETY_TESTS = [
# (<format_string>, <args ()>, <kwargs {}>)
pytest.param("{foo.__class__.__name__}", (), {"foo": "bar"}, id="kwarg_dunder"),
pytest.param("{_foo}", (), {"_foo": "bar"}, id="kwarg_private"),
pytest.param("{foo._private}", (), {"foo": DummyClass()}, id="kwarg_private_attribute"),
pytest.param("{0.__class__.__name__}", ("foo",), {}, id="arg_dunder"),
pytest.param("{0._private}", (DummyClass(),), {}, id="arg_private_attribute"),
]


class TestAvdStringFormatter:
@pytest.mark.parametrize(("format_string", "args", "kwargs", "expected_output"), FORMAT_STRING_TESTS)
def test_avd_formatter(self, format_string: str, args: tuple, kwargs: dict, expected_output: list) -> None:
resp = AvdStringFormatter().format(format_string, *args, **kwargs)
assert resp == expected_output

@pytest.mark.parametrize(("format_string", "args", "kwargs"), SAFETY_TESTS)
def test_avd_formatter_safety(self, format_string: str, args: tuple, kwargs: dict) -> None:
with pytest.raises(ValueError, match=r"Unsupported field name '.+'. Avoid (attributes|keys) starting with underscore."):
AvdStringFormatter().format(format_string, *args, **kwargs)

0 comments on commit 626cdd1

Please sign in to comment.