Skip to content

Commit

Permalink
Do not annotate the same variable multiple times in ApplyTypeAnnotati…
Browse files Browse the repository at this point in the history
…onsVisitor (#956)
  • Loading branch information
martindemello authored Jun 14, 2023
1 parent 3cacca1 commit 50d48c1
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 10 deletions.
26 changes: 16 additions & 10 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ def __init__(
# quotations to avoid undefined forward references in type annotations.
self.global_names: Set[str] = set()

# We use this to avoid annotating multiple assignments to the same
# symbol in a given scope
self.already_annotated: Set[str] = set()

@staticmethod
def store_stub_in_context(
context: CodemodContext,
Expand Down Expand Up @@ -945,17 +949,19 @@ def _annotate_single_target(
name = get_full_name_for_node(only_target)
if name is not None:
self.qualifier.append(name)
if (
self._qualifier_name() in self.annotations.attributes
and not isinstance(only_target, (cst.Attribute, cst.Subscript))
qualifier_name = self._qualifier_name()
if qualifier_name in self.annotations.attributes and not isinstance(
only_target, (cst.Attribute, cst.Subscript)
):
annotation = self.annotations.attributes[self._qualifier_name()]
self.qualifier.pop()
return self._apply_annotation_to_attribute_or_global(
name=name,
annotation=annotation,
value=node.value,
)
if qualifier_name not in self.already_annotated:
self.already_annotated.add(qualifier_name)
annotation = self.annotations.attributes[qualifier_name]
self.qualifier.pop()
return self._apply_annotation_to_attribute_or_global(
name=name,
annotation=annotation,
value=node.value,
)
else:
self.qualifier.pop()
return updated_node
Expand Down
55 changes: 55 additions & 0 deletions libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,3 +1910,58 @@ class C:
)
def test_valid_assign_expressions(self, stub: str, before: str, after: str) -> None:
self.run_simple_test_case(stub=stub, before=before, after=after)

@data_provider(
{
"toplevel": (
"""
x: int
""",
"""
x = 1
x = 2
""",
"""
x: int = 1
x = 2
""",
),
"class": (
"""
class A:
x: int
""",
"""
class A:
x = 1
x = 2
""",
"""
class A:
x: int = 1
x = 2
""",
),
"mixed": (
"""
x: int
class A:
x: int
""",
"""
x = 1
class A:
x = 1
x = 2
""",
"""
x: int = 1
class A:
x: int = 1
x = 2
""",
),
}
)
def test_no_duplicate_annotations(self, stub: str, before: str, after: str) -> None:
self.run_simple_test_case(stub=stub, before=before, after=after)

0 comments on commit 50d48c1

Please sign in to comment.