This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 39f2482580 [Fix] Fix SSA conversion for SizeVar retention (#16924)
39f2482580 is described below
commit 39f2482580b57fa5b1f6c1a1dc0e6f5e823ee4c0
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Apr 25 08:11:46 2024 -0400
[Fix] Fix SSA conversion for SizeVar retention (#16924)
This PR fixes the var construction in IRConvertSSA, which always casts
SizeVar to Var. This behavior leads to expr not being able to get
simplified in the LowerIntrin pass later on. Specifically, if not using
SizeVar, the LowerIntrin pass loses the information of the non-negative
var information, and cannot simply a bunch of FloorDiv/FloorMod
expressions.
One regression test for SplitHostDevice is added to ensure the retention
of SizeVar. Adding the test in SplitHostDevice because this is where
the SSA conversion is used.
---
src/tir/transforms/ir_utils.cc | 13 +++++++++--
.../test_tir_transform_split_host_device.py | 25 ++++++++++++++++++++--
2 files changed, 34 insertions(+), 4 deletions(-)
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 584b3cbf58..c52027acba 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -435,10 +435,19 @@ class IRConvertSSA final : public StmtExprMutator {
private:
struct ScopedRedefine {
ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent),
old_var(old_var) {
+ bool is_size_var = old_var->IsInstance<SizeVarNode>();
if (old_var->type_annotation.defined()) {
- new_var = Var(old_var->name_hint, old_var->type_annotation);
+ if (is_size_var) {
+ new_var = SizeVar(old_var->name_hint, old_var->type_annotation);
+ } else {
+ new_var = Var(old_var->name_hint, old_var->type_annotation);
+ }
} else {
- new_var = Var(old_var->name_hint, old_var->dtype);
+ if (is_size_var) {
+ new_var = SizeVar(old_var->name_hint, old_var->dtype);
+ } else {
+ new_var = Var(old_var->name_hint, old_var->dtype);
+ }
}
parent->scope_[old_var.get()].push_back(new_var);
}
diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py
b/tests/python/tir-transform/test_tir_transform_split_host_device.py
index 6adfbeb81d..2d0d8a68d8 100644
--- a/tests/python/tir-transform/test_tir_transform_split_host_device.py
+++ b/tests/python/tir-transform/test_tir_transform_split_host_device.py
@@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te
import tvm.testing
-from tvm.script import tir as T, ir as I
+from tvm import te
+from tvm.script import ir as I
+from tvm.script import tir as T
@tvm.testing.requires_cuda
@@ -345,5 +346,25 @@ def test_dynamic_launch_thread():
tvm.ir.assert_structural_equal(expected, after)
+def test_size_var():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(var_A: T.handle, var_B: T.handle):
+ T.func_attr({"target": T.target("cuda")})
+ m = T.int64(is_size_var=True)
+ A = T.match_buffer(var_A, (m,))
+ B = T.match_buffer(var_B, (m,))
+ T.attr(T.target("cuda"), "target", 0)
+ blockIdx_x = T.launch_thread("blockIdx.x", m)
+ B_1 = T.Buffer((m,), data=B.data)
+ A_1 = T.Buffer((m,), data=A.data)
+ B_1[blockIdx_x] = A_1[blockIdx_x]
+
+ after = tvm.tir.transform.SplitHostDevice()(Module)
+ assert len(after["main_kernel"].params) == 3
+ assert isinstance(after["main_kernel"].params[2], tvm.tir.SizeVar)
+
+
if __name__ == "__main__":
tvm.testing.main()