This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4f8c03fad3 [TVMScript] Support `T.launch_thread` with i64 dtype
(#16916)
4f8c03fad3 is described below
commit 4f8c03fad393c360008f1fb208f117c66c04090c
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Apr 24 20:44:46 2024 +0800
[TVMScript] Support `T.launch_thread` with i64 dtype (#16916)
This PR fixes the bug of mismatched dtype in `T.launch_thread` when the
dtype is `i64`.
---
include/tvm/script/ir_builder/tir/ir.h | 3 ++-
python/tvm/script/ir_builder/tir/ir.py | 7 +++++--
src/script/ir_builder/tir/ir.cc | 10 +++++-----
.../test_tir_transform_inject_ptx_async_copy.py | 4 ++--
tests/python/tvmscript/test_tvmscript_parser_tir.py | 15 +++++++++++++++
5 files changed, 29 insertions(+), 10 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index c4ba44f673..5b44f79ad7 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -401,9 +401,10 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr
extent);
/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
+ * \param dtype The data type of the variable.
* \return The result variable which gets bound to the thread env.
*/
-Var EnvThread(String thread_tag);
+Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));
/*!
* \brief Store data in a buffer.
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 127d2a4356..c04ac780c9 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1241,7 +1241,7 @@ def launch_thread(
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined]
# pylint: disable=no-member
-def env_thread(thread_tag: str) -> IterVar:
+def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar:
"""Bind a var to thread env
Parameters
@@ -1249,13 +1249,16 @@ def env_thread(thread_tag: str) -> IterVar:
thread_tag : str
The thread type tag.
+ dtype : str
+ The data type of the thread env.
+
Returns
-------
res : IterVar
The result iteration variable gets bound to the thread env.
"""
- return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] #
pylint: disable=no-member
+ return _ffi_api.EnvThread(thread_tag, dtype) # type: ignore[attr-defined]
# pylint: disable=no-member
def buffer_store(
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index ccb5a8b57b..3ce5c15e6c 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -432,7 +432,8 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
}
ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
if (!iter_var->dom.defined()) {
- const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
+ const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom =
+ Range(tvm::tir::make_zero(extent.dtype()), extent);
} else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
<< iter_var->dom->extent << " vs " << extent;
@@ -444,7 +445,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
}
LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
- return LaunchThread(EnvThread(thread_tag), extent);
+ return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent);
}
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
@@ -512,9 +513,8 @@ ElseFrame Else() {
return ElseFrame(n);
}
-Var EnvThread(String thread_tag) {
- IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)),
tvm::tir::IterVarType::kThreadIndex,
- thread_tag);
+Var EnvThread(String thread_tag, DataType dtype) {
+ IterVar iter_var(Range{nullptr}, Var("", dtype),
tvm::tir::IterVarType::kThreadIndex, thread_tag);
Var var = iter_var->var;
if (Optional<PrimFuncFrame> opt_frame =
IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
opt_frame.value()->env_threads.Set(var, iter_var);
diff --git
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index 4c94dc04cc..c160e4a31d 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -969,9 +969,9 @@ class
TestMultiplicationNodesAreInligned(tvm.testing.CompareBeforeAfter):
T.ptx_cp_async(
"float16",
A_shared.data,
- T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
+ tx * T.int64(128) + cse_var_1 * T.int64(8),
A.data,
- T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
+ tx * T.int64(128) + cse_var_1 * T.int64(8),
16,
)
T.ptx_commit_group()
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 530746a6fc..25a904a157 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -471,5 +471,20 @@ def test_reinterpret_nop():
tvm.ir.assert_structural_equal(func, expected)
+def test_launch_thread_i64():
+ """Test launching thread with int64"""
+
+ @T.prim_func
+ def func() -> None:
+ blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1))
+ if blockIdx_x == T.int64(0):
+ T.evaluate(T.int64(0))
+ else:
+ T.evaluate(T.int64(1))
+
+ assert func.body.node.dom.min.dtype == "int64"
+ assert func.body.node.dom.extent.dtype == "int64"
+
+
if __name__ == "__main__":
tvm.testing.main()