This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new f31df01  fix lower_warp_memory (#5247)
f31df01 is described below

commit f31df01e470d41a09221921059327360c43c2472
Author: Tang, Shizhi <[email protected]>
AuthorDate: Mon Apr 6 23:43:38 2020 +0800

    fix lower_warp_memory (#5247)
---
 src/tir/transforms/lower_warp_memory.cc            |  6 +--
 .../test_tir_transform_lower_warp_memory.py        | 51 +++++++++++++++++++++-
 2 files changed, 52 insertions(+), 5 deletions(-)

diff --git a/src/tir/transforms/lower_warp_memory.cc 
b/src/tir/transforms/lower_warp_memory.cc
index 0361100..1921db5 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -219,13 +219,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
   }
 
  protected:
-  PrimExpr Mutate_(const VarNode* op) {
+  PrimExpr VisitExpr_(const VarNode* op) override {
     CHECK(op != buffer_)
         << "Cannot access address of warp memory directly";
     return StmtExprMutator::VisitExpr_(op);
   }
 
-  Stmt VisitStmt_(const StoreNode* op) {
+  Stmt VisitStmt_(const StoreNode* op) override {
     if (op->buffer_var.get() == buffer_) {
       PrimExpr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
@@ -235,7 +235,7 @@ class WarpAccessRewriter : protected StmtExprMutator {
     }
   }
 
-  PrimExpr Mutate_(const LoadNode* op) {
+  PrimExpr VisitExpr_(const LoadNode* op) override {
     if (op->buffer_var.get() == buffer_) {
       PrimExpr local_index, group;
       std::tie(local_index, group) = SplitIndexByGroup(op->index);
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py 
b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index cf6ef72..25204eb 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -16,8 +16,11 @@
 # under the License.
 import tvm
 from tvm import te
+from tvm.contrib.nvcc import have_fp16
 
-def test_lower_warp_mem():
+import numpy as np
+
+def test_lower_warp_memory_local_scope():
     m = 128
     A = te.placeholder((m,), name='A')
     B = te.compute((m,), lambda i: A[i] + 3, name='B')
@@ -44,6 +47,50 @@ def test_lower_warp_mem():
     assert(fdevice.body.body.value.value == "local")
     assert(fdevice.body.body.body.extents[0].value == 2)
 
+def test_lower_warp_memory_cuda_end_to_end():
+    def check_cuda(dtype):
+        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+            print("skip because cuda is not enabled..")
+            return
+        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
+            return
+
+        m = 128
+        A = te.placeholder((m,), name='A', dtype=dtype)
+        B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], 
name='B')
+
+        cuda_target = tvm.target.create("cuda")
+        assert cuda_target.thread_warp_size == 32
+        with cuda_target:
+            s = te.create_schedule(B.op)
+            AA = s.cache_read(A, "warp", [B])
+            xo, xi = s[B].split(B.op.axis[0], 64)
+            xi0, xi1 = s[B].split(xi, factor=32)
+            tx = te.thread_axis("threadIdx.x")
+            s[B].bind(xi1, tx)
+            s[B].bind(xo, te.thread_axis("blockIdx.x"))
+            s[AA].compute_at(s[B], xo)
+            xo, xi = s[AA].split(s[AA].op.axis[0], 32)
+            s[AA].bind(xi, tx)
+
+            ctx = tvm.gpu(0)
+            func = tvm.build(s, [A, B], "cuda")
+            A_np = np.array(list(range(m)), dtype=dtype)
+            B_np = np.array(
+                    list(range(1, 32)) + [0] +
+                    list(range(33, 64)) + [32] +
+                    list(range(65, 96)) + [64] +
+                    list(range(97, 128)) + [96],
+                    dtype=dtype)
+            A_nd = tvm.nd.array(A_np, ctx)
+            B_nd = tvm.nd.array(np.zeros(B_np.shape, dtype=B_np.dtype), ctx)
+            func(A_nd, B_nd)
+            tvm.testing.assert_allclose(B_nd.asnumpy(), B_np, rtol=1e-3)
+
+    check_cuda("float32")
+    check_cuda("float16")
 
 if __name__ == "__main__":
-    test_lower_warp_mem()
+    test_lower_warp_memory_local_scope()
+    test_lower_warp_memory_cuda_end_to_end()

Reply via email to