This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f8c03fad3 [TVMScript] Support `T.launch_thread` with i64 dtype 
(#16916)
4f8c03fad3 is described below

commit 4f8c03fad393c360008f1fb208f117c66c04090c
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Apr 24 20:44:46 2024 +0800

    [TVMScript] Support `T.launch_thread` with i64 dtype (#16916)
    
    This PR fixes the bug of mismatched dtype in `T.launch_thread` when the 
dtype is `i64`.
---
 include/tvm/script/ir_builder/tir/ir.h                    |  3 ++-
 python/tvm/script/ir_builder/tir/ir.py                    |  7 +++++--
 src/script/ir_builder/tir/ir.cc                           | 10 +++++-----
 .../test_tir_transform_inject_ptx_async_copy.py           |  4 ++--
 tests/python/tvmscript/test_tvmscript_parser_tir.py       | 15 +++++++++++++++
 5 files changed, 29 insertions(+), 10 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index c4ba44f673..5b44f79ad7 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -401,9 +401,10 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr 
extent);
 /*!
  * \brief Bind a var to thread env.
  * \param thread_tag The thread type tag.
+ * \param dtype The data type of the variable.
  * \return The result variable which gets bound to the thread env.
  */
-Var EnvThread(String thread_tag);
+Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
 
 /*!
  * \brief Store data in a buffer.
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 127d2a4356..c04ac780c9 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1241,7 +1241,7 @@ def launch_thread(
     return _ffi_api.LaunchThread(thread, extent)  # type: ignore[attr-defined] 
# pylint: disable=no-member
 
 
-def env_thread(thread_tag: str) -> IterVar:
+def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar:
     """Bind a var to thread env
 
     Parameters
@@ -1249,13 +1249,16 @@ def env_thread(thread_tag: str) -> IterVar:
     thread_tag : str
         The thread type tag.
 
+    dtype : str
+        The data type of the thread env.
+
     Returns
     -------
     res : IterVar
         The result iteration variable gets bound to the thread env.
 
     """
-    return _ffi_api.EnvThread(thread_tag)  # type: ignore[attr-defined] # 
pylint: disable=no-member
+    return _ffi_api.EnvThread(thread_tag, dtype)  # type: ignore[attr-defined] 
# pylint: disable=no-member
 
 
 def buffer_store(
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index ccb5a8b57b..3ce5c15e6c 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -432,7 +432,8 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
   }
   ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
   if (!iter_var->dom.defined()) {
-    const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+    const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom =
+        Range(tvm::tir::make_zero(extent.dtype()), extent);
   } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
     LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
                << iter_var->dom->extent << " vs " << extent;
@@ -444,7 +445,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
 }
 
 LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
-  return LaunchThread(EnvThread(thread_tag), extent);
+  return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent);
 }
 
 RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
@@ -512,9 +513,8 @@ ElseFrame Else() {
   return ElseFrame(n);
 }
 
-Var EnvThread(String thread_tag) {
-  IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), 
tvm::tir::IterVarType::kThreadIndex,
-                   thread_tag);
+Var EnvThread(String thread_tag, DataType dtype) {
+  IterVar iter_var(Range{nullptr}, Var("", dtype), 
tvm::tir::IterVarType::kThreadIndex, thread_tag);
   Var var = iter_var->var;
   if (Optional<PrimFuncFrame> opt_frame = 
IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
     opt_frame.value()->env_threads.Set(var, iter_var);
diff --git 
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index 4c94dc04cc..c160e4a31d 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -969,9 +969,9 @@ class 
TestMultiplicationNodesAreInligned(tvm.testing.CompareBeforeAfter):
             T.ptx_cp_async(
                 "float16",
                 A_shared.data,
-                T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
+                tx * T.int64(128) + cse_var_1 * T.int64(8),
                 A.data,
-                T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
+                tx * T.int64(128) + cse_var_1 * T.int64(8),
                 16,
             )
         T.ptx_commit_group()
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py 
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 530746a6fc..25a904a157 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -471,5 +471,20 @@ def test_reinterpret_nop():
     tvm.ir.assert_structural_equal(func, expected)
 
 
+def test_launch_thread_i64():
+    """Test launching thread with int64"""
+
+    @T.prim_func
+    def func() -> None:
+        blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1))
+        if blockIdx_x == T.int64(0):
+            T.evaluate(T.int64(0))
+        else:
+            T.evaluate(T.int64(1))
+
+    assert func.body.node.dom.min.dtype == "int64"
+    assert func.body.node.dom.extent.dtype == "int64"
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to