This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3c5365dc7b [TIR] Fix dtype mismatch in UnifyThreadBinding (#11843)
3c5365dc7b is described below
commit 3c5365dc7b51096f743ece82601f4442796de622
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jun 22 22:17:28 2022 -0700
[TIR] Fix dtype mismatch in UnifyThreadBinding (#11843)
This PR fixed dtype mismatch in UnifyThreadBinding when multiple thread
axes with the same thread tag have different dtype.
---
src/tir/transforms/unify_thread_binding.cc | 9 ++---
.../test_tir_transform_unify_thread_binding.py | 41 ++++++++++++++++++++++
2 files changed, 46 insertions(+), 4 deletions(-)
diff --git a/src/tir/transforms/unify_thread_binding.cc
b/src/tir/transforms/unify_thread_binding.cc
index 8210079f75..da725f7f8e 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -109,8 +109,9 @@ class ThreadBindingUnifier : public StmtExprMutator {
}
// Step 4. We will substitute the occurrences of the old variable in the
old IterVar with the
- // new variable in further mutation. Thus, we store the mapping entry.
- var_substitution_map_.Set(old_var, new_iter_var->var);
+ // new variable in further mutation. Thus, we store the mapping entry.
Cast to old dtype if
+ // needed (we assume both old and new dtype are valid for the range of the
thread extent).
+ var_substitution_map_.Set(old_var, cast(old_var.dtype(),
new_iter_var->var));
// Step 5. Mutate recursively, update the body with the new IterVar, and
restore the depth
// counter. Emit for-loops to launch threads if current statement is the
outermost thread
@@ -149,7 +150,7 @@ class ThreadBindingUnifier : public StmtExprMutator {
PrimExpr VisitExpr_(const VarNode* var) final {
// If this variable appears as a key in `var_substitution_map_`, we
substitute it with its
// corresponding value in the mapping.
- Map<Var, Var>::iterator it = var_substitution_map_.find(GetRef<Var>(var));
+ Map<Var, PrimExpr>::iterator it =
var_substitution_map_.find(GetRef<Var>(var));
return it != var_substitution_map_.end() ? (*it).second : GetRef<Var>(var);
}
@@ -164,7 +165,7 @@ class ThreadBindingUnifier : public StmtExprMutator {
*/
Array<IterVar> launch_threads_;
/*! \brief A mapping from old variables to new variables, which is used for
substitution */
- Map<Var, Var> var_substitution_map_;
+ Map<Var, PrimExpr> var_substitution_map_;
/*! \brief A integer counter storing the depth of thread bindings of
"blockIdx.x/y/z" */
int thread_block_depth_ = 0;
/*! \brief An analyzer used for equality proof */
diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
index 457c43a763..90fce22bc1 100644
--- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
+++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
@@ -72,6 +72,43 @@ def unified_element_wise_thread_x(a: T.handle, b: T.handle,
c: T.handle) -> None
)
[email protected]_func
+def element_wise_thread_x_different_dtype(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+) -> None:
+ for i in T.thread_binding(128, "blockIdx.x"):
+ for j0_0 in T.thread_binding(4, "threadIdx.x"):
+ for j0_1 in T.serial(0, 32):
+ with T.block(""):
+ B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0
+ for j1_0 in T.thread_binding(T.int64(4), "threadIdx.x"):
+ for j1_1 in T.serial(T.int64(32)):
+ with T.block(""):
+ C[i, j1_0 * T.int64(32) + j1_1] = B[i, j1_0 * T.int64(32)
+ j1_1] + 1.0
+
+
[email protected]_func
+def unified_element_wise_thread_x_different_dtype(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+) -> None:
+ for blockIdx_x in T.thread_binding(128, "blockIdx.x"):
+ for threadIdx_x in T.thread_binding(4, "threadIdx.x"):
+ for j0_1 in T.serial(0, 32):
+ with T.block(""):
+ B[blockIdx_x, threadIdx_x * 32 + j0_1] = (
+ A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0
+ )
+ for j1_1 in T.serial(T.int64(32)):
+ with T.block(""):
+ C[blockIdx_x, T.cast(threadIdx_x, "int64") * T.int64(32) +
j1_1] = (
+ B[blockIdx_x, T.cast(threadIdx_x, "int64") *
T.int64(32) + j1_1] + 1.0
+ )
+
+
@T.prim_func
def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
j1_0 = T.env_thread("threadIdx.x")
@@ -223,6 +260,10 @@ def test_thread_x():
_check(element_wise_thread_x, unified_element_wise_thread_x)
+def test_thread_x_different_dtype():
+ _check(element_wise_thread_x_different_dtype,
unified_element_wise_thread_x_different_dtype)
+
+
def test_env_thread_x():
_check(element_wise_env_thread_x, unified_element_wise_env_thread_x)