This is an automated email from the ASF dual-hosted git repository. syfeng pushed a commit to branch test_all_cases_on_unity in repository https://gitbox.apache.org/repos/asf/tvm.git
commit aa29f52bf8679a067eb187c698ea22f15e268f76 Author: Siyuan Feng <[email protected]> AuthorDate: Sat Jan 20 22:44:48 2024 +0800 fix thread dtype mismatch fix thread dtype mismatch --- python/tvm/topi/cuda/scatter_elements.py | 2 +- src/te/operation/op_utils.cc | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 2f345b9d67..27567ea23e 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -168,7 +168,7 @@ def gen_ir(data, indices, updates, out, axis, reduce_func): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) # Copy initial input data to output with ib.new_scope(): - num_blocks = ceil_div(full_range, max_threads) + num_blocks = cast(ceil_div(full_range, max_threads), "int32") bx = te.thread_axis("blockIdx.x") tx = te.thread_axis("threadIdx.x") ib.scope_attr(bx, "thread_extent", num_blocks) diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index 1f386bc2dd..7168933a32 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -155,22 +155,24 @@ std::vector<std::vector<Stmt>> MakeLoopNest(const Stage& stage, ICHECK(is_zero(dom->min)); ICHECK(is_positive_const(dom->extent)); // annotate the extent of the IterVar - nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, + cast(bind_iv->var.dtype(), dom->extent), no_op)); value_map[iv] = promote_to_iv_dtype(var); } else if (bind_iv->thread_tag == "pipeline") { // pipeline marker. ICHECK(is_zero(dom->min)); ICHECK(is_one(dom->extent)); // annotate the extent of the IterVar - nest[i + 1].emplace_back( - AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, + cast(bind_iv->var.dtype(), dom->extent), no_op)); value_map[iv] = dom->min; } else { // Always restrict threaded IterVar to starts from 0. ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, but it starts at " << dom->min; // annotate the extent of the IterVar - nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, + cast(bind_iv->var.dtype(), dom->extent), no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else if (stage->scope == "") {
