This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4096548d13 [BugFix][TVMScript] Parser crash (#13630)
4096548d13 is described below

commit 4096548d13cc8add8fe1f89d54f0968f89570461
Author: lightzhan <[email protected]>
AuthorDate: Sun Dec 18 09:44:49 2022 +0800

    [BugFix][TVMScript] Parser crash (#13630)
    
    This PR tries to fix the crash of parser when the old value of a var is an 
array but the new value is not. For example:
    
    ```python
    from tvm.script import tir as T
    def func_wrapper(shape, dtype):
        @T.prim_func
        def test_case():
            a = T.alloc_buffer(shape, dtype=dtype)
    
        return test_case
    
    
    if __name__ == "__main__":
        a = np.zeros((10, 10), dtype="int8")
        print(func_wrapper((256, 256), dtype="int8").script())
    ```
    
    In the above code, there are two assignment to var 'a'. In the global 
scope, its value is a numpy array. But it is a Buffer in the prim function. 
There is a table named 'name2value' to track the value of vars like 'a' here.
    When the parser wants to update its value, it will compare the value 
between the new and the old assignment. Here the problem comes. When we use 
'==' to compare an array with a value, the result is an array too, which can 
not be used as a condition of a if stmt directly. So, the code above will emit 
an error:
    
    ```shell
    error: The truth value of an array with more than one element is ambiguous. 
Use a.any() or a.all()
     --> /workspace/code_newest/tvm/private_test/test_meta_programming.py:16:9
        |
     16 |          a = T.alloc_buffer(shape, dtype=dtype)
        |          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ```
    
    This PR fixes this by change "==" to "is".
    
    Co-authored-by: lightzhan-intellif <[email protected]>
---
 python/tvm/script/parser/core/parser.py            |  8 ++++++--
 tests/python/unittest/test_tvmscript_regression.py | 15 +++++++++++++++
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/python/tvm/script/parser/core/parser.py 
b/python/tvm/script/parser/core/parser.py
index c6d43f11cb..7c699c42ae 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -19,6 +19,7 @@
 from collections import defaultdict
 from contextlib import contextmanager
 from typing import Any, Callable, Dict, List, Optional, Set, Union
+import numpy as np
 from tvm._ffi.base import TVMError
 
 from tvm.error import DiagnosticError
@@ -150,8 +151,11 @@ class VarTable:
             The options of whether variable shadowing allwed for this variable.
         """
         # Skip if the key and value are equal to those in the var_table
-        if self.name2value[var] and self.name2value[var][-1] == value:
-            return
+        if self.name2value[var] and isinstance(self.name2value[var][-1], 
type(value)):
+            if isinstance(value, np.ndarray) and (self.name2value[var][-1] == 
value).all():
+                return
+            elif self.name2value[var][-1] == value:
+                return
         if allow_shadowing and var in self.frames[-1].vars:
             # Shadowing
             self.name2value[var][-1] = value
diff --git a/tests/python/unittest/test_tvmscript_regression.py 
b/tests/python/unittest/test_tvmscript_regression.py
index 3ad8090893..05c1665ea2 100644
--- a/tests/python/unittest/test_tvmscript_regression.py
+++ b/tests/python/unittest/test_tvmscript_regression.py
@@ -45,5 +45,20 @@ def test_multi_element_array_in_outmost_namespace():
     tvm.ir.assert_structural_equal(func, rt_func)
 
 
+def test_different_dtype_assignment_to_var():
+    @T.prim_func
+    def test_case():
+        a = T.alloc_buffer((10, 10), dtype="int8")
+
+    @T.prim_func
+    def func_ref():
+        a = T.alloc_buffer([10, 10], dtype="int8")
+        T.evaluate(0)
+
+    tvm.ir.assert_structural_equal(test_case, func_ref)
+
+
 if __name__ == "__main__":
+    a = numpy.zeros((10, 10), dtype="int8")
     test_multi_element_array_in_outmost_namespace()
+    test_different_dtype_assignment_to_var()

Reply via email to