This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 40af75b61f [Fix][TIR] UnifyThreadBinding creating unit loop with
annotation (#14588)
40af75b61f is described below
commit 40af75b61ff7111b479e447714db63225609cbb5
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Apr 11 20:18:35 2023 -0400
[Fix][TIR] UnifyThreadBinding creating unit loop with annotation (#14588)
This PR fixes a behavior of the UnifyThreadBinding pass which (at one
place) assumes a return value is always a ForNode, which is not right.
To be more specific, when a thread-binding loop has an annotation,
the current behavior is assuming that the post-recursive-mutation value
is also a ForNode, and apply the previous annotation directly to the new
loop. However, the post-recursive-mutation value is also possibly not a
ForNode. In this case, the current behavior is incorrect.
This PR creates a new unit-length loop in this case to preserve the
annotation.
Thanks Bohan for catching this issue.
Co-authored-by: Bohan Hou <[email protected]>
---
src/tir/transforms/unify_thread_binding.cc | 17 ++++++++++++---
.../test_tir_transform_unify_thread_binding.py | 25 ++++++++++++++++++++++
2 files changed, 39 insertions(+), 3 deletions(-)
diff --git a/src/tir/transforms/unify_thread_binding.cc
b/src/tir/transforms/unify_thread_binding.cc
index da725f7f8e..09b0970dd3 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -64,9 +64,20 @@ class ThreadBindingUnifier : public StmtExprMutator {
if (annotations.empty()) {
return stmt;
}
- For new_loop = Downcast<For>(stmt);
- new_loop.CopyOnWrite()->annotations = std::move(annotations);
- return std::move(new_loop);
+ if (const auto* loop = stmt.as<ForNode>()) {
+ For new_loop = GetRef<For>(loop);
+ new_loop.CopyOnWrite()->annotations = std::move(annotations);
+ return std::move(new_loop);
+ } else {
+ // Create a new unit loop with the annotation.
+ DataType dtype = op->loop_var->dtype;
+ return For(/*loop_var=*/Var("var", dtype), //
+ /*min=*/IntImm(dtype, 0), //
+ /*extent=*/IntImm(dtype, 1), //
+ /*kind=*/ForKind::kSerial, stmt, //
+ /*thread_binding=*/NullOpt, //
+ /*annotation=*/std::move(annotations));
+ }
}
template <typename Node>
diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
index e489298741..c49ea5e60f 100644
--- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
+++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
@@ -286,6 +286,31 @@ def test_implicit_block():
_check(element_wise_implicit_block, unified_element_wise_implicit_block)
+def test_inner_binding_with_annotation():
+ @T.prim_func
+ def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B:
T.Buffer((64,), "float32")):
+ for bx in T.thread_binding(32, "blockIdx.x"):
+ for tx in T.thread_binding(2, "threadIdx.x",
annotations={"my_annotation": 1}):
+ with T.block("block"):
+ v = T.axis.spatial(64, bx * 2 + tx)
+ B[v] = A[v]
+
+ @T.prim_func
+ def unified_inner_binding_with_annotation(
+ A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")
+ ):
+ for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
+ for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"):
+ for var in T.serial(1, annotations={"my_annotation": 1}):
+ with T.block("block"):
+ v = T.axis.spatial(64, blockIdx_x * 2 + threadIdx_x)
+ T.reads(A[v])
+ T.writes(B[v])
+ B[v] = A[v]
+
+ _check(inner_binding_with_annotation,
unified_inner_binding_with_annotation)
+
+
def test_lower_te():
a = te.placeholder((32, 2, 2))
b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0)