https://github.com/python/cpython/commit/c8f233c53b4634b820f2aa0efd45f43d9999aa1e
commit: c8f233c53b4634b820f2aa0efd45f43d9999aa1e
branch: main
author: Jelle Zijlstra <jelle.zijls...@gmail.com>
committer: JelleZijlstra <jelle.zijls...@gmail.com>
date: 2025-05-04T08:49:13-07:00
summary:

gh-132805: annotationlib: Fix handling of non-constant values in FORWARDREF 
(#132812)

Co-authored-by: David C Ellis <ducks...@gmail.com>

files:
A Misc/NEWS.d/next/Library/2025-04-22-13-42-12.gh-issue-132805.r-dhmJ.rst
M Lib/annotationlib.py
M Lib/test/test_annotationlib.py

diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
index 37f51e69f94127..cd24679f30abee 100644
--- a/Lib/annotationlib.py
+++ b/Lib/annotationlib.py
@@ -38,6 +38,7 @@ class Format(enum.IntEnum):
     "__weakref__",
     "__arg__",
     "__globals__",
+    "__extra_names__",
     "__code__",
     "__ast_node__",
     "__cell__",
@@ -82,6 +83,7 @@ def __init__(
         # is created through __class__ assignment on a _Stringifier object.
         self.__globals__ = None
         self.__cell__ = None
+        self.__extra_names__ = None
         # These are initially None but serve as a cache and may be set to a 
non-None
         # value later.
         self.__code__ = None
@@ -151,6 +153,8 @@ def evaluate(self, *, globals=None, locals=None, 
type_params=None, owner=None):
                 if not self.__forward_is_class__ or param_name not in globals:
                     globals[param_name] = param
                     locals.pop(param_name, None)
+        if self.__extra_names__:
+            locals = {**locals, **self.__extra_names__}
 
         arg = self.__forward_arg__
         if arg.isidentifier() and not keyword.iskeyword(arg):
@@ -231,6 +235,10 @@ def __eq__(self, other):
             and self.__forward_is_class__ == other.__forward_is_class__
             and self.__cell__ == other.__cell__
             and self.__owner__ == other.__owner__
+            and (
+                (tuple(sorted(self.__extra_names__.items())) if 
self.__extra_names__ else None) ==
+                (tuple(sorted(other.__extra_names__.items())) if 
other.__extra_names__ else None)
+            )
         )
 
     def __hash__(self):
@@ -241,6 +249,7 @@ def __hash__(self):
             self.__forward_is_class__,
             self.__cell__,
             self.__owner__,
+            tuple(sorted(self.__extra_names__.items())) if 
self.__extra_names__ else None,
         ))
 
     def __or__(self, other):
@@ -274,6 +283,7 @@ def __init__(
         cell=None,
         *,
         stringifier_dict,
+        extra_names=None,
     ):
         # Either an AST node or a simple str (for the common case where a 
ForwardRef
         # represent a single name).
