This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0a3c736591 [Unity][FIX] fix thread dtype mismatch (#16443)
0a3c736591 is described below
commit 0a3c7365915908837a41f99fc45585fdf06d9a65
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun Jan 21 22:53:26 2024 +0800
[Unity][FIX] fix thread dtype mismatch (#16443)
---
python/tvm/topi/cuda/scatter_elements.py | 2 +-
src/te/operation/op_utils.cc | 10 ++++++----
2 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/python/tvm/topi/cuda/scatter_elements.py
b/python/tvm/topi/cuda/scatter_elements.py
index 2f345b9d67..27567ea23e 100644
--- a/python/tvm/topi/cuda/scatter_elements.py
+++ b/python/tvm/topi/cuda/scatter_elements.py
@@ -168,7 +168,7 @@ def gen_ir(data, indices, updates, out, axis, reduce_func):
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
# Copy initial input data to output
with ib.new_scope():
- num_blocks = ceil_div(full_range, max_threads)
+ num_blocks = cast(ceil_div(full_range, max_threads), "int32")
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", num_blocks)
diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc
index 1f386bc2dd..7168933a32 100644
--- a/src/te/operation/op_utils.cc
+++ b/src/te/operation/op_utils.cc
@@ -155,22 +155,24 @@ std::vector<std::vector<Stmt>> MakeLoopNest(const Stage&
stage,
ICHECK(is_zero(dom->min));
ICHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
- nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread,
dom->extent, no_op));
+ nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread,
+ cast(bind_iv->var.dtype(),
dom->extent), no_op));
value_map[iv] = promote_to_iv_dtype(var);
} else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker.
ICHECK(is_zero(dom->min));
ICHECK(is_one(dom->extent));
// annotate the extent of the IterVar
- nest[i + 1].emplace_back(
- AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent,
no_op));
+ nest[i + 1].emplace_back(AttrStmt(bind_iv,
tir::attr::pipeline_exec_scope,
+ cast(bind_iv->var.dtype(),
dom->extent), no_op));
value_map[iv] = dom->min;
} else {
// Always restrict threaded IterVar to starts from 0.
ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero,
but it starts at "
<< dom->min;
// annotate the extent of the IterVar
- nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent,
dom->extent, no_op));
+ nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent,
+ cast(bind_iv->var.dtype(),
dom->extent), no_op));
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
} else if (stage->scope == "") {