From a56116dd6471e8546a76132592a2d5bfbe38b6a0 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 13 Sep 2023 16:45:10 -0700 Subject: [PATCH] gh-109219: propagate used free vars through type param scopes [v2] --- Lib/test/test_type_params.py | 14 ++++++++++++++ Python/symtable.c | 17 +++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_type_params.py b/Lib/test/test_type_params.py index b1848aee4753a1..0a61be9773ac4a 100644 --- a/Lib/test/test_type_params.py +++ b/Lib/test/test_type_params.py @@ -694,6 +694,20 @@ class Cls: cls = ns["outer"]() self.assertEqual(cls.Alias.__value__, "class") + def test_nested_free(self): + ns = run_code(""" + def f(): + T = str + class C: + T = int + class D[U](T): + x = T + return C + """) + C = ns["f"]() + self.assertIn(int, C.D.__bases__) + self.assertIs(C.D.x, str) + class TypeParamsManglingTest(unittest.TestCase): def test_mangling(self): diff --git a/Python/symtable.c b/Python/symtable.c index d737c09203d31b..b50d837a20cae3 100644 --- a/Python/symtable.c +++ b/Python/symtable.c @@ -784,7 +784,8 @@ drop_class_free(PySTEntryObject *ste, PyObject *free) static int update_symbols(PyObject *symbols, PyObject *scopes, PyObject *bound, PyObject *free, - PyObject *inlined_cells, int classflag) + PyObject *inlined_cells, int classflag, + PySTEntryObject *class_entry) { PyObject *name = NULL, *itr = NULL; PyObject *v = NULL, *v_scope = NULL, *v_new = NULL, *v_free = NULL; @@ -836,8 +837,16 @@ update_symbols(PyObject *symbols, PyObject *scopes, the class that has the same name as a local or global in the class scope. */ - if (classflag && - PyLong_AS_LONG(v) & (DEF_BOUND | DEF_GLOBAL)) { + PyObject *class_v = NULL; + if (class_entry) { + class_v = PyDict_GetItemWithError(class_entry->ste_symbols, name); + if (!class_v && PyErr_Occurred()) { + goto error; + } + } + if ((classflag && + PyLong_AS_LONG(v) & (DEF_BOUND | DEF_GLOBAL)) || + (class_v && PyLong_AS_LONG(class_v) & (DEF_BOUND | DEF_GLOBAL))) { long flags = PyLong_AS_LONG(v) | DEF_FREE_CLASS; v_new = PyLong_FromLong(flags); if (!v_new) { @@ -1078,7 +1087,7 @@ analyze_block(PySTEntryObject *ste, PyObject *bound, PyObject *free, goto error; /* Records the results of the analysis in the symbol table entry */ if (!update_symbols(ste->ste_symbols, scopes, bound, newfree, inlined_cells, - ste->ste_type == ClassBlock)) + ste->ste_type == ClassBlock, class_entry)) goto error; temp = PyNumber_InPlaceOr(free, newfree);