junrushao commented on code in PR #13966:
URL: https://github.com/apache/tvm/pull/13966#discussion_r1105276027


##########
src/target/source/codegen_cuda.cc:
##########
@@ -914,7 +914,12 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     std::string src = this->PrintExpr(op->args[2]);
     std::string src_offset = this->PrintExpr(op->args[3]);
     std::string size = this->PrintExpr(op->args[4]);
-    this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, 
size);
+    // use size of argument list to indicate whether or not to use predicated 
cp.async
+    if (op->args.size() == 5)
+      this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, 
size);
+    else
+      this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, 
src_offset,
+                                                     size, 
this->PrintExpr(op->args[5]));

Review Comment:
   nit: always use `{}`
   
   ```suggestion
       if (op->args.size() == 5) {
         this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, 
size);
       }
       else {
         this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, 
src_offset,
                                                        size, 
this->PrintExpr(op->args[5]));
       }
   ```



##########
tests/python/unittest/test_cp_async_in_if_then_else.py:
##########
@@ -0,0 +1,304 @@
+import tvm
+import numpy as np
+from tvm.script import tir as T
+
+
[email protected]_module
+class Module:
+    @T.prim_func
+    def main(
+        A: T.Buffer[(1012, 1014), "float32"],
+        B: T.Buffer[(1014, 1017), "float32"],
+        Y: T.Buffer[(1012, 1017), "float32"],
+    ):
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="shared")
+        B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="shared")
+        A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        for ax0_0_ax1_0_fused in T.thread_binding(
+            128,
+            thread="blockIdx.x",
+            annotations={"pragma_auto_unroll_max_step": 1024, 
"pragma_unroll_explicit": 1},
+        ):
+            for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"):
+                for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in 
T.thread_binding(
+                    64, thread="threadIdx.x"
+                ):
+                    for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in 
T.grid(4, 4, 2, 1):
+                        with T.block("Y_init"):
+                            v0 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0_3_init * 2
+                                + ax0_4_init,
+                            )
+                            v1 = T.axis.spatial(
+                                1024,
+                                ax1_4_init
+                                + ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1_3_init,
+                            )
+                            T.reads()
+                            T.writes(Y_reindex_local[v0, v1])
+                            T.block_attr(
+                                {
+                                    
"meta_schedule.thread_extent_high_inclusive": 1024,
+                                    
"meta_schedule.thread_extent_low_inclusive": 32,
+                                    "meta_schedule.tiling_structure": 
"SSSRRSRS",
+                                }
+                            )
+                            Y_reindex_local[v0, v1] = T.float32(0)
+                    for ax2_0_fused in T.serial(
+                        256,
+                        annotations={
+                            "software_pipeline_async_stages": [0, 1],
+                            "software_pipeline_order": [0, 1, 3, 2, 4],
+                            "software_pipeline_stage": [0, 0, 2, 3, 3],
+                        },
+                    ):
+                        for ax0_ax1_fused_0 in T.serial(4):
+                            for ax0_ax1_fused_1 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                with T.block("A_reindex_shared"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused // 8 * 64
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) // 4,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax2_0_fused * 4
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) % 4,
+                                    )
+                                    T.reads(A[v0, v1])
+                                    T.writes(
+                                        A_reindex_shared[
+                                            v1,
+                                            v0 // 32 * 32
+                                            + v0 % 8 // 4 * 16
+                                            + v0 % 32 // 8 * 4
+                                            + v0 % 4,
+                                        ]
+                                    )
+                                    A_reindex_shared[
+                                        v1,
+                                        v0 // 32 * 32
+                                        + v0 % 8 // 4 * 16
+                                        + v0 % 32 // 8 * 4
+                                        + v0 % 4,
+                                    ] = T.if_then_else(
+                                        v0 < 1012 and v1 < 1014,
+                                        A[v0, v1],
+                                        T.float32(0),
+                                        dtype="float32",
+                                    )
+                        for ax0_ax1_fused_0 in T.serial(8):
+                            for ax0_ax1_fused_1 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                with T.block("B_reindex_shared"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax2_0_fused * 4
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) // 128,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused % 8 * 128
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) % 128,
+                                    )
+                                    T.reads(B[v0, v1])
+                                    T.writes(
+                                        B_reindex_shared[
+                                            v0,
+                                            v1 // 64 * 64
+                                            + v1 % 8 // 4 * 32
+                                            + v1 % 64 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                                    )
+                                    B_reindex_shared[
+                                        v0,
+                                        v1 // 64 * 64
+                                        + v1 % 8 // 4 * 32
+                                        + v1 % 64 // 8 * 4
+                                        + v1 % 4,
+                                    ] = T.if_then_else(
+                                        v0 < 1014 and v1 < 1017,
+                                        B[v0, v1],
+                                        T.float32(0),
+                                        dtype="float32",
+                                    )
+                        for ax2_1_fused in T.unroll(
+                            4,
+                            annotations={
+                                "software_pipeline_order": [0, 1, 2],
+                                "software_pipeline_stage": [0, 0, 1],
+                            },
+                        ):
+                            for ax0_ax1_fused_0 in T.unroll(2):
+                                for ax0_ax1_fused_1 in T.vectorized(4):
+                                    with T.block("A_reindex_shared_local"):
+                                        v0 = T.axis.spatial(1024, ax2_0_fused 
* 4 + ax2_1_fused)
+                                        v1 = T.axis.spatial(
+                                            1024,
+                                            ax0_0_ax1_0_fused // 8 * 64
+                                            + ax0_1_ax1_1_fused // 2 * 32
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                            // 32
+                                            * 16
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                            + ax0_ax1_fused_0 * 4
+                                            + ax0_ax1_fused_1,
+                                        )
+                                        T.reads(
+                                            A_reindex_shared[
+                                                v0,
+                                                v1 // 32 * 32
+                                                + v1 % 8 // 4 * 16
+                                                + v1 % 32 // 8 * 4
+                                                + v1 % 4,
+                                            ]
+                                        )
+                                        T.writes(A_reindex_shared_local[v0, 
v1])
+                                        A_reindex_shared_local[v0, v1] = 
A_reindex_shared[
+                                            v0,
+                                            v1 // 32 * 32
+                                            + v1 % 8 // 4 * 16
+                                            + v1 % 32 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                            for ax0_ax1_fused_0 in T.unroll(2):
+                                for ax0_ax1_fused_1 in T.vectorized(2):
+                                    with T.block("B_reindex_shared_local"):
+                                        v0 = T.axis.spatial(1024, ax2_0_fused 
* 4 + ax2_1_fused)
+                                        v1 = T.axis.spatial(
+                                            1024,
+                                            ax0_0_ax1_0_fused % 8 * 128
+                                            + ax0_1_ax1_1_fused % 2 * 64
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                            % 32
+                                            // 2
+                                            * 4
+                                            + ax0_ax1_fused_0 * 2
+                                            + ax0_ax1_fused_1,
+                                        )
+                                        T.reads(
+                                            B_reindex_shared[
+                                                v0,
+                                                v1 // 64 * 64
+                                                + v1 % 8 // 4 * 32
+                                                + v1 % 64 // 8 * 4
+                                                + v1 % 4,
+                                            ]
+                                        )
+                                        T.writes(B_reindex_shared_local[v0, 
v1])
+                                        B_reindex_shared_local[v0, v1] = 
B_reindex_shared[
+                                            v0,
+                                            v1 // 64 * 64
+                                            + v1 % 8 // 4 * 32
+                                            + v1 % 64 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                            for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4, 
4, 1, 2, 1):
+                                with T.block("Y_update"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused // 8 * 64
+                                        + ax0_1_ax1_1_fused // 2 * 32
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                        + ax0_3 * 2
+                                        + ax0_4,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax1_4
+                                        + ax0_0_ax1_0_fused % 8 * 128
+                                        + ax0_1_ax1_1_fused % 2 * 64
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                        % 32
+                                        // 2
+                                        * 4
+                                        + ax1_3,
+                                    )
+                                    v2 = T.axis.reduce(1024, ax2_0_fused * 4 + 
ax2_1_fused + ax2_2)
+                                    T.reads(
+                                        Y_reindex_local[v0, v1],
+                                        A_reindex_shared_local[v2, v0],
+                                        B_reindex_shared_local[v2, v1],
+                                    )
+                                    T.writes(Y_reindex_local[v0, v1])
+                                    T.block_attr(
+                                        {
+                                            
"meta_schedule.thread_extent_high_inclusive": 1024,
+                                            
"meta_schedule.thread_extent_low_inclusive": 32,
+                                            "meta_schedule.tiling_structure": 
"SSSRRSRS",
+                                        }
+                                    )
+                                    Y_reindex_local[v0, v1] = (
+                                        Y_reindex_local[v0, v1]
+                                        + A_reindex_shared_local[v2, v0]
+                                        * B_reindex_shared_local[v2, v1]
+                                    )
+                    for ax0, ax1 in T.grid(8, 4):
+                        with T.block("Y_reindex_local"):
+                            T.where(
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0
+                                < 1012
+                                and ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1
+                                < 1017
+                            )
+                            v0 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0,
+                            )
+                            v1 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1,
+                            )
+                            T.reads(Y_reindex_local[v0, v1])
+                            T.writes(Y[v0, v1])
+                            Y[v0, v1] = Y_reindex_local[v0, v1]
+