@@ -285,6 +295,7 @@ def __init__(
         self.__code__ = None
         self.__ast_node__ = node
         self.__globals__ = globals
+        self.__extra_names__ = extra_names
         self.__cell__ = cell
         self.__owner__ = owner
         self.__stringifier_dict__ = stringifier_dict
@@ -292,28 +303,63 @@ def __init__(
     def __convert_to_ast(self, other):
         if isinstance(other, _Stringifier):
             if isinstance(other.__ast_node__, str):
-                return ast.Name(id=other.__ast_node__)
-            return other.__ast_node__
-        elif isinstance(other, slice):
+                return ast.Name(id=other.__ast_node__), other.__extra_names__
+            return other.__ast_node__, other.__extra_names__
+        elif (
+            # In STRING format we don't bother with the create_unique_name() 
dance;
+            # it's better to emit the repr() of the object instead of an 
opaque name.
+            self.__stringifier_dict__.format == Format.STRING
+            or other is None
+            or type(other) in (str, int, float, bool, complex)
+        ):
+            return ast.Constant(value=other), None
+        elif type(other) is dict:
+            extra_names = {}
+            keys = []
+            values = []
+            for key, value in other.items():
+                new_key, new_extra_names = self.__convert_to_ast(key)
+                if new_extra_names is not None:
+                    extra_names.update(new_extra_names)
+                keys.append(new_key)
+                new_value, new_extra_names = self.__convert_to_ast(value)
+                if new_extra_names is not None:
+                    extra_names.update(new_extra_names)
+                values.append(new_value)
+            return ast.Dict(keys, values), extra_names
+        elif type(other) in (list, tuple, set):
+            extra_names = {}
+            elts = []
+            for elt in other:
+                new_elt, new_extra_names = self.__convert_to_ast(elt)
+                if new_extra_names is not None:
+                    extra_names.update(new_extra_names)
+                elts.append(new_elt)
+            ast_class = {list: ast.List, tuple: ast.Tuple, set: 
ast.Set}[type(other)]
+            return ast_class(elts), extra_names
+        else:
+            name = self.__stringifier_dict__.create_unique_name()
+            return ast.Name(id=name), {name: other}
+
+    def __convert_to_ast_getitem(self, other):
+        if isinstance(other, slice):
+            extra_names = {}
+
+            def conv(obj):
+                if obj is None:
+                    return None
+                new_obj, new_extra_names = self.__convert_to_ast(obj)
+                if new_extra_names is not None:
+                    extra_names.update(new_extra_names)
+                return new_obj
+
             return ast.Slice(
-                lower=(
-                    self.__convert_to_ast(other.start)
-                    if other.start is not None
-                    else None
-                ),
-                upper=(
-                    self.__convert_to_ast(other.stop)
-                    if other.stop is not None
-                    else None
-                ),
-                step=(
-                    self.__convert_to_ast(other.step)
-                    if other.step is not None
-                    else None
-                ),
-            )
+                lower=conv(other.start),
+                upper=conv(other.stop),
+                step=conv(other.step),
+            ), extra_names
         else:
-            return ast.Constant(value=other)
+            return self.__convert_to_ast(other)
 
     def __get_ast(self):
         node = self.__ast_node__
@@ -321,13 +367,19 @@ def __get_ast(self):
             return ast.Name(id=node)
         return node
 
-    def __make_new(self, node):
+    def __make_new(self, node, extra_names=None):
+        new_extra_names = {}
+        if self.__extra_names__ is not None:
+            new_extra_names.update(self.__extra_names__)
+        if extra_names is not None:
+            new_extra_names.update(extra_names)
         stringifier = _Stringifier(
             node,
             self.__globals__,
             self.__owner__,
             self.__forward_is_class__,
             stringifier_dict=self.__stringifier_dict__,
+            extra_names=new_extra_names or None,
         )
         self.__stringifier_dict__.stringifiers.append(stringifier)
         return stringifier
@@ -343,27 +395,37 @@ def __getitem__(self, other):
         if self.__ast_node__ == "__classdict__":
             raise KeyError
         if isinstance(other, tuple):
-            elts = [self.__convert_to_ast(elt) for elt in other]
+            extra_names = {}
+            elts = []
+            for elt in other:
+                new_elt, new_extra_names = self.__convert_to_ast_getitem(elt)
+                if new_extra_names is not None:
+                    extra_names.update(new_extra_names)
+                elts.append(new_elt)
             other = ast.Tuple(elts)
         else:
-            other = self.__convert_to_ast(other)
+            other, extra_names = self.__convert_to_ast_getitem(other)
         assert isinstance(other, ast.AST), repr(other)
-        return self.__make_new(ast.Subscript(self.__get_ast(), other))
+        return self.__make_new(ast.Subscript(self.__get_ast(), other), 
extra_names)
 
     def __getattr__(self, attr):
         return self.__make_new(ast.Attribute(self.__get_ast(), attr))
 
     def __call__(self, *args, **kwargs):
-        return self.__make_new(
-            ast.Call(
-                self.__get_ast(),
-                [self.__convert_to_ast(arg) for arg in args],
-                [
-                    ast.keyword(key, self.__convert_to_ast(value))
-                    for key, value in kwargs.items()
-                ],
-            )
-        )
+        extra_names = {}
+        ast_args = []
+        for arg in args:
+            new_arg, new_extra_names = self.__convert_to_ast(arg)
+            if new_extra_names is not None:
+                extra_names.update(new_extra_names)
+            ast_args.append(new_arg)
+        ast_kwargs = []
+        for key, value in kwargs.items():
+            new_value, new_extra_names = self.__convert_to_ast(value)
+            if new_extra_names is not None:
+                extra_names.update(new_extra_names)
+            ast_kwargs.append(ast.keyword(key, new_value))
+        return self.__make_new(ast.Call(self.__get_ast(), ast_args, 
ast_kwargs), extra_names)
 
     def __iter__(self):
         yield self.__make_new(ast.Starred(self.__get_ast()))
@@ -378,8 +440,9 @@ def __format__(self, format_spec):
 
     def _make_binop(op: ast.AST):
         def binop(self, other):
+            rhs, extra_names = self.__convert_to_ast(other)
             return self.__make_new(
-                ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
+                ast.BinOp(self.__get_ast(), op, rhs), extra_names
             )
 
         return binop
@@ -402,8 +465,9 @@ def binop(self, other):
 
     def _make_rbinop(op: ast.AST):
         def rbinop(self, other):
+            new_other, extra_names = self.__convert_to_ast(other)
             return self.__make_new(
-                ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
+                ast.BinOp(new_other, op, self.__get_ast()), extra_names
             )
 
         return rbinop
@@ -426,12 +490,14 @@ def rbinop(self, other):
 
     def _make_compare(op):
         def compare(self, other):
+            rhs, extra_names = self.__convert_to_ast(other)
             return self.__make_new(
                 ast.Compare(
                     left=self.__get_ast(),
                     ops=[op],
-                    comparators=[self.__convert_to_ast(other)],
-                )
+                    comparators=[rhs],
+                ),
+                extra_names,
             )
 
         return compare
