Rainy-Memory commented on code in PR #13966:
URL: https://github.com/apache/tvm/pull/13966#discussion_r1105283441


##########
tests/python/unittest/test_cp_async_in_if_then_else.py:
##########
@@ -0,0 +1,304 @@
+import tvm
+import numpy as np
+from tvm.script import tir as T
+
+
[email protected]_module
+class Module:
+    @T.prim_func
+    def main(
+        A: T.Buffer[(1012, 1014), "float32"],
+        B: T.Buffer[(1014, 1017), "float32"],
+        Y: T.Buffer[(1012, 1017), "float32"],
+    ):
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="shared")
+        B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="shared")
+        A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        for ax0_0_ax1_0_fused in T.thread_binding(
+            128,
+            thread="blockIdx.x",
+            annotations={"pragma_auto_unroll_max_step": 1024, 
"pragma_unroll_explicit": 1},
+        ):
+            for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"):
+                for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in 
T.thread_binding(
+                    64, thread="threadIdx.x"
+                ):
+                    for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in 
T.grid(4, 4, 2, 1):
+                        with T.block("Y_init"):
+                            v0 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0_3_init * 2
+                                + ax0_4_init,
+                            )
+                            v1 = T.axis.spatial(
+                                1024,
+                                ax1_4_init
+                                + ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1_3_init,
+                            )
+                            T.reads()
+                            T.writes(Y_reindex_local[v0, v1])
+                            T.block_attr(
+                                {
+                                    
"meta_schedule.thread_extent_high_inclusive": 1024,
+                                    
"meta_schedule.thread_extent_low_inclusive": 32,
+                                    "meta_schedule.tiling_structure": 
"SSSRRSRS",
+                                }
+                            )
+                            Y_reindex_local[v0, v1] = T.float32(0)
+                    for ax2_0_fused in T.serial(
+                        256,
+                        annotations={
+                            "software_pipeline_async_stages": [0, 1],
+                            "software_pipeline_order": [0, 1, 3, 2, 4],
+                            "software_pipeline_stage": [0, 0, 2, 3, 3],
+                        },
+                    ):
+                        for ax0_ax1_fused_0 in T.serial(4):
+                            for ax0_ax1_fused_1 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                with T.block("A_reindex_shared"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused // 8 * 64
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) // 4,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax2_0_fused * 4
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) % 4,
+                                    )
+                                    T.reads(A[v0, v1])
+                                    T.writes(
+                                        A_reindex_shared[
+                                            v1,
+                                            v0 // 32 * 32
+                                            + v0 % 8 // 4 * 16
+                                            + v0 % 32 // 8 * 4
+                                            + v0 % 4,
+                                        ]
+                                    )
+                                    A_reindex_shared[
+                                        v1,
+                                        v0 // 32 * 32
+                                        + v0 % 8 // 4 * 16
+                                        + v0 % 32 // 8 * 4
+                                        + v0 % 4,
+                                    ] = T.if_then_else(
+                                        v0 < 1012 and v1 < 1014,
+                                        A[v0, v1],
+                                        T.float32(0),
+                                        dtype="float32",
+                                    )
+                        for ax0_ax1_fused_0 in T.serial(8):
+                            for ax0_ax1_fused_1 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                with T.block("B_reindex_shared"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax2_0_fused * 4
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) // 128,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused % 8 * 128
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) % 128,
+                                    )
+                                    T.reads(B[v0, v1])
+                                    T.writes(
+                                        B_reindex_shared[
+                                            v0,
+                                            v1 // 64 * 64
+                                            + v1 % 8 // 4 * 32
+                                            + v1 % 64 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                                    )
+                                    B_reindex_shared[
+                                        v0,
+                                        v1 // 64 * 64
+                                        + v1 % 8 // 4 * 32
+                                        + v1 % 64 // 8 * 4
+                                        + v1 % 4,
+                                    ] = T.if_then_else(
+                                        v0 < 1014 and v1 < 1017,
+                                        B[v0, v1],
+                                        T.float32(0),
+                                        dtype="float32",
+                                    )
+                        for ax2_1_fused in T.unroll(
+                            4,
+                            annotations={
+                                "software_pipeline_order": [0, 1, 2],
+                                "software_pipeline_stage": [0, 0, 1],
+                            },
+                        ):
+                            for ax0_ax1_fused_0 in T.unroll(2):
+                                for ax0_ax1_fused_1 in T.vectorized(4):
+                                    with T.block("A_reindex_shared_local"):
+                                        v0 = T.axis.spatial(1024, ax2_0_fused 
* 4 + ax2_1_fused)
+                                        v1 = T.axis.spatial(
+                                            1024,
+                                            ax0_0_ax1_0_fused // 8 * 64
+                                            + ax0_1_ax1_1_fused // 2 * 32
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                            // 32
+                                            * 16
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                            + ax0_ax1_fused_0 * 4
+                                            + ax0_ax1_fused_1,
+                                        )
+                                        T.reads(
+                                            A_reindex_shared[
+                                                v0,
+                                                v1 // 32 * 32
+                                                + v1 % 8 // 4 * 16
+                                                + v1 % 32 // 8 * 4
+                                                + v1 % 4,
+                                            ]
+                                        )
+                                        T.writes(A_reindex_shared_local[v0, 
v1])
+                                        A_reindex_shared_local[v0, v1] = 
A_reindex_shared[
+                                            v0,
+                                            v1 // 32 * 32
+                                            + v1 % 8 // 4 * 16
+                                            + v1 % 32 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                            for ax0_ax1_fused_0 in T.unroll(2):
+                                for ax0_ax1_fused_1 in T.vectorized(2):
+                                    with T.block("B_reindex_shared_local"):
+                                        v0 = T.axis.spatial(1024, ax2_0_fused 
* 4 + ax2_1_fused)
+                                        v1 = T.axis.spatial(
+                                            1024,
+                                            ax0_0_ax1_0_fused % 8 * 128
+                                            + ax0_1_ax1_1_fused % 2 * 64
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                            % 32
+                                            // 2
+                                            * 4
+                                            + ax0_ax1_fused_0 * 2
+                                            + ax0_ax1_fused_1,
+                                        )
+                                        T.reads(
+                                            B_reindex_shared[
+                                                v0,
+                                                v1 // 64 * 64
+                                                + v1 % 8 // 4 * 32
+                                                + v1 % 64 // 8 * 4
+                                                + v1 % 4,
+                                            ]
+                                        )
+                                        T.writes(B_reindex_shared_local[v0, 
v1])
+                                        B_reindex_shared_local[v0, v1] = 
B_reindex_shared[
+                                            v0,
+                                            v1 // 64 * 64
+                                            + v1 % 8 // 4 * 32
+                                            + v1 % 64 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                            for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4, 
4, 1, 2, 1):
+                                with T.block("Y_update"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused // 8 * 64
+                                        + ax0_1_ax1_1_fused // 2 * 32
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                        + ax0_3 * 2
+                                        + ax0_4,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax1_4
+                                        + ax0_0_ax1_0_fused % 8 * 128
+                                        + ax0_1_ax1_1_fused % 2 * 64
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                        % 32
+                                        // 2
+                                        * 4
+                                        + ax1_3,
+                                    )
+                                    v2 = T.axis.reduce(1024, ax2_0_fused * 4 + 
ax2_1_fused + ax2_2)
+                                    T.reads(
+                                        Y_reindex_local[v0, v1],
+                                        A_reindex_shared_local[v2, v0],
+                                        B_reindex_shared_local[v2, v1],
+                                    )
+                                    T.writes(Y_reindex_local[v0, v1])
+                                    T.block_attr(
+                                        {
+                                            
"meta_schedule.thread_extent_high_inclusive": 1024,
+                                            
"meta_schedule.thread_extent_low_inclusive": 32,
+                                            "meta_schedule.tiling_structure": 
"SSSRRSRS",
+                                        }
+                                    )
+                                    Y_reindex_local[v0, v1] = (
+                                        Y_reindex_local[v0, v1]
+                                        + A_reindex_shared_local[v2, v0]
+                                        * B_reindex_shared_local[v2, v1]
+                                    )
+                    for ax0, ax1 in T.grid(8, 4):
+                        with T.block("Y_reindex_local"):
+                            T.where(
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0
+                                < 1012
+                                and ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1
+                                < 1017
+                            )
+                            v0 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0,
+                            )
+                            v1 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1,
+                            )
+                            T.reads(Y_reindex_local[v0, v1])
+                            T.writes(Y[v0, v1])
+                            Y[v0, v1] = Y_reindex_local[v0, v1]
+
+
+def test_matmul():
+    with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
+        rt_mod = tvm.build(Module, target="cuda")

Review Comment:
   now compute version check is added to the unittest



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to