junrushao commented on code in PR #13966:
URL: https://github.com/apache/tvm/pull/13966#discussion_r1106063096


##########
tests/python/unittest/test_cp_async_in_if_then_else.py:
##########
@@ -0,0 +1,205 @@
+import tvm
+import numpy as np
+
+from tvm.script import tir as T
+import tvm.testing
+
+expected_cuda_script = r"""
+#ifdef _WIN32
+  using uint = unsigned int;
+  using uchar = unsigned char;
+  using ushort = unsigned short;
+  using int64_t = long long;
+  using uint64_t = unsigned long long;
+#else
+  #define uint unsigned int
+  #define uchar unsigned char
+  #define ushort unsigned short
+  #define int64_t long long
+  #define uint64_t unsigned long long
+#endif
+extern "C" __global__ void __launch_bounds__(16) main_kernel0(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
+  __shared__ float A_shared[64];
+  __shared__ float B_shared[64];
+  A_shared[((int)threadIdx.x)] = 0.000000e+00f;
+  B_shared[((int)threadIdx.x)] = 0.000000e+00f;
+__asm__ __volatile__("cp.async.commit_group;");
+
+
+  {
+    unsigned int addr;
+    __asm__ __volatile__(
+      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
+      : "=r"(addr)
+      : "l"((void *)(A_shared + (((int)threadIdx.x) + 16)))
+    );
+    __asm__ __volatile__(
+      "cp.async.ca.shared.global [%0], [%1], %2;"
+       :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4)
+    );
+  }
+
+  {
+    unsigned int addr;
+    __asm__ __volatile__(
+      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
+      : "=r"(addr)
+      : "l"((void *)(B_shared + (((int)threadIdx.x) + 16)))
+    );
+    __asm__ __volatile__(
+      "cp.async.ca.shared.global [%0], [%1], %2;"
+       :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4)
+    );
+  }
+__asm__ __volatile__("cp.async.commit_group;");
+
+
+  {
+    unsigned int addr;
+    __asm__ __volatile__(
+      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
+      : "=r"(addr)
+      : "l"((void *)(A_shared + (((int)threadIdx.x) + 32)))
+    );
+    __asm__ __volatile__(
+      "cp.async.ca.shared.global [%0], [%1], %2;"
+       :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4)
+    );
+  }
+
+  {
+    unsigned int addr;
+    __asm__ __volatile__(
+      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
+      : "=r"(addr)
+      : "l"((void *)(B_shared + (((int)threadIdx.x) + 32)))
+    );
+    __asm__ __volatile__(
+      "cp.async.ca.shared.global [%0], [%1], %2;"
+       :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4)
+    );
+  }
+__asm__ __volatile__("cp.async.commit_group;");
+
+  for (int i = 0; i < 13; ++i) {
+    bool cse_var_1 = (i < 12);
+
+  {
+    unsigned int addr;
+    __asm__ __volatile__(
+      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
+      : "=r"(addr)
+      : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))))
+    );
+    int src_bytes = cse_var_1 ? 4 : 0;
+    __asm__ __volatile__(
+      "cp.async.ca.shared.global [%0], [%1], %2, %3;"
+       :: "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), 
"n"(4), "r"(src_bytes)
+    );
+  }
+__asm__ __volatile__("cp.async.commit_group;");
+
+__asm__ __volatile__("cp.async.wait_group 5;");
+
+    __syncthreads();
+    C[((((int)threadIdx.x) * 16) + i)] = (A_shared[(((i & 3) * 16) + 
((int)threadIdx.x))] + B_shared[(((i & 3) * 16) + ((int)threadIdx.x))]);
+    __syncthreads();
+
+  {
+    unsigned int addr;
+    __asm__ __volatile__(
+      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; 
}\n"
+      : "=r"(addr)
+      : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))))
+    );
+    int src_bytes = cse_var_1 ? 4 : 0;
+    __asm__ __volatile__(
+      "cp.async.ca.shared.global [%0], [%1], %2, %3;"
+       :: "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), 
"n"(4), "r"(src_bytes)
+    );
+  }
+__asm__ __volatile__("cp.async.commit_group;");
+
+  }
+__asm__ __volatile__("cp.async.wait_group 2;");
+
+  __syncthreads();
+  C[((((int)threadIdx.x) * 16) + 13)] = (A_shared[(((int)threadIdx.x) + 16)] + 
B_shared[(((int)threadIdx.x) + 16)]);
+__asm__ __volatile__("cp.async.wait_group 1;");
+
+  __syncthreads();
+  C[((((int)threadIdx.x) * 16) + 14)] = (A_shared[(((int)threadIdx.x) + 32)] + 
B_shared[(((int)threadIdx.x) + 32)]);
+__asm__ __volatile__("cp.async.wait_group 0;");
+
+  __syncthreads();
+  C[((((int)threadIdx.x) * 16) + 15)] = (A_shared[(((int)threadIdx.x) + 48)] + 
B_shared[(((int)threadIdx.x) + 48)]);
+}
+
+"""
+
+
[email protected]_cuda
+def test_cp_async_in_if_then_else():
+    arch = tvm.contrib.nvcc.get_target_compute_version()
+    major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+    if major < 8:
+        # At least sm80 is required
+        return

Review Comment:
   Note: we can assert generated code even if the compute version is less than 
8, because we are not compiling or running them



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to