Review Comment:
   can we assert the cuda code generated btw?



##########
src/tir/transforms/inject_ptx_async_copy.cc:
##########
@@ -47,74 +47,104 @@ class PTXAsyncCopyInjector : public StmtMutator {
     return StmtMutator::VisitStmt_(attr);
   }
 
-  Stmt VisitStmt_(const BufferStoreNode* store) {
-    if (in_async && (store->buffer.scope() == "shared" || 
store->buffer.scope() == "shared.dyn")) {
-      if (auto* load = store->value.as<BufferLoadNode>()) {
-        if (load->buffer.scope() == "global") {
-          ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
-          ICHECK(load->indices[0]->dtype.lanes() == 
store->indices[0]->dtype.lanes());
-
-          const int indices_lanes = load->indices[0]->dtype.lanes();
-          const int bytes = indices_lanes * load->buffer->dtype.bytes();
-
-          if (bytes == 4 || bytes == 8 || bytes == 16) {
-            auto dst_elem_type = 
GetPointerType(store->buffer->data->type_annotation);
-            auto src_elem_type = 
GetPointerType(load->buffer->data->type_annotation);
-            ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
-                << "Both store and load buffer should have a pointer type 
annotation.";
-
-            int index_factor = 1;
-            if (dst_elem_type.value() != src_elem_type.value()) {
-              // The only case where src and dst have different dtypes is when 
the dst shared memory
-              // is a byte buffer generated by merging dynamic shared memory.
-              ICHECK(store->buffer.scope() == "shared.dyn");
-              ICHECK(dst_elem_type.value() == DataType::UInt(8));
-              // BufferStore/Load have the "pointer reinterpret" semantics 
according to their
-              // "value" dtype. Their "indices" are supposed to be applied 
after such pointer cast,
-              // for example: ((*float16)(byte_buffer))[buffer->indices] = 
fp16_value;
-              // To replace BufferStore/Load with cp.async, we need to 
multiply the store index by
-              // the byte size of the "value" dtype, to get the correct offset 
into the byte buffer.
-              index_factor = src_elem_type->bytes();
-            }
+  Stmt injectPTX(const BufferLoadNode* load, const BufferStoreNode* store,

Review Comment:
   nit: style
   
   ```suggestion
     Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store,
   ```



##########
tests/python/unittest/test_cp_async_in_if_then_else.py:
##########
@@ -0,0 +1,304 @@
+import tvm
+import numpy as np
+from tvm.script import tir as T
+
+
[email protected]_module
+class Module:
+    @T.prim_func
+    def main(
+        A: T.Buffer[(1012, 1014), "float32"],
+        B: T.Buffer[(1014, 1017), "float32"],
+        Y: T.Buffer[(1012, 1017), "float32"],
+    ):
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="shared")
+        B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="shared")
+        A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32", 
scope="local")
+        for ax0_0_ax1_0_fused in T.thread_binding(
+            128,
+            thread="blockIdx.x",
+            annotations={"pragma_auto_unroll_max_step": 1024, 
"pragma_unroll_explicit": 1},
+        ):
+            for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"):
+                for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in 
T.thread_binding(
+                    64, thread="threadIdx.x"
+                ):
+                    for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in 
T.grid(4, 4, 2, 1):
+                        with T.block("Y_init"):
+                            v0 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0_3_init * 2
+                                + ax0_4_init,
+                            )
+                            v1 = T.axis.spatial(
+                                1024,
+                                ax1_4_init
+                                + ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1_3_init,
+                            )
+                            T.reads()
+                            T.writes(Y_reindex_local[v0, v1])
+                            T.block_attr(
+                                {
+                                    
"meta_schedule.thread_extent_high_inclusive": 1024,
+                                    
"meta_schedule.thread_extent_low_inclusive": 32,
+                                    "meta_schedule.tiling_structure": 
"SSSRRSRS",
+                                }
+                            )
+                            Y_reindex_local[v0, v1] = T.float32(0)
+                    for ax2_0_fused in T.serial(
+                        256,
+                        annotations={
+                            "software_pipeline_async_stages": [0, 1],
+                            "software_pipeline_order": [0, 1, 3, 2, 4],
+                            "software_pipeline_stage": [0, 0, 2, 3, 3],
+                        },
+                    ):
+                        for ax0_ax1_fused_0 in T.serial(4):
+                            for ax0_ax1_fused_1 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                with T.block("A_reindex_shared"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused // 8 * 64
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) // 4,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax2_0_fused * 4
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) % 4,
+                                    )
+                                    T.reads(A[v0, v1])
+                                    T.writes(
+                                        A_reindex_shared[
+                                            v1,
+                                            v0 // 32 * 32
+                                            + v0 % 8 // 4 * 16
+                                            + v0 % 32 // 8 * 4
+                                            + v0 % 4,
+                                        ]
+                                    )
+                                    A_reindex_shared[
+                                        v1,
+                                        v0 // 32 * 32
+                                        + v0 % 8 // 4 * 16
+                                        + v0 % 32 // 8 * 4
+                                        + v0 % 4,
+                                    ] = T.if_then_else(
+                                        v0 < 1012 and v1 < 1014,
+                                        A[v0, v1],
+                                        T.float32(0),
+                                        dtype="float32",
+                                    )
+                        for ax0_ax1_fused_0 in T.serial(8):
+                            for ax0_ax1_fused_1 in T.thread_binding(64, 
thread="threadIdx.x"):
+                                with T.block("B_reindex_shared"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax2_0_fused * 4
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) // 128,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused % 8 * 128
+                                        + (ax0_ax1_fused_0 * 64 + 
ax0_ax1_fused_1) % 128,
+                                    )
+                                    T.reads(B[v0, v1])
+                                    T.writes(
+                                        B_reindex_shared[
+                                            v0,
+                                            v1 // 64 * 64
+                                            + v1 % 8 // 4 * 32
+                                            + v1 % 64 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                                    )
+                                    B_reindex_shared[
+                                        v0,
+                                        v1 // 64 * 64
+                                        + v1 % 8 // 4 * 32
+                                        + v1 % 64 // 8 * 4
+                                        + v1 % 4,
+                                    ] = T.if_then_else(
+                                        v0 < 1014 and v1 < 1017,
+                                        B[v0, v1],
+                                        T.float32(0),
+                                        dtype="float32",
+                                    )
+                        for ax2_1_fused in T.unroll(
+                            4,
+                            annotations={
+                                "software_pipeline_order": [0, 1, 2],
+                                "software_pipeline_stage": [0, 0, 1],
+                            },
+                        ):
+                            for ax0_ax1_fused_0 in T.unroll(2):
+                                for ax0_ax1_fused_1 in T.vectorized(4):
+                                    with T.block("A_reindex_shared_local"):
+                                        v0 = T.axis.spatial(1024, ax2_0_fused 
* 4 + ax2_1_fused)
+                                        v1 = T.axis.spatial(
+                                            1024,
+                                            ax0_0_ax1_0_fused // 8 * 64
+                                            + ax0_1_ax1_1_fused // 2 * 32
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                            // 32
+                                            * 16
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                            + ax0_ax1_fused_0 * 4
+                                            + ax0_ax1_fused_1,
+                                        )
+                                        T.reads(
+                                            A_reindex_shared[
+                                                v0,
+                                                v1 // 32 * 32
+                                                + v1 % 8 // 4 * 16
+                                                + v1 % 32 // 8 * 4
+                                                + v1 % 4,
+                                            ]
+                                        )
+                                        T.writes(A_reindex_shared_local[v0, 
v1])
+                                        A_reindex_shared_local[v0, v1] = 
A_reindex_shared[
+                                            v0,
+                                            v1 // 32 * 32
+                                            + v1 % 8 // 4 * 16
+                                            + v1 % 32 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                            for ax0_ax1_fused_0 in T.unroll(2):
+                                for ax0_ax1_fused_1 in T.vectorized(2):
+                                    with T.block("B_reindex_shared_local"):
+                                        v0 = T.axis.spatial(1024, ax2_0_fused 
* 4 + ax2_1_fused)
+                                        v1 = T.axis.spatial(
+                                            1024,
+                                            ax0_0_ax1_0_fused % 8 * 128
+                                            + ax0_1_ax1_1_fused % 2 * 64
+                                            + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                            % 32
+                                            // 2
+                                            * 4
+                                            + ax0_ax1_fused_0 * 2
+                                            + ax0_ax1_fused_1,
+                                        )
+                                        T.reads(
+                                            B_reindex_shared[
+                                                v0,
+                                                v1 // 64 * 64
+                                                + v1 % 8 // 4 * 32
+                                                + v1 % 64 // 8 * 4
+                                                + v1 % 4,
+                                            ]
+                                        )
+                                        T.writes(B_reindex_shared_local[v0, 
v1])
+                                        B_reindex_shared_local[v0, v1] = 
B_reindex_shared[
+                                            v0,
+                                            v1 // 64 * 64
+                                            + v1 % 8 // 4 * 32
+                                            + v1 % 64 // 8 * 4
+                                            + v1 % 4,
+                                        ]
+                            for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4, 
4, 1, 2, 1):
+                                with T.block("Y_update"):
+                                    v0 = T.axis.spatial(
+                                        1024,
+                                        ax0_0_ax1_0_fused // 8 * 64
+                                        + ax0_1_ax1_1_fused // 2 * 32
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                        + ax0_3 * 2
+                                        + ax0_4,
+                                    )
+                                    v1 = T.axis.spatial(
+                                        1024,
+                                        ax1_4
+                                        + ax0_0_ax1_0_fused % 8 * 128
+                                        + ax0_1_ax1_1_fused % 2 * 64
+                                        + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+                                        % 32
+                                        // 2
+                                        * 4
+                                        + ax1_3,
+                                    )
+                                    v2 = T.axis.reduce(1024, ax2_0_fused * 4 + 
ax2_1_fused + ax2_2)
+                                    T.reads(
+                                        Y_reindex_local[v0, v1],
+                                        A_reindex_shared_local[v2, v0],
+                                        B_reindex_shared_local[v2, v1],
+                                    )
+                                    T.writes(Y_reindex_local[v0, v1])
+                                    T.block_attr(
+                                        {
+                                            
"meta_schedule.thread_extent_high_inclusive": 1024,
+                                            
"meta_schedule.thread_extent_low_inclusive": 32,
+                                            "meta_schedule.tiling_structure": 
"SSSRRSRS",
+                                        }
+                                    )
+                                    Y_reindex_local[v0, v1] = (
+                                        Y_reindex_local[v0, v1]
+                                        + A_reindex_shared_local[v2, v0]
+                                        * B_reindex_shared_local[v2, v1]
+                                    )
+                    for ax0, ax1 in T.grid(8, 4):
+                        with T.block("Y_reindex_local"):
+                            T.where(
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0
+                                < 1012
+                                and ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1
+                                < 1017
+                            )
+                            v0 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused // 8 * 64
+                                + ax0_1_ax1_1_fused // 2 * 32
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+                                + ax0,
+                            )
+                            v1 = T.axis.spatial(
+                                1024,
+                                ax0_0_ax1_0_fused % 8 * 128
+                                + ax0_1_ax1_1_fused % 2 * 64
+                                + 
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+                                + ax1,
+                            )
+                            T.reads(Y_reindex_local[v0, v1])
+                            T.writes(Y[v0, v1])
+                            Y[v0, v1] = Y_reindex_local[v0, v1]
+
+
+def test_matmul():
+    with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
+        rt_mod = tvm.build(Module, target="cuda")

Review Comment:
   are async copy instructions always available? particularly on previous 
generations of cuda gpus



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to