junrushao commented on code in PR #13966: URL: https://github.com/apache/tvm/pull/13966#discussion_r1105276027
##########
src/target/source/codegen_cuda.cc:
##########
@@ -914,7 +914,12 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
- this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
size);
+ // use size of argument list to indicate whether or not to use predicated
cp.async
+ if (op->args.size() == 5)
+ this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
size);
+ else
+ this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src,
src_offset,
+ size,
this->PrintExpr(op->args[5]));
Review Comment:
nit: always use `{}`
```suggestion
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
size);
}
else {
this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src,
src_offset,
size,
this->PrintExpr(op->args[5]));
}
```
##########
tests/python/unittest/test_cp_async_in_if_then_else.py:
##########
@@ -0,0 +1,304 @@
+import tvm
+import numpy as np
+from tvm.script import tir as T
+
+
[email protected]_module
+class Module:
+ @T.prim_func
+ def main(
+ A: T.Buffer[(1012, 1014), "float32"],
+ B: T.Buffer[(1014, 1017), "float32"],
+ Y: T.Buffer[(1012, 1017), "float32"],
+ ):
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32",
scope="local")
+ A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32",
scope="shared")
+ B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32",
scope="shared")
+ A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32",
scope="local")
+ B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32",
scope="local")
+ for ax0_0_ax1_0_fused in T.thread_binding(
+ 128,
+ thread="blockIdx.x",
+ annotations={"pragma_auto_unroll_max_step": 1024,
"pragma_unroll_explicit": 1},
+ ):
+ for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"):
+ for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in
T.thread_binding(
+ 64, thread="threadIdx.x"
+ ):
+ for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in
T.grid(4, 4, 2, 1):
+ with T.block("Y_init"):
+ v0 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0_3_init * 2
+ + ax0_4_init,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax1_4_init
+ + ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+ + ax1_3_init,
+ )
+ T.reads()
+ T.writes(Y_reindex_local[v0, v1])
+ T.block_attr(
+ {
+
"meta_schedule.thread_extent_high_inclusive": 1024,
+
"meta_schedule.thread_extent_low_inclusive": 32,
+ "meta_schedule.tiling_structure":
"SSSRRSRS",
+ }
+ )
+ Y_reindex_local[v0, v1] = T.float32(0)
+ for ax2_0_fused in T.serial(
+ 256,
+ annotations={
+ "software_pipeline_async_stages": [0, 1],
+ "software_pipeline_order": [0, 1, 3, 2, 4],
+ "software_pipeline_stage": [0, 0, 2, 3, 3],
+ },
+ ):
+ for ax0_ax1_fused_0 in T.serial(4):
+ for ax0_ax1_fused_1 in T.thread_binding(64,
thread="threadIdx.x"):
+ with T.block("A_reindex_shared"):
+ v0 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + (ax0_ax1_fused_0 * 64 +
ax0_ax1_fused_1) // 4,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax2_0_fused * 4
+ + (ax0_ax1_fused_0 * 64 +
ax0_ax1_fused_1) % 4,
+ )
+ T.reads(A[v0, v1])
+ T.writes(
+ A_reindex_shared[
+ v1,
+ v0 // 32 * 32
+ + v0 % 8 // 4 * 16
+ + v0 % 32 // 8 * 4
+ + v0 % 4,
+ ]
+ )
+ A_reindex_shared[
+ v1,
+ v0 // 32 * 32
+ + v0 % 8 // 4 * 16
+ + v0 % 32 // 8 * 4
+ + v0 % 4,
+ ] = T.if_then_else(
+ v0 < 1012 and v1 < 1014,
+ A[v0, v1],
+ T.float32(0),
+ dtype="float32",
+ )
+ for ax0_ax1_fused_0 in T.serial(8):
+ for ax0_ax1_fused_1 in T.thread_binding(64,
thread="threadIdx.x"):
+ with T.block("B_reindex_shared"):
+ v0 = T.axis.spatial(
+ 1024,
+ ax2_0_fused * 4
+ + (ax0_ax1_fused_0 * 64 +
ax0_ax1_fused_1) // 128,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused % 8 * 128
+ + (ax0_ax1_fused_0 * 64 +
ax0_ax1_fused_1) % 128,
+ )
+ T.reads(B[v0, v1])
+ T.writes(
+ B_reindex_shared[
+ v0,
+ v1 // 64 * 64
+ + v1 % 8 // 4 * 32
+ + v1 % 64 // 8 * 4
+ + v1 % 4,
+ ]
+ )
+ B_reindex_shared[
+ v0,
+ v1 // 64 * 64
+ + v1 % 8 // 4 * 32
+ + v1 % 64 // 8 * 4
+ + v1 % 4,
+ ] = T.if_then_else(
+ v0 < 1014 and v1 < 1017,
+ B[v0, v1],
+ T.float32(0),
+ dtype="float32",
+ )
+ for ax2_1_fused in T.unroll(
+ 4,
+ annotations={
+ "software_pipeline_order": [0, 1, 2],
+ "software_pipeline_stage": [0, 0, 1],
+ },
+ ):
+ for ax0_ax1_fused_0 in T.unroll(2):
+ for ax0_ax1_fused_1 in T.vectorized(4):
+ with T.block("A_reindex_shared_local"):
+ v0 = T.axis.spatial(1024, ax2_0_fused
* 4 + ax2_1_fused)
+ v1 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+ // 32
+ * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0_ax1_fused_0 * 4
+ + ax0_ax1_fused_1,
+ )
+ T.reads(
+ A_reindex_shared[
+ v0,
+ v1 // 32 * 32
+ + v1 % 8 // 4 * 16
+ + v1 % 32 // 8 * 4
+ + v1 % 4,
+ ]
+ )
+ T.writes(A_reindex_shared_local[v0,
v1])
+ A_reindex_shared_local[v0, v1] =
A_reindex_shared[
+ v0,
+ v1 // 32 * 32
+ + v1 % 8 // 4 * 16
+ + v1 % 32 // 8 * 4
+ + v1 % 4,
+ ]
+ for ax0_ax1_fused_0 in T.unroll(2):
+ for ax0_ax1_fused_1 in T.vectorized(2):
+ with T.block("B_reindex_shared_local"):
+ v0 = T.axis.spatial(1024, ax2_0_fused
* 4 + ax2_1_fused)
+ v1 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+ % 32
+ // 2
+ * 4
+ + ax0_ax1_fused_0 * 2
+ + ax0_ax1_fused_1,
+ )
+ T.reads(
+ B_reindex_shared[
+ v0,
+ v1 // 64 * 64
+ + v1 % 8 // 4 * 32
+ + v1 % 64 // 8 * 4
+ + v1 % 4,
+ ]
+ )
+ T.writes(B_reindex_shared_local[v0,
v1])
+ B_reindex_shared_local[v0, v1] =
B_reindex_shared[
+ v0,
+ v1 // 64 * 64
+ + v1 % 8 // 4 * 32
+ + v1 % 64 // 8 * 4
+ + v1 % 4,
+ ]
+ for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4,
4, 1, 2, 1):
+ with T.block("Y_update"):
+ v0 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0_3 * 2
+ + ax0_4,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax1_4
+ + ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+ % 32
+ // 2
+ * 4
+ + ax1_3,
+ )
+ v2 = T.axis.reduce(1024, ax2_0_fused * 4 +
ax2_1_fused + ax2_2)
+ T.reads(
+ Y_reindex_local[v0, v1],
+ A_reindex_shared_local[v2, v0],
+ B_reindex_shared_local[v2, v1],
+ )
+ T.writes(Y_reindex_local[v0, v1])
+ T.block_attr(
+ {
+
"meta_schedule.thread_extent_high_inclusive": 1024,
+
"meta_schedule.thread_extent_low_inclusive": 32,
+ "meta_schedule.tiling_structure":
"SSSRRSRS",
+ }
+ )
+ Y_reindex_local[v0, v1] = (
+ Y_reindex_local[v0, v1]
+ + A_reindex_shared_local[v2, v0]
+ * B_reindex_shared_local[v2, v1]
+ )
+ for ax0, ax1 in T.grid(8, 4):
+ with T.block("Y_reindex_local"):
+ T.where(
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0
+ < 1012
+ and ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+ + ax1
+ < 1017
+ )
+ v0 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+ + ax1,
+ )
+ T.reads(Y_reindex_local[v0, v1])
+ T.writes(Y[v0, v1])
+ Y[v0, v1] = Y_reindex_local[v0, v1]
+
Review Comment:
can we assert the cuda code generated btw?
##########
src/tir/transforms/inject_ptx_async_copy.cc:
##########
@@ -47,74 +47,104 @@ class PTXAsyncCopyInjector : public StmtMutator {
return StmtMutator::VisitStmt_(attr);
}
- Stmt VisitStmt_(const BufferStoreNode* store) {
- if (in_async && (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn")) {
- if (auto* load = store->value.as<BufferLoadNode>()) {
- if (load->buffer.scope() == "global") {
- ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
- ICHECK(load->indices[0]->dtype.lanes() ==
store->indices[0]->dtype.lanes());
-
- const int indices_lanes = load->indices[0]->dtype.lanes();
- const int bytes = indices_lanes * load->buffer->dtype.bytes();
-
- if (bytes == 4 || bytes == 8 || bytes == 16) {
- auto dst_elem_type =
GetPointerType(store->buffer->data->type_annotation);
- auto src_elem_type =
GetPointerType(load->buffer->data->type_annotation);
- ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
- << "Both store and load buffer should have a pointer type
annotation.";
-
- int index_factor = 1;
- if (dst_elem_type.value() != src_elem_type.value()) {
- // The only case where src and dst have different dtypes is when
the dst shared memory
- // is a byte buffer generated by merging dynamic shared memory.
- ICHECK(store->buffer.scope() == "shared.dyn");
- ICHECK(dst_elem_type.value() == DataType::UInt(8));
- // BufferStore/Load have the "pointer reinterpret" semantics
according to their
- // "value" dtype. Their "indices" are supposed to be applied
after such pointer cast,
- // for example: ((*float16)(byte_buffer))[buffer->indices] =
fp16_value;
- // To replace BufferStore/Load with cp.async, we need to
multiply the store index by
- // the byte size of the "value" dtype, to get the correct offset
into the byte buffer.
- index_factor = src_elem_type->bytes();
- }
+ Stmt injectPTX(const BufferLoadNode* load, const BufferStoreNode* store,
Review Comment:
nit: style
```suggestion
Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store,
```
##########
tests/python/unittest/test_cp_async_in_if_then_else.py:
##########
@@ -0,0 +1,304 @@
+import tvm
+import numpy as np
+from tvm.script import tir as T
+
+
[email protected]_module
+class Module:
+ @T.prim_func
+ def main(
+ A: T.Buffer[(1012, 1014), "float32"],
+ B: T.Buffer[(1014, 1017), "float32"],
+ Y: T.Buffer[(1012, 1017), "float32"],
+ ):
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ Y_reindex_local = T.alloc_buffer([1024, 1024], dtype="float32",
scope="local")
+ A_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32",
scope="shared")
+ B_reindex_shared = T.alloc_buffer([1024, 1024], dtype="float32",
scope="shared")
+ A_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32",
scope="local")
+ B_reindex_shared_local = T.alloc_buffer([1024, 1024], dtype="float32",
scope="local")
+ for ax0_0_ax1_0_fused in T.thread_binding(
+ 128,
+ thread="blockIdx.x",
+ annotations={"pragma_auto_unroll_max_step": 1024,
"pragma_unroll_explicit": 1},
+ ):
+ for ax0_1_ax1_1_fused in T.thread_binding(4, thread="vthread.x"):
+ for ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused in
T.thread_binding(
+ 64, thread="threadIdx.x"
+ ):
+ for ax0_3_init, ax1_3_init, ax0_4_init, ax1_4_init in
T.grid(4, 4, 2, 1):
+ with T.block("Y_init"):
+ v0 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0_3_init * 2
+ + ax0_4_init,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax1_4_init
+ + ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+ + ax1_3_init,
+ )
+ T.reads()
+ T.writes(Y_reindex_local[v0, v1])
+ T.block_attr(
+ {
+
"meta_schedule.thread_extent_high_inclusive": 1024,
+
"meta_schedule.thread_extent_low_inclusive": 32,
+ "meta_schedule.tiling_structure":
"SSSRRSRS",
+ }
+ )
+ Y_reindex_local[v0, v1] = T.float32(0)
+ for ax2_0_fused in T.serial(
+ 256,
+ annotations={
+ "software_pipeline_async_stages": [0, 1],
+ "software_pipeline_order": [0, 1, 3, 2, 4],
+ "software_pipeline_stage": [0, 0, 2, 3, 3],
+ },
+ ):
+ for ax0_ax1_fused_0 in T.serial(4):
+ for ax0_ax1_fused_1 in T.thread_binding(64,
thread="threadIdx.x"):
+ with T.block("A_reindex_shared"):
+ v0 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + (ax0_ax1_fused_0 * 64 +
ax0_ax1_fused_1) // 4,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax2_0_fused * 4
+ + (ax0_ax1_fused_0 * 64 +
ax0_ax1_fused_1) % 4,
+ )
+ T.reads(A[v0, v1])
+ T.writes(
+ A_reindex_shared[
+ v1,
+ v0 // 32 * 32
+ + v0 % 8 // 4 * 16
+ + v0 % 32 // 8 * 4
+ + v0 % 4,
+ ]
+ )
+ A_reindex_shared[
+ v1,
+ v0 // 32 * 32
+ + v0 % 8 // 4 * 16
+ + v0 % 32 // 8 * 4
+ + v0 % 4,
+ ] = T.if_then_else(
+ v0 < 1012 and v1 < 1014,
+ A[v0, v1],
+ T.float32(0),
+ dtype="float32",
+ )
+ for ax0_ax1_fused_0 in T.serial(8):
+ for ax0_ax1_fused_1 in T.thread_binding(64,
thread="threadIdx.x"):
+ with T.block("B_reindex_shared"):
+ v0 = T.axis.spatial(
+ 1024,
+ ax2_0_fused * 4
+ + (ax0_ax1_fused_0 * 64 +
ax0_ax1_fused_1) // 128,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused % 8 * 128
+ + (ax0_ax1_fused_0 * 64 +
ax0_ax1_fused_1) % 128,
+ )
+ T.reads(B[v0, v1])
+ T.writes(
+ B_reindex_shared[
+ v0,
+ v1 // 64 * 64
+ + v1 % 8 // 4 * 32
+ + v1 % 64 // 8 * 4
+ + v1 % 4,
+ ]
+ )
+ B_reindex_shared[
+ v0,
+ v1 // 64 * 64
+ + v1 % 8 // 4 * 32
+ + v1 % 64 // 8 * 4
+ + v1 % 4,
+ ] = T.if_then_else(
+ v0 < 1014 and v1 < 1017,
+ B[v0, v1],
+ T.float32(0),
+ dtype="float32",
+ )
+ for ax2_1_fused in T.unroll(
+ 4,
+ annotations={
+ "software_pipeline_order": [0, 1, 2],
+ "software_pipeline_stage": [0, 0, 1],
+ },
+ ):
+ for ax0_ax1_fused_0 in T.unroll(2):
+ for ax0_ax1_fused_1 in T.vectorized(4):
+ with T.block("A_reindex_shared_local"):
+ v0 = T.axis.spatial(1024, ax2_0_fused
* 4 + ax2_1_fused)
+ v1 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+ // 32
+ * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0_ax1_fused_0 * 4
+ + ax0_ax1_fused_1,
+ )
+ T.reads(
+ A_reindex_shared[
+ v0,
+ v1 // 32 * 32
+ + v1 % 8 // 4 * 16
+ + v1 % 32 // 8 * 4
+ + v1 % 4,
+ ]
+ )
+ T.writes(A_reindex_shared_local[v0,
v1])
+ A_reindex_shared_local[v0, v1] =
A_reindex_shared[
+ v0,
+ v1 // 32 * 32
+ + v1 % 8 // 4 * 16
+ + v1 % 32 // 8 * 4
+ + v1 % 4,
+ ]
+ for ax0_ax1_fused_0 in T.unroll(2):
+ for ax0_ax1_fused_1 in T.vectorized(2):
+ with T.block("B_reindex_shared_local"):
+ v0 = T.axis.spatial(1024, ax2_0_fused
* 4 + ax2_1_fused)
+ v1 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+ % 32
+ // 2
+ * 4
+ + ax0_ax1_fused_0 * 2
+ + ax0_ax1_fused_1,
+ )
+ T.reads(
+ B_reindex_shared[
+ v0,
+ v1 // 64 * 64
+ + v1 % 8 // 4 * 32
+ + v1 % 64 // 8 * 4
+ + v1 % 4,
+ ]
+ )
+ T.writes(B_reindex_shared_local[v0,
v1])
+ B_reindex_shared_local[v0, v1] =
B_reindex_shared[
+ v0,
+ v1 // 64 * 64
+ + v1 % 8 // 4 * 32
+ + v1 % 64 // 8 * 4
+ + v1 % 4,
+ ]
+ for ax0_3, ax1_3, ax2_2, ax0_4, ax1_4 in T.grid(4,
4, 1, 2, 1):
+ with T.block("Y_update"):
+ v0 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0_3 * 2
+ + ax0_4,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax1_4
+ + ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused
+ % 32
+ // 2
+ * 4
+ + ax1_3,
+ )
+ v2 = T.axis.reduce(1024, ax2_0_fused * 4 +
ax2_1_fused + ax2_2)
+ T.reads(
+ Y_reindex_local[v0, v1],
+ A_reindex_shared_local[v2, v0],
+ B_reindex_shared_local[v2, v1],
+ )
+ T.writes(Y_reindex_local[v0, v1])
+ T.block_attr(
+ {
+
"meta_schedule.thread_extent_high_inclusive": 1024,
+
"meta_schedule.thread_extent_low_inclusive": 32,
+ "meta_schedule.tiling_structure":
"SSSRRSRS",
+ }
+ )
+ Y_reindex_local[v0, v1] = (
+ Y_reindex_local[v0, v1]
+ + A_reindex_shared_local[v2, v0]
+ * B_reindex_shared_local[v2, v1]
+ )
+ for ax0, ax1 in T.grid(8, 4):
+ with T.block("Y_reindex_local"):
+ T.where(
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0
+ < 1012
+ and ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+ + ax1
+ < 1017
+ )
+ v0 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused // 8 * 64
+ + ax0_1_ax1_1_fused // 2 * 32
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused // 32 * 16
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 2 * 8
+ + ax0,
+ )
+ v1 = T.axis.spatial(
+ 1024,
+ ax0_0_ax1_0_fused % 8 * 128
+ + ax0_1_ax1_1_fused % 2 * 64
+ +
ax1_2_0_ax0_2_0_ax0_2_1_ax1_2_1_ax0_2_2_fused % 32 // 2 * 4
+ + ax1,
+ )
+ T.reads(Y_reindex_local[v0, v1])
+ T.writes(Y[v0, v1])
+ Y[v0, v1] = Y_reindex_local[v0, v1]
+
+
+def test_matmul():
+ with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
+ rt_mod = tvm.build(Module, target="cuda")
Review Comment:
are async copy instructions always available? particularly on previous
generations of cuda gpus
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
