This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 177d529 [TVMScript] Enable assignment statement without type
annotation (#10736)
177d529 is described below
commit 177d529f0edb2c30cc9e32c48b73e2ce942c4dc9
Author: Masahiro Masuda <[email protected]>
AuthorDate: Thu Mar 24 13:06:53 2022 +0900
[TVMScript] Enable assignment statement without type annotation (#10736)
* Add test
* workaround mypy
* replace assert with condition
---
python/tvm/script/parser.py | 8 ++++++-
tests/python/unittest/test_tvmscript_type.py | 33 ++++++++++++++++++++++++++++
2 files changed, 40 insertions(+), 1 deletion(-)
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 17beb81..3291912 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -585,9 +585,15 @@ class TVMScriptParser(Transformer):
node.span,
)
ast_var = node.lhs[0]
+
+ if node.ty is None and hasattr(value, "dtype"):
+ var_ty = value.dtype
+ else:
+ var_ty = self.parse_type(node.ty, ast_var)
+
var = tvm.te.var(
ast_var.id.name,
- self.parse_type(node.ty, ast_var),
+ var_ty,
span=tvm_span_from_synr(ast_var.span),
)
self.context.update_symbol(var.name, var, node)
diff --git a/tests/python/unittest/test_tvmscript_type.py
b/tests/python/unittest/test_tvmscript_type.py
index 12954e3..8f0682c 100644
--- a/tests/python/unittest/test_tvmscript_type.py
+++ b/tests/python/unittest/test_tvmscript_type.py
@@ -16,6 +16,7 @@
# under the License.
# pylint:
disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
from tvm.script import tir as T
+from tvm.script.registry import register
"""
This prim func include necessary buffer types that need to be checked
@@ -177,6 +178,38 @@ def different_access_indices(a: T.handle, b: T.handle) ->
None:
B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
+@register
+def int32x16(imm, span):
+ return imm.astype("int32x16", span)
+
+
+# Test assignment statements work without type annotation
[email protected]_func
+def dot_product_intrin_vnni(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+ B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+ C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+ T.writes(C[0:16])
+
+ A_u8x4 = A.vload([0], "uint8x4") # type: ignore
+ A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+ B_i8x64 = B.vload([0, 0], dtype="int8x64") # type: ignore
+ B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+ C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # type:
ignore
+ T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+ T.uint32(0),
+ T.int32x16(0), # type: ignore
+ T.broadcast(A_i32, 16),
+ B_i32x16,
+ dtype="int32x16",
+ )
+
+
# Not running any test as we only want to type-check here
if __name__ == "__main__":
pass