From 74f990a1d612ee05d6a72082d9fcfa2df7ab1d1f Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Tue, 15 Oct 2024 08:10:47 -0700 Subject: [PATCH] gh-125507: Call annotate(FORWARDREF) before trying __annotations__ Fixes #125507 --- Lib/annotationlib.py | 16 +++++++++------- Lib/test/test_annotationlib.py | 20 +++++++++++++++++++- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index d5166170c071c4..33e05460726399 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -536,6 +536,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): for key, val in annos.items() } elif format == Format.FORWARDREF: + # In FORWARDREF format, try returning the owner's __annotations__ first, + # if they exist. + if owner is not None: + try: + return _get_dunder_annotations(owner) + except NameError: + pass # FORWARDREF is implemented similarly to STRING, but there are two changes, # at the beginning and the end of the process. # First, while STRING uses an empty dictionary as the namespace, so that all @@ -683,13 +690,8 @@ def get_annotations( # For VALUE, we only look at __annotations__ ann = _get_dunder_annotations(obj) case Format.FORWARDREF: - # For FORWARDREF, we use __annotations__ if it exists - try: - return dict(_get_dunder_annotations(obj)) - except NameError: - pass - - # But if __annotations__ threw a NameError, we try calling __annotate__ + # First we use call_annotate_function(), which will internally also + # try __annotations__ if the FORWARDREF format is passed. ann = _get_and_call_annotate(obj, format) if ann is not None: return ann diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index eedf2506a14912..ff4912b381cedb 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -80,6 +80,20 @@ def f(x: int, y: doesntexist): fwdref.evaluate() self.assertEqual(fwdref.evaluate(globals={"doesntexist": 1}), 1) + def test_custom_annotate(self): + def __annotate__(format): + return {"a": Format(format).name} + + class C: + pass + + C.__annotate__ = __annotate__ + + for format in Format: + with self.subTest(format=format): + anno = annotationlib.get_annotations(C, format=format) + self.assertEqual(anno, {"a": format.name}) + class TestSourceFormat(unittest.TestCase): def test_closure(self): @@ -809,7 +823,11 @@ def __annotations__(self): @property def __annotate__(self): - return lambda format: {"x": str} + def anno(format): + if format == Format.FORWARDREF: + raise NotImplementedError + return {"x": str} + return anno hb = HasBoth() self.assertEqual(