JosephTheOctonaut commented on a change in pull request #10066:
URL: https://github.com/apache/tvm/pull/10066#discussion_r792984666



##########
File path: tests/python/unittest/test_tir_transform_inject_software_pipeline.py
##########
@@ -0,0 +1,824 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import sys
+
+import tvm
+from tvm import tir, te, TVMError
+from tvm.script import tir as T
+
+
+def _check(original, transformed):
+    func = original
+    mod = tvm.IRModule.from_expr(func)
+    mod = tvm.tir.transform.InjectSoftwarePipeline()(mod)
+    mod = tvm.tir.transform.Simplify()(mod)
+    tvm.ir.assert_structural_equal(mod["main"], transformed, True)
+
+
+def _check_error(func):
+    mod = tvm.IRModule.from_expr(func)
+    with pytest.raises(ValueError):
+        tvm.tir.transform.InjectSoftwarePipeline()(mod)
+
+
[email protected]_func
+def trivial_pipeline(A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), 
"float32"]):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0, 1, annotations={"software_pipeline_stage": [0, 1], 
"software_pipeline_order": [0, 1]}
+        ):
+            with T.block():
+                T.reads(A[tx, i])
+                T.writes(C[tx, i])
+                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, i])
+                    T.writes(B[tx, 0])
+                    B[tx, 0] = A[tx, i] * T.float32(2)
+                with T.block():
+                    T.reads(B[tx, 0])
+                    T.writes(C[tx, i])
+                    C[tx, i] = B[tx, 0] + T.float32(1)
+
+
[email protected]_func
+def transformed_trivial_pipeline(
+    A: T.Buffer[(16, 1), "float32"], C: T.Buffer[(16, 1), "float32"]
+) -> None:
+    for tx in T.thread_binding(16, thread="threadIdx.x"):
+        with T.block():
+            T.reads(A[tx, 0])
+            T.writes(C[tx, 0])
+            B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+            with T.block():
+                T.reads(A[tx, 0])
+                T.writes(B[0, tx, 0])
+                B[0, tx, 0] = A[tx, 0] * T.float32(2)
+            with T.block():
+                T.reads()
+                T.writes()
+                T.evaluate(0)
+            with T.block():
+                T.reads(B[0, tx, 0])
+                T.writes(C[tx, 0])
+                C[tx, 0] = B[0, tx, 0] + T.float32(1)
+
+
[email protected]_func
+def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), 
"float32"]):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0,
+            16,
+            annotations={"software_pipeline_stage": [0, 1], 
"software_pipeline_order": [0, 1]},
+        ):
+            with T.block():
+                T.reads(A[tx, i])
+                T.writes(C[tx, i])
+                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+                with T.block():
+                    T.reads(A[tx, i])
+                    T.writes(B[tx, 0])
+                    B[tx, 0] = A[tx, i] * T.float32(2)
+                with T.block():
+                    T.reads(B[tx, 0])
+                    T.writes(C[tx, i])
+                    C[tx, i] = B[tx, 0] + T.float32(1)
+
+
[email protected]_func
+def transformed_simple_compute(
+    A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]
+) -> None:
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        with T.block():
+            T.reads([A[tx, 0:16]])
+            T.writes([C[tx, 0:16]])
+            B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
+            with T.block():
+                T.reads([A[tx, 0]])
+                T.writes([B[0, tx, 0]])
+                B[0, tx, 0] = A[tx, 0] * T.float32(2)
+            with T.block():
+                T.reads([A[tx, 1:16], B[0:2, tx, 0]])
+                T.writes([B[0:2, tx, 0], C[tx, 0:15]])
+                for i in T.serial(0, 15):
+                    with T.block():
+                        T.reads([A[tx, i + 1]])
+                        T.writes([B[(i + 1) % 2, tx, 0]])
+                        B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
+                    with T.block():
+                        T.reads([B[i % 2, tx, 0]])
+                        T.writes([C[tx, i]])
+                        C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
+            with T.block():
+                T.reads([B[1, tx, 0]])
+                T.writes([C[tx, 15]])
+                C[tx, 15] = B[1, tx, 0] + T.float32(1)
+
+
[email protected]_func
+def nested_pipeline_simple(
+    A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]
+):
+    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+        for i in T.serial(
+            0,
+            16,
+            annotations={
+                "software_pipeline_stage": [0, 1, 1, 1],
+                "software_pipeline_order": [0, 1, 2, 3],
+            },
+        ):
+            with T.block():
+                T.reads(A[tx, i, 0:16])
+                T.writes(C[tx, i, 0:16])
+                A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", 
scope="shared")
+                for j in T.serial(0, 16):

Review comment:
       In this example as well as the last two (`nested_pipeline_interleaving` 
and `nested_pipeline_double_buffer `), the same index variable is used in 
multiple loops. While not incorrect, it can make it hard to compare the pre- 
and post-transformed TIR because the variable (`j` in the examples) could 
belong to multiple source loops. It might make the mapping clearer to use all 
unique index vars.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to