This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c5365dc7b [TIR] Fix dtype mismatch in UnifyThreadBinding (#11843)
3c5365dc7b is described below

commit 3c5365dc7b51096f743ece82601f4442796de622
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jun 22 22:17:28 2022 -0700

    [TIR] Fix dtype mismatch in UnifyThreadBinding (#11843)
    
    This PR fixed dtype mismatch in UnifyThreadBinding when multiple thread 
axes with the same thread tag have different dtype.
---
 src/tir/transforms/unify_thread_binding.cc         |  9 ++---
 .../test_tir_transform_unify_thread_binding.py     | 41 ++++++++++++++++++++++
 2 files changed, 46 insertions(+), 4 deletions(-)

diff --git a/src/tir/transforms/unify_thread_binding.cc 
b/src/tir/transforms/unify_thread_binding.cc
index 8210079f75..da725f7f8e 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -109,8 +109,9 @@ class ThreadBindingUnifier : public StmtExprMutator {
     }
 
     // Step 4. We will substitute the occurrences of the old variable in the 
old IterVar with the
-    // new variable in further mutation. Thus, we store the mapping entry.
-    var_substitution_map_.Set(old_var, new_iter_var->var);
+    // new variable in further mutation. Thus, we store the mapping entry. 
Cast to old dtype if
+    // needed (we assume both old and new dtype are valid for the range of the 
thread extent).
+    var_substitution_map_.Set(old_var, cast(old_var.dtype(), 
new_iter_var->var));
 
     // Step 5. Mutate recursively, update the body with the new IterVar, and 
restore the depth
     // counter. Emit for-loops to launch threads if current statement is the 
outermost thread
@@ -149,7 +150,7 @@ class ThreadBindingUnifier : public StmtExprMutator {
   PrimExpr VisitExpr_(const VarNode* var) final {
     // If this variable appears as a key in `var_substitution_map_`, we 
substitute it with its
     // corresponding value in the mapping.
-    Map<Var, Var>::iterator it = var_substitution_map_.find(GetRef<Var>(var));
+    Map<Var, PrimExpr>::iterator it = 
var_substitution_map_.find(GetRef<Var>(var));
     return it != var_substitution_map_.end() ? (*it).second : GetRef<Var>(var);
   }
 
@@ -164,7 +165,7 @@ class ThreadBindingUnifier : public StmtExprMutator {
    */
   Array<IterVar> launch_threads_;
   /*! \brief A mapping from old variables to new variables, which is used for 
substitution */
-  Map<Var, Var> var_substitution_map_;
+  Map<Var, PrimExpr> var_substitution_map_;
   /*! \brief A integer counter storing the depth of thread bindings of 
"blockIdx.x/y/z" */
   int thread_block_depth_ = 0;
   /*! \brief An analyzer used for equality proof */
diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py 
b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
index 457c43a763..90fce22bc1 100644
--- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
+++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
@@ -72,6 +72,43 @@ def unified_element_wise_thread_x(a: T.handle, b: T.handle, 
c: T.handle) -> None
                     )
 
 
[email protected]_func
+def element_wise_thread_x_different_dtype(
+    A: T.Buffer[(128, 128), "float32"],
+    B: T.Buffer[(128, 128), "float32"],
+    C: T.Buffer[(128, 128), "float32"],
+) -> None:
+    for i in T.thread_binding(128, "blockIdx.x"):
+        for j0_0 in T.thread_binding(4, "threadIdx.x"):
+            for j0_1 in T.serial(0, 32):
+                with T.block(""):
+                    B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0
+        for j1_0 in T.thread_binding(T.int64(4), "threadIdx.x"):
+            for j1_1 in T.serial(T.int64(32)):
+                with T.block(""):
+                    C[i, j1_0 * T.int64(32) + j1_1] = B[i, j1_0 * T.int64(32) 
+ j1_1] + 1.0
+
+
[email protected]_func
+def unified_element_wise_thread_x_different_dtype(
+    A: T.Buffer[(128, 128), "float32"],
+    B: T.Buffer[(128, 128), "float32"],
+    C: T.Buffer[(128, 128), "float32"],
+) -> None:
+    for blockIdx_x in T.thread_binding(128, "blockIdx.x"):
+        for threadIdx_x in T.thread_binding(4, "threadIdx.x"):
+            for j0_1 in T.serial(0, 32):
+                with T.block(""):
+                    B[blockIdx_x, threadIdx_x * 32 + j0_1] = (
+                        A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0
+                    )
+            for j1_1 in T.serial(T.int64(32)):
+                with T.block(""):
+                    C[blockIdx_x, T.cast(threadIdx_x, "int64") * T.int64(32) + 
j1_1] = (
+                        B[blockIdx_x, T.cast(threadIdx_x, "int64") * 
T.int64(32) + j1_1] + 1.0
+                    )
+
+
 @T.prim_func
 def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
     j1_0 = T.env_thread("threadIdx.x")
@@ -223,6 +260,10 @@ def test_thread_x():
     _check(element_wise_thread_x, unified_element_wise_thread_x)
 
 
+def test_thread_x_different_dtype():
+    _check(element_wise_thread_x_different_dtype, 
unified_element_wise_thread_x_different_dtype)
+
+
 def test_env_thread_x():
     _check(element_wise_env_thread_x, unified_element_wise_env_thread_x)
 

Reply via email to