This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0a3c736591 [Unity][FIX] fix thread dtype mismatch (#16443)
0a3c736591 is described below

commit 0a3c7365915908837a41f99fc45585fdf06d9a65
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun Jan 21 22:53:26 2024 +0800

    [Unity][FIX] fix thread dtype mismatch (#16443)
---
 python/tvm/topi/cuda/scatter_elements.py |  2 +-
 src/te/operation/op_utils.cc             | 10 ++++++----
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/python/tvm/topi/cuda/scatter_elements.py 
b/python/tvm/topi/cuda/scatter_elements.py
index 2f345b9d67..27567ea23e 100644
--- a/python/tvm/topi/cuda/scatter_elements.py
+++ b/python/tvm/topi/cuda/scatter_elements.py
@@ -168,7 +168,7 @@ def gen_ir(data, indices, updates, out, axis, reduce_func):
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
     # Copy initial input data to output
     with ib.new_scope():
-        num_blocks = ceil_div(full_range, max_threads)
+        num_blocks = cast(ceil_div(full_range, max_threads), "int32")
         bx = te.thread_axis("blockIdx.x")
         tx = te.thread_axis("threadIdx.x")
         ib.scope_attr(bx, "thread_extent", num_blocks)
diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc
index 1f386bc2dd..7168933a32 100644
--- a/src/te/operation/op_utils.cc
+++ b/src/te/operation/op_utils.cc
@@ -155,22 +155,24 @@ std::vector<std::vector<Stmt>> MakeLoopNest(const Stage& 
stage,
       ICHECK(is_zero(dom->min));
       ICHECK(is_positive_const(dom->extent));
       // annotate the extent of the IterVar
-      nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, 
dom->extent, no_op));
+      nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread,
+                                        cast(bind_iv->var.dtype(), 
dom->extent), no_op));
       value_map[iv] = promote_to_iv_dtype(var);
     } else if (bind_iv->thread_tag == "pipeline") {
       // pipeline marker.
       ICHECK(is_zero(dom->min));
       ICHECK(is_one(dom->extent));
       // annotate the extent of the IterVar
-      nest[i + 1].emplace_back(
-          AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, 
no_op));
+      nest[i + 1].emplace_back(AttrStmt(bind_iv, 
tir::attr::pipeline_exec_scope,
+                                        cast(bind_iv->var.dtype(), 
dom->extent), no_op));
       value_map[iv] = dom->min;
     } else {
       // Always restrict threaded IterVar to starts from 0.
       ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, 
but it starts at "
                                 << dom->min;
       // annotate the extent of the IterVar
-      nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, 
dom->extent, no_op));
+      nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent,
+                                        cast(bind_iv->var.dtype(), 
dom->extent), no_op));
       if (!debug_keep_trivial_loop && is_one(dom->extent)) {
         value_map[iv] = dom->min;
       } else if (stage->scope == "") {

Reply via email to