@@ -459,13 +525,15 @@ def unary_op(self):
 
 
 class _StringifierDict(dict):
-    def __init__(self, namespace, globals=None, owner=None, is_class=False):
+    def __init__(self, namespace, *, globals=None, owner=None, is_class=False, 
format):
         super().__init__(namespace)
         self.namespace = namespace
         self.globals = globals
         self.owner = owner
         self.is_class = is_class
         self.stringifiers = []
+        self.next_id = 1
+        self.format = format
 
     def __missing__(self, key):
         fwdref = _Stringifier(
@@ -478,6 +546,11 @@ def __missing__(self, key):
         self.stringifiers.append(fwdref)
         return fwdref
 
+    def create_unique_name(self):
+        name = f"__annotationlib_name_{self.next_id}__"
+        self.next_id += 1
+        return name
+
 
 def call_evaluate_function(evaluate, format, *, owner=None):
     """Call an evaluate function. Evaluate functions are normally generated for
@@ -521,7 +594,7 @@ def call_annotate_function(annotate, format, *, owner=None, 
_is_evaluate=False):
         # possibly constants if the annotate function uses them directly). We 
then
         # convert each of those into a string to get an approximation of the
         # original source.
-        globals = _StringifierDict({})
+        globals = _StringifierDict({}, format=format)
         if annotate.__closure__:
             freevars = annotate.__code__.co_freevars
             new_closure = []
@@ -544,9 +617,9 @@ def call_annotate_function(annotate, format, *, owner=None, 
_is_evaluate=False):
         )
         annos = func(Format.VALUE_WITH_FAKE_GLOBALS)
         if _is_evaluate:
-            return annos if isinstance(annos, str) else repr(annos)
+            return _stringify_single(annos)
         return {
-            key: val if isinstance(val, str) else repr(val)
+            key: _stringify_single(val)
             for key, val in annos.items()
         }
     elif format == Format.FORWARDREF:
@@ -569,7 +642,13 @@ def call_annotate_function(annotate, format, *, 
owner=None, _is_evaluate=False):
         # that returns a bool and an defined set of attributes.
         namespace = {**annotate.__builtins__, **annotate.__globals__}
         is_class = isinstance(owner, type)
-        globals = _StringifierDict(namespace, annotate.__globals__, owner, 
is_class)
+        globals = _StringifierDict(
+            namespace,
+            globals=annotate.__globals__,
+            owner=owner,
+            is_class=is_class,
+            format=format,
+        )
         if annotate.__closure__:
             freevars = annotate.__code__.co_freevars
             new_closure = []
@@ -619,6 +698,16 @@ def call_annotate_function(annotate, format, *, 
owner=None, _is_evaluate=False):
         raise ValueError(f"Invalid format: {format!r}")
 
 
+def _stringify_single(anno):
+    if anno is ...:
+        return "..."
+    # We have to handle str specially to support PEP 563 stringified 
annotations.
+    elif isinstance(anno, str):
+        return anno
+    else:
+        return repr(anno)
+
+
 def get_annotate_from_class_namespace(obj):
     """Retrieve the annotate function from a class namespace dictionary.
 
diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py
index 404a8ccc9d3741..d9000b6392277e 100644
--- a/Lib/test/test_annotationlib.py
+++ b/Lib/test/test_annotationlib.py
@@ -121,6 +121,28 @@ def f(
         self.assertIsInstance(gamma_anno, ForwardRef)
         self.assertEqual(gamma_anno, support.EqualToForwardRef("some < obj", 
owner=f))
 
+    def test_partially_nonexistent_union(self):
+        # Test unions with '|' syntax equal unions with typing.Union[] with 
some forwardrefs
+        class UnionForwardrefs:
+            pipe: str | undefined
+            union: Union[str, undefined]
+
+        annos = get_annotations(UnionForwardrefs, format=Format.FORWARDREF)
+
+        pipe = annos["pipe"]
+        self.assertIsInstance(pipe, ForwardRef)
+        self.assertEqual(
+            pipe.evaluate(globals={"undefined": int}),
+            str | int,
+        )
+        union = annos["union"]
+        self.assertIsInstance(union, Union)
+        arg1, arg2 = typing.get_args(union)
+        self.assertIs(arg1, str)
+        self.assertEqual(
+            arg2, support.EqualToForwardRef("undefined", is_class=True, 
owner=UnionForwardrefs)
+        )
+
 
 class TestStringFormat(unittest.TestCase):
     def test_closure(self):
@@ -251,6 +273,89 @@ def f(
             },
         )
 
+    def test_getitem(self):
+        def f(x: undef1[str, undef2]):
+            pass
+        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        self.assertEqual(anno, {"x": "undef1[str, undef2]"})
+
+        anno = annotationlib.get_annotations(f, format=Format.FORWARDREF)
+        fwdref = anno["x"]
+        self.assertIsInstance(fwdref, ForwardRef)
+        self.assertEqual(
+            fwdref.evaluate(globals={"undef1": dict, "undef2": float}), 
dict[str, float]
+        )
+
+    def test_slice(self):
+        def f(x: a[b:c]):
+            pass
+        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        self.assertEqual(anno, {"x": "a[b:c]"})
+
+        def f(x: a[b:c, d:e]):
+            pass
+        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        self.assertEqual(anno, {"x": "a[b:c, d:e]"})
+
+        obj = slice(1, 1, 1)
+        def f(x: obj):
+            pass
+        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        self.assertEqual(anno, {"x": "obj"})
+
+    def test_literals(self):
+        def f(
+            a: 1,
+            b: 1.0,
+            c: "hello",
+            d: b"hello",
+            e: True,
+            f: None,
+            g: ...,
+            h: 1j,
+        ):
+            pass
+
+        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        self.assertEqual(
+            anno,
+            {
+                "a": "1",
+                "b": "1.0",
+                "c": 'hello',
+                "d": "b'hello'",
+                "e": "True",
+                "f": "None",
+                "g": "...",
+                "h": "1j",
+            },
+        )
+
+    def test_displays(self):
+        # Simple case first
+        def f(x: a[[int, str], float]):
+            pass
+        anno = annotationlib.get_annotations(f, format=Format.STRING)
+        self.assertEqual(anno, {"x": "a[[int, str], float]"})
+
+        def g(
+            w: a[[int, str], float],
+            x: a[{int, str}, 3],
+            y: a[{int: str}, 4],
+            z: a[(int, str), 5],
+        ):
+            pass
+        anno = annotationlib.get_annotations(g, format=Format.STRING)
+        self.assertEqual(
+            anno,
+            {
+                "w": "a[[int, str], float]",
+                "x": "a[{int, str}, 3]",
+                "y": "a[{int: str}, 4]",
+                "z": "a[(int, str), 5]",
+            },
+        )
+
     def test_nested_expressions(self):
         def f(
             nested: list[Annotated[set[int], "set of ints", 4j]],
@@ -296,6 +401,17 @@ def f(fstring_format: f"{a:02d}"):
         with self.assertRaisesRegex(TypeError, format_msg):
             get_annotations(f, format=Format.STRING)
 
+    def test_shenanigans(self):
+        # In cases like this we can't reconstruct the source; test that we do 
something
+        # halfway reasonable.
+        def f(x: x | (1).__class__, y: (1).__class__):
+            pass
+
+        self.assertEqual(
+            get_annotations(f, format=Format.STRING),
+            {"x": "x | <class 'int'>", "y": "<class 'int'>"},
+        )
+
 
 class TestGetAnnotations(unittest.TestCase):
     def test_builtin_type(self):
diff --git 
a/Misc/NEWS.d/next/Library/2025-04-22-13-42-12.gh-issue-132805.r-dhmJ.rst 
b/Misc/NEWS.d/next/Library/2025-04-22-13-42-12.gh-issue-132805.r-dhmJ.rst
new file mode 100644
index 00000000000000..d62b95775a67c2
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2025-04-22-13-42-12.gh-issue-132805.r-dhmJ.rst
@@ -0,0 +1,2 @@
+Fix incorrect handling of nested non-constant values in the FORWARDREF
+format in :mod:`annotationlib`.

_______________________________________________
Python-checkins mailing list -- python-checkins@python.org
To unsubscribe send an email to python-checkins-le...@python.org
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: arch...@mail-archive.com

Reply via email to