This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 177d529  [TVMScript] Enable assignment statement without type 
annotation  (#10736)
177d529 is described below

commit 177d529f0edb2c30cc9e32c48b73e2ce942c4dc9
Author: Masahiro Masuda <[email protected]>
AuthorDate: Thu Mar 24 13:06:53 2022 +0900

    [TVMScript] Enable assignment statement without type annotation  (#10736)
    
    * Add test
    
    * workaround mypy
    
    * replace assert with condition
---
 python/tvm/script/parser.py                  |  8 ++++++-
 tests/python/unittest/test_tvmscript_type.py | 33 ++++++++++++++++++++++++++++
 2 files changed, 40 insertions(+), 1 deletion(-)

diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 17beb81..3291912 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -585,9 +585,15 @@ class TVMScriptParser(Transformer):
                         node.span,
                     )
                 ast_var = node.lhs[0]
+
+                if node.ty is None and hasattr(value, "dtype"):
+                    var_ty = value.dtype
+                else:
+                    var_ty = self.parse_type(node.ty, ast_var)
+
                 var = tvm.te.var(
                     ast_var.id.name,
-                    self.parse_type(node.ty, ast_var),
+                    var_ty,
                     span=tvm_span_from_synr(ast_var.span),
                 )
                 self.context.update_symbol(var.name, var, node)
diff --git a/tests/python/unittest/test_tvmscript_type.py 
b/tests/python/unittest/test_tvmscript_type.py
index 12954e3..8f0682c 100644
--- a/tests/python/unittest/test_tvmscript_type.py
+++ b/tests/python/unittest/test_tvmscript_type.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: 
disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
 from tvm.script import tir as T
+from tvm.script.registry import register
 
 """
 This prim func include necessary buffer types that need to be checked
@@ -177,6 +178,38 @@ def different_access_indices(a: T.handle, b: T.handle) -> 
None:
                 B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
 
 
+@register
+def int32x16(imm, span):
+    return imm.astype("int32x16", span)
+
+
+# Test assignment statements work without type annotation
[email protected]_func
+def dot_product_intrin_vnni(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+
+        A_u8x4 = A.vload([0], "uint8x4")  # type: ignore
+        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+        B_i8x64 = B.vload([0, 0], dtype="int8x64")  # type: ignore
+        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # type: 
ignore
+            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+            T.uint32(0),
+            T.int32x16(0),  # type: ignore
+            T.broadcast(A_i32, 16),
+            B_i32x16,
+            dtype="int32x16",
+        )
+
+
 # Not running any test as we only want to type-check here
 if __name__ == "__main__":
     pass

Reply via email to