This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6161a8d552 [BugFix][TVMScript]fix var capturing order error (#13640)
6161a8d552 is described below
commit 6161a8d55296fe02c6dbf403af2eeb6b6985ca4a
Author: lightzhan <[email protected]>
AuthorDate: Mon Dec 19 13:31:04 2022 +0800
[BugFix][TVMScript]fix var capturing order error (#13640)
This PR try to fix the following bug:
```python
def test_var_capturing_order():
b = 2
@T.prim_func
def test_case():
k: T.int32 = b
if __name__ == "__main__":
b = 1
```
In the prim func `test_case`, the vaule of b should be 2, rather than 1.
The parser wrongly uses global vars to shadow the value of nonlocal vars, which
should be reversed.
Co-authored-by: lightzhan-intellif <[email protected]>
---
python/tvm/script/parser/core/utils.py | 2 +-
tests/python/unittest/test_tvmscript_regression.py | 17 +++++++++++++++++
2 files changed, 18 insertions(+), 1 deletion(-)
diff --git a/python/tvm/script/parser/core/utils.py
b/python/tvm/script/parser/core/utils.py
index a304afddbe..453ac18b38 100644
--- a/python/tvm/script/parser/core/utils.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -37,8 +37,8 @@ def inspect_function_capture(func: Callable) -> Dict[str,
Any]:
The function variables map with non-local or global variables.
"""
captured = {
- **inspect.getclosurevars(func).nonlocals,
**func.__globals__, # type: ignore
+ **inspect.getclosurevars(func).nonlocals,
}
return captured
diff --git a/tests/python/unittest/test_tvmscript_regression.py
b/tests/python/unittest/test_tvmscript_regression.py
index 05c1665ea2..d063c0fcab 100644
--- a/tests/python/unittest/test_tvmscript_regression.py
+++ b/tests/python/unittest/test_tvmscript_regression.py
@@ -58,7 +58,24 @@ def test_different_dtype_assignment_to_var():
tvm.ir.assert_structural_equal(test_case, func_ref)
+def test_var_capturing_order():
+ b = 2
+
+ @T.prim_func
+ def test_case():
+ k: T.int32 = b
+
+ @T.prim_func
+ def func_ref():
+ k: T.int32 = 2
+ T.evaluate(0)
+
+ tvm.ir.assert_structural_equal(test_case, func_ref)
+
+
if __name__ == "__main__":
a = numpy.zeros((10, 10), dtype="int8")
test_multi_element_array_in_outmost_namespace()
test_different_dtype_assignment_to_var()
+ b = 1
+ test_var_capturing_order()