This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6161a8d552 [BugFix][TVMScript]fix var capturing order error (#13640)
6161a8d552 is described below

commit 6161a8d55296fe02c6dbf403af2eeb6b6985ca4a
Author: lightzhan <[email protected]>
AuthorDate: Mon Dec 19 13:31:04 2022 +0800

    [BugFix][TVMScript]fix var capturing order error (#13640)
    
    This PR try to fix the following bug:
    
    ```python
    def test_var_capturing_order():
        b = 2
    
        @T.prim_func
        def test_case():
            k: T.int32 = b
    
    
    if __name__ == "__main__":
        b = 1
    ```
    
    In the prim func `test_case`, the vaule of b should be 2, rather than 1. 
The parser wrongly uses global vars to shadow the value of nonlocal vars, which 
should be reversed.
    
    Co-authored-by: lightzhan-intellif <[email protected]>
---
 python/tvm/script/parser/core/utils.py             |  2 +-
 tests/python/unittest/test_tvmscript_regression.py | 17 +++++++++++++++++
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/python/tvm/script/parser/core/utils.py 
b/python/tvm/script/parser/core/utils.py
index a304afddbe..453ac18b38 100644
--- a/python/tvm/script/parser/core/utils.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -37,8 +37,8 @@ def inspect_function_capture(func: Callable) -> Dict[str, 
Any]:
         The function variables map with non-local or global variables.
     """
     captured = {
-        **inspect.getclosurevars(func).nonlocals,
         **func.__globals__,  # type: ignore
+        **inspect.getclosurevars(func).nonlocals,
     }
     return captured
 
diff --git a/tests/python/unittest/test_tvmscript_regression.py 
b/tests/python/unittest/test_tvmscript_regression.py
index 05c1665ea2..d063c0fcab 100644
--- a/tests/python/unittest/test_tvmscript_regression.py
+++ b/tests/python/unittest/test_tvmscript_regression.py
@@ -58,7 +58,24 @@ def test_different_dtype_assignment_to_var():
     tvm.ir.assert_structural_equal(test_case, func_ref)
 
 
+def test_var_capturing_order():
+    b = 2
+
+    @T.prim_func
+    def test_case():
+        k: T.int32 = b
+
+    @T.prim_func
+    def func_ref():
+        k: T.int32 = 2
+        T.evaluate(0)
+
+    tvm.ir.assert_structural_equal(test_case, func_ref)
+
+
 if __name__ == "__main__":
     a = numpy.zeros((10, 10), dtype="int8")
     test_multi_element_array_in_outmost_namespace()
     test_different_dtype_assignment_to_var()
+    b = 1
+    test_var_capturing_order()

Reply via email to