This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 61f80814e6 [TVMScript] Fix PEP 563 closure variable resolution (#18856)
61f80814e6 is described below

commit 61f80814e655c97be48e47cd19b09ea5a8636f4b
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Feb 28 22:38:50 2026 -0500

    [TVMScript] Fix PEP 563 closure variable resolution (#18856)
    
    With `from __future__ import annotations`, Python stores annotations as
    strings
    and does not capture annotation-only variables in `__closure__`. This
    broke
    TVMScript when buffer shapes/dtypes referenced closure variables.
    
    Fix: wrap `extra_vars` in a `collections.ChainMap` with snapshots of all
    live
    caller-frame locals (from `inspect.stack()`) as fallback layers in both
    `tir/entry.py` (`prim_func`) and `ir/entry.py` (`ir_module`). The
    `ir_module`
    function now also captures `outer_stack = inspect.stack()` at its entry
    point,
    mirroring the existing pattern in `prim_func`. Lookup falls back to
    frame locals
    only on cache miss, preserving existing behavior for non-PEP-563 code.
    
    Add `tests/python/tvmscript/test_tvmscript_pep563_closure.py` (requires
    `from __future__ import annotations` at the top) covering closure
    variables in
    buffer shapes, dtypes, nested scopes, ir_module, and mixed
    annotation+body use.
---
 python/tvm/script/parser/core/utils.py             |  82 +++++++++++
 python/tvm/script/parser/ir/entry.py               |   8 +-
 python/tvm/script/parser/tir/entry.py              |   4 +-
 .../tvmscript/test_tvmscript_pep563_closure.py     | 158 +++++++++++++++++++++
 4 files changed, 250 insertions(+), 2 deletions(-)

diff --git a/python/tvm/script/parser/core/utils.py 
b/python/tvm/script/parser/core/utils.py
index 85190b96d9..fc8a928e05 100644
--- a/python/tvm/script/parser/core/utils.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -89,6 +89,88 @@ def inspect_class_capture(cls: type) -> dict[str, Any]:
     return result
 
 
+def _collect_annotation_names(source_obj: type | Callable) -> set[str]:
+    """Parse source AST to find names used in function annotations.
+
+    Returns the set of ``ast.Name`` identifiers found inside argument
+    annotations and return annotations of any function definitions in
+    *source_obj*.
+    """
+    import ast
+    import textwrap
+
+    try:
+        source = textwrap.dedent(inspect.getsource(source_obj))
+        tree = ast.parse(source)
+    except (OSError, TypeError):
+        return set()
+
+    names: set[str] = set()
+    for node in ast.walk(tree):
+        if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
+            for arg in node.args.args + node.args.posonlyargs + 
node.args.kwonlyargs:
+                if arg.annotation:
+                    for n in ast.walk(arg.annotation):
+                        if isinstance(n, ast.Name):
+                            names.add(n.id)
+            if node.returns:
+                for n in ast.walk(node.returns):
+                    if isinstance(n, ast.Name):
+                        names.add(n.id)
+    return names
+
+
+def _has_string_annotations(source_obj: type | Callable) -> bool:
+    """Check if *source_obj* has stringified annotations (PEP 563)."""
+    if inspect.isclass(source_obj):
+        return any(
+            isinstance(a, str)
+            for v in source_obj.__dict__.values()
+            if inspect.isfunction(v)
+            for a in v.__annotations__.values()
+        )
+    return any(isinstance(a, str) for a in getattr(source_obj, 
"__annotations__", {}).values())
+
+
+def _get_enclosing_scope_names(qualname: str) -> set[str]:
+    """Extract lexically enclosing scope names from ``__qualname__``.
+
+    For ``outer.<locals>.inner.<locals>.func`` this returns ``{"outer", 
"inner"}``.
+    """
+    parts = qualname.split(".")
+    return {p for p in parts[:-1] if p != "<locals>"}
+
+
+def resolve_closure_vars(
+    source_obj: type | Callable, extra_vars: dict[str, Any], outer_stack: list
+) -> None:
+    """Resolve closure variables hidden by PEP 563.
+
+    With ``from __future__ import annotations``, variables used only in
+    annotations are not captured in ``__closure__``.  This function parses
+    the source AST to find names used in function annotations, then looks
+    them up in lexically enclosing scope frames identified via
+    ``__qualname__``.
+
+    Only triggered when annotations are actually strings (PEP 563 active).
+    Only annotation-referenced names are added, and only from enclosing
+    scopes — not from arbitrary caller frames.
+
+    Works for both classes (``@I.ir_module``) and functions (``@T.prim_func``).
+    """
+    if not _has_string_annotations(source_obj):
+        return
+    ann_names = _collect_annotation_names(source_obj)
+    enclosing = _get_enclosing_scope_names(source_obj.__qualname__)
+    for name in ann_names:
+        if name not in extra_vars:
+            for frame_info in outer_stack[1:]:
+                if frame_info.frame.f_code.co_name in enclosing:
+                    if name in frame_info.frame.f_locals:
+                        extra_vars[name] = frame_info.frame.f_locals[name]
+                        break
+
+
 def is_defined_in_class(frames: list[FrameType], obj: Any) -> bool:
     """Check whether a object is defined in a class scope.
 
diff --git a/python/tvm/script/parser/ir/entry.py 
b/python/tvm/script/parser/ir/entry.py
index 8f7a5be663..b0685e3db0 100644
--- a/python/tvm/script/parser/ir/entry.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -46,6 +46,9 @@ def ir_module(mod: type | None = None, check_well_formed: 
bool = True) -> IRModu
         The parsed ir module.
     """
 
+    # Capture stack outside wrapper (wrapper adds to the stack)
+    outer_stack = inspect.stack()
+
     def decorator_wrapper(mod):
         if not inspect.isclass(mod):
             raise TypeError(f"Expect a class, but got: {mod}")
@@ -53,7 +56,10 @@ def ir_module(mod: type | None = None, check_well_formed: 
bool = True) -> IRModu
         # Check BasePyModule inheritance
         base_py_module_inherited = any(base.__name__ == "BasePyModule" for 
base in mod.__bases__)
 
-        m = parse(mod, utils.inspect_class_capture(mod), 
check_well_formed=check_well_formed)
+        extra_vars = utils.inspect_class_capture(mod)
+        # Resolve closure variables hidden by PEP 563 (annotation-only names)
+        utils.resolve_closure_vars(mod, extra_vars, outer_stack)
+        m = parse(mod, extra_vars, check_well_formed=check_well_formed)
 
         if base_py_module_inherited:
             # Lazy import: tvm.relax cannot be imported at module level in 
tvm.script.parser
diff --git a/python/tvm/script/parser/tir/entry.py 
b/python/tvm/script/parser/tir/entry.py
index da09851e67..d0486b0d9f 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -63,7 +63,9 @@ def prim_func(
             raise TypeError(f"Expect a function, but got: {func}")
         if utils.is_defined_in_class(outer_stack, func):
             return func
-        f = parse(func, utils.inspect_function_capture(func), 
check_well_formed=check_well_formed)
+        extra_vars = utils.inspect_function_capture(func)
+        utils.resolve_closure_vars(func, extra_vars, outer_stack)
+        f = parse(func, extra_vars, check_well_formed=check_well_formed)
         setattr(f, "__name__", func.__name__)
         return f
 
diff --git a/tests/python/tvmscript/test_tvmscript_pep563_closure.py 
b/tests/python/tvmscript/test_tvmscript_pep563_closure.py
new file mode 100644
index 0000000000..a5d26d7f16
--- /dev/null
+++ b/tests/python/tvmscript/test_tvmscript_pep563_closure.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test TVMScript with PEP 563 (from __future__ import annotations).
+
+IMPORTANT: The `from __future__ import annotations` import below is the
+test condition itself, because we need to test compatibility with it.
+"""
+
+from __future__ import annotations
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I
+from tvm.script import tir as T
+
+
+def _normalize(func):
+    """Strip the global_symbol so function names do not affect structural 
equality."""
+    return func.with_attr("global_symbol", "")
+
+
+def test_prim_func_closure_shape():
+    """Closure variable used in Buffer shape annotation."""
+
+    def f(M=16):
+        @T.prim_func
+        def func(A: T.Buffer((M,), "float32")):
+            T.evaluate(0)
+
+        return func
+
+    @T.prim_func
+    def expected_16(A: T.Buffer((16,), "float32")):
+        T.evaluate(0)
+
+    @T.prim_func
+    def expected_32(A: T.Buffer((32,), "float32")):
+        T.evaluate(0)
+
+    tvm.ir.assert_structural_equal(_normalize(f(16)), _normalize(expected_16))
+    tvm.ir.assert_structural_equal(_normalize(f(32)), _normalize(expected_32))
+
+
+def test_prim_func_closure_dtype():
+    """Closure variable used as Buffer dtype."""
+
+    def f(dtype="float32"):
+        @T.prim_func
+        def func(A: T.Buffer((16,), dtype)):
+            T.evaluate(0)
+
+        return func
+
+    @T.prim_func
+    def expected_f32(A: T.Buffer((16,), "float32")):
+        T.evaluate(0)
+
+    @T.prim_func
+    def expected_f16(A: T.Buffer((16,), "float16")):
+        T.evaluate(0)
+
+    tvm.ir.assert_structural_equal(_normalize(f("float32")), 
_normalize(expected_f32))
+    tvm.ir.assert_structural_equal(_normalize(f("float16")), 
_normalize(expected_f16))
+
+
+def test_prim_func_nested_closure():
+    """Variables from enclosing scope active on the call stack (grandparent 
frame fallback).
+
+    With PEP 563, closure-only variables are missing from __closure__ unless 
they
+    appear in the function body. The ChainMap fallback walks the live call 
stack,
+    so this works when the enclosing frames are still active (outer calls 
middle
+    which applies the decorator, keeping outer's frame alive on the stack).
+    """
+
+    def outer(M=16):
+        def middle(N=8):
+            @T.prim_func
+            def func(A: T.Buffer((M, N), "float32")):
+                T.evaluate(0)
+
+            return func
+
+        return middle()
+
+    @T.prim_func
+    def expected_16_8(A: T.Buffer((16, 8), "float32")):
+        T.evaluate(0)
+
+    @T.prim_func
+    def expected_32_8(A: T.Buffer((32, 8), "float32")):
+        T.evaluate(0)
+
+    tvm.ir.assert_structural_equal(_normalize(outer(16)), 
_normalize(expected_16_8))
+    tvm.ir.assert_structural_equal(_normalize(outer(32)), 
_normalize(expected_32_8))
+
+
+def test_ir_module_closure():
+    """Closure variable in @I.ir_module class method."""
+
+    def f(M=16):
+        @I.ir_module
+        class Mod:
+            @T.prim_func
+            def main(A: T.Buffer((M,), "float32")):
+                T.evaluate(0)
+
+        return Mod
+
+    @T.prim_func
+    def expected_16(A: T.Buffer((16,), "float32")):
+        T.evaluate(0)
+
+    @T.prim_func
+    def expected_32(A: T.Buffer((32,), "float32")):
+        T.evaluate(0)
+
+    tvm.ir.assert_structural_equal(_normalize(f(16)["main"]), 
_normalize(expected_16))
+    tvm.ir.assert_structural_equal(_normalize(f(32)["main"]), 
_normalize(expected_32))
+
+
+def test_mixed_closure_usage():
+    """Closure var used in both annotation AND body -- regression check."""
+
+    def f(M=16):
+        @T.prim_func
+        def func(A: T.Buffer((M,), "float32")):
+            T.evaluate(M)
+
+        return func
+
+    @T.prim_func
+    def expected_16(A: T.Buffer((16,), "float32")):
+        T.evaluate(16)
+
+    @T.prim_func
+    def expected_32(A: T.Buffer((32,), "float32")):
+        T.evaluate(32)
+
+    tvm.ir.assert_structural_equal(_normalize(f(16)), _normalize(expected_16))
+    tvm.ir.assert_structural_equal(_normalize(f(32)), _normalize(expected_32))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to