This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c9fb87fd0e [TIR] Fix software pipeline with dynamic loop extent 
(#16027)
c9fb87fd0e is described below

commit c9fb87fd0e7498a8b29dec911ab0db7e5633601f
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Nov 3 00:26:34 2023 -0700

    [TIR] Fix software pipeline with dynamic loop extent (#16027)
---
 src/tir/transforms/inject_software_pipeline.cc     |  2 +-
 .../test_tir_transform_inject_software_pipeline.py | 72 ++++++++++++++++++++++
 2 files changed, 73 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/inject_software_pipeline.cc 
b/src/tir/transforms/inject_software_pipeline.cc
index a75bfdcddc..21de2d8607 100644
--- a/src/tir/transforms/inject_software_pipeline.cc
+++ b/src/tir/transforms/inject_software_pipeline.cc
@@ -801,7 +801,7 @@ class PipelineRewriter : public StmtExprMutator {
 
     auto make_nop = []() { return BlockRealize({}, Bool(true), 
MakeBlock(Evaluate(0), {})); };
 
-    if (!analyzer_.CanProve(extent > 0)) {
+    if (analyzer_.CanProve(extent <= 0)) {
       return make_nop();
     }
     bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
diff --git 
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py 
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 2a1ce2be28..a013cf0f65 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -153,6 +153,74 @@ def transformed_simple_compute(
                 C[tx, 15] = B[1, tx, 0] + T.float32(1)
 
 
[email protected]_func
+def dynamic_compute(a_handle: T.handle, c_handle: T.handle):
+    k = T.int32()
+    A = T.match_buffer(a_handle, (16, k), "float32")
+    C = T.match_buffer(c_handle, (16, k), "float32")
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0,
+            k,
+            annotations={
+                "software_pipeline_stage": [0, 1],
+                "software_pipeline_order": [0, 1],
+            },
+        ):
+            with T.block("compute"):
+                T.reads(A[tx, i])
+                T.writes(C[tx, i])
+                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, i])
+                    T.writes(B[tx, 0])
+                    B[tx, 0] = A[tx, i] * T.float32(2)
+                with T.block():
+                    T.reads(B[tx, 0])
+                    T.writes(C[tx, i])
+                    C[tx, i] = B[tx, 0] + T.float32(1)
+
+
[email protected]_func
+def transformed_dynamic_compute(a_handle: T.handle, c_handle: T.handle):
+    k = T.int32()
+    A = T.match_buffer(a_handle, (16, k), "float32")
+    C = T.match_buffer(c_handle, (16, k), "float32")
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        with T.block():
+            T.reads(A[tx, 0 : T.max(1, k)])
+            T.writes(C[tx, T.min(0, k - 1) : T.min(0, k - 1) + T.max(k, 1)])
+            B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+            with T.block(""):
+                T.reads(A[tx, 0])
+                T.writes(B[0, tx, 0])
+                with T.block(""):
+                    T.where(0 < k)
+                    T.reads(A[tx, 0])
+                    T.writes(B[0, tx, 0])
+                    B[0, tx, 0] = A[tx, 0] * T.float32(2)
+            with T.block(""):
+                T.reads(A[tx, 1 : 1 + (k - 1)], B[0:2, tx, 0])
+                T.writes(B[0:2, tx, 0], C[tx, 0 : k - 1])
+                for i in range(k - 1):
+                    with T.block(""):
+                        T.reads(A[tx, i + 1])
+                        T.writes(B[(i + 1) % 2, tx, 0])
+                        B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+                    with T.block(""):
+                        T.reads(B[i % 2, tx, 0])
+                        T.writes(C[tx, i])
+                        C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
+            with T.block(""):
+                T.reads(B[(k + 1) % 2, tx, 0])
+                T.writes(C[tx, k - 1])
+                with T.block(""):
+                    T.where(1 <= k)
+                    T.reads(B[(k + 1) % 2, tx, 0])
+                    T.writes(C[tx, k - 1])
+                    C[tx, k - 1] = B[(k + 1) % 2, tx, 0] + T.float32(1)
+
+
 @T.prim_func
 def simple_compute_with_other_annotation(
     A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")
@@ -1069,6 +1137,10 @@ def test_simple_compute_with_other_annotation():
     _check(simple_compute_with_other_annotation, 
transformed_simple_compute_with_other_annotation)
 
 
+def test_dynamic_compute():
+    _check(dynamic_compute, transformed_dynamic_compute)
+
+
 def test_trivial_pipeline():
     _check(trivial_pipeline, transformed_trivial_pipeline)
 

Reply via email to