This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c9fb87fd0e [TIR] Fix software pipeline with dynamic loop extent
(#16027)
c9fb87fd0e is described below
commit c9fb87fd0e7498a8b29dec911ab0db7e5633601f
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Nov 3 00:26:34 2023 -0700
[TIR] Fix software pipeline with dynamic loop extent (#16027)
---
src/tir/transforms/inject_software_pipeline.cc | 2 +-
.../test_tir_transform_inject_software_pipeline.py | 72 ++++++++++++++++++++++
2 files changed, 73 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/inject_software_pipeline.cc
b/src/tir/transforms/inject_software_pipeline.cc
index a75bfdcddc..21de2d8607 100644
--- a/src/tir/transforms/inject_software_pipeline.cc
+++ b/src/tir/transforms/inject_software_pipeline.cc
@@ -801,7 +801,7 @@ class PipelineRewriter : public StmtExprMutator {
auto make_nop = []() { return BlockRealize({}, Bool(true),
MakeBlock(Evaluate(0), {})); };
- if (!analyzer_.CanProve(extent > 0)) {
+ if (analyzer_.CanProve(extent <= 0)) {
return make_nop();
}
bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
diff --git
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index 2a1ce2be28..a013cf0f65 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -153,6 +153,74 @@ def transformed_simple_compute(
C[tx, 15] = B[1, tx, 0] + T.float32(1)
[email protected]_func
+def dynamic_compute(a_handle: T.handle, c_handle: T.handle):
+ k = T.int32()
+ A = T.match_buffer(a_handle, (16, k), "float32")
+ C = T.match_buffer(c_handle, (16, k), "float32")
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ for i in T.serial(
+ 0,
+ k,
+ annotations={
+ "software_pipeline_stage": [0, 1],
+ "software_pipeline_order": [0, 1],
+ },
+ ):
+ with T.block("compute"):
+ T.reads(A[tx, i])
+ T.writes(C[tx, i])
+ B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+ with T.block():
+ T.reads(A[tx, i])
+ T.writes(B[tx, 0])
+ B[tx, 0] = A[tx, i] * T.float32(2)
+ with T.block():
+ T.reads(B[tx, 0])
+ T.writes(C[tx, i])
+ C[tx, i] = B[tx, 0] + T.float32(1)
+
+
[email protected]_func
+def transformed_dynamic_compute(a_handle: T.handle, c_handle: T.handle):
+ k = T.int32()
+ A = T.match_buffer(a_handle, (16, k), "float32")
+ C = T.match_buffer(c_handle, (16, k), "float32")
+ for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ with T.block():
+ T.reads(A[tx, 0 : T.max(1, k)])
+ T.writes(C[tx, T.min(0, k - 1) : T.min(0, k - 1) + T.max(k, 1)])
+ B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+ with T.block(""):
+ T.reads(A[tx, 0])
+ T.writes(B[0, tx, 0])
+ with T.block(""):
+ T.where(0 < k)
+ T.reads(A[tx, 0])
+ T.writes(B[0, tx, 0])
+ B[0, tx, 0] = A[tx, 0] * T.float32(2)
+ with T.block(""):
+ T.reads(A[tx, 1 : 1 + (k - 1)], B[0:2, tx, 0])
+ T.writes(B[0:2, tx, 0], C[tx, 0 : k - 1])
+ for i in range(k - 1):
+ with T.block(""):
+ T.reads(A[tx, i + 1])
+ T.writes(B[(i + 1) % 2, tx, 0])
+ B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+ with T.block(""):
+ T.reads(B[i % 2, tx, 0])
+ T.writes(C[tx, i])
+ C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
+ with T.block(""):
+ T.reads(B[(k + 1) % 2, tx, 0])
+ T.writes(C[tx, k - 1])
+ with T.block(""):
+ T.where(1 <= k)
+ T.reads(B[(k + 1) % 2, tx, 0])
+ T.writes(C[tx, k - 1])
+ C[tx, k - 1] = B[(k + 1) % 2, tx, 0] + T.float32(1)
+
+
@T.prim_func
def simple_compute_with_other_annotation(
A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")
@@ -1069,6 +1137,10 @@ def test_simple_compute_with_other_annotation():
_check(simple_compute_with_other_annotation,
transformed_simple_compute_with_other_annotation)
+def test_dynamic_compute():
+ _check(dynamic_compute, transformed_dynamic_compute)
+
+
def test_trivial_pipeline():
_check(trivial_pipeline, transformed_trivial_pipeline)