diff --git a/python/py_helpers.py b/python/py_helpers.py index 2f8c56d..247f228 100644 --- a/python/py_helpers.py +++ b/python/py_helpers.py @@ -118,6 +118,8 @@ def has_returns(self, returns_str): return returns_str == self.tree.returns.id elif isinstance(self.tree.returns, ast.Constant): return returns_str == self.tree.returns.value + elif isinstance((ann := self.tree.returns), ast.Subscript): + return Node(ann).is_equivalent(returns_str) return False def find_body(self): @@ -251,6 +253,25 @@ def find_variable(self, name): return Node(node) return Node() + def find_variables(self, name): + assignments = self._find_all((ast.Assign, ast.AnnAssign)) + var_list = [] + for node in assignments: + if isinstance(node.tree, ast.Assign): + for target in node.tree.targets: + if isinstance(target, ast.Name): + if target.id == name: + var_list.append(node) + if isinstance(target, ast.Attribute): + names = name.split(".") + if target.value.id == names[0] and target.attr == names[1]: + var_list.append(node) + elif isinstance(node.tree, ast.AnnAssign): + if isinstance(node.tree.target, ast.Name): + if node.tree.target.id == name: + var_list.append(node) + return var_list + # find variable incremented or decremented using += or -= def find_aug_variable(self, name): if not self._has_body(): diff --git a/python/py_helpers.test.py b/python/py_helpers.test.py index 6901674..41fd8e6 100644 --- a/python/py_helpers.test.py +++ b/python/py_helpers.test.py @@ -161,6 +161,23 @@ def foo(): ) self.assertEqual(node.find_function("foo").find_aug_variable("x"), Node()) + def test_find_variables(self): + code_str = """ +x: int = 0 +a.b = 0 +x = 5 +a.b = 2 +x = 10 +""" + node = Node(code_str) + self.assertEqual(len(node.find_variables("x")), 3) + self.assertTrue(node.find_variables("x")[0].is_equivalent("x: int = 0")) + self.assertTrue(node.find_variables("x")[1].is_equivalent("x = 5")) + self.assertTrue(node.find_variables("x")[2].is_equivalent("x = 10")) + self.assertEqual(len(node.find_variables("a.b")), 2) + self.assertTrue(node.find_variables("a.b")[0].is_equivalent("a.b = 0")) + self.assertTrue(node.find_variables("a.b")[1].is_equivalent("a.b = 2")) + class TestFunctionAndClassHelpers(unittest.TestCase): def test_find_function_returns_node(self): @@ -295,12 +312,16 @@ def foo(a: int, b: int) -> int: def test_has_returns(self): code_str = """ def foo() -> int: - pass + pass + +def spam() -> Dict[str, int]: + pass """ node = Node(code_str) self.assertTrue(node.find_function("foo").has_returns("int")) self.assertFalse(node.find_function("foo").has_returns("str")) + self.assertTrue(node.find_function("spam").has_returns("Dict[str, int]")) def test_has_returns_without_returns(self): code_str = """