This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new d50ba72  [CodeGen][CUDA] Fix issues in cuda codegen (#4876)
d50ba72 is described below

commit d50ba721eb5f7c0dbeceeaa78335d6f4c8cf2973
Author: wpan11nv <[email protected]>
AuthorDate: Sat Feb 15 19:47:36 2020 -0800

    [CodeGen][CUDA] Fix issues in cuda codegen (#4876)
    
    - Do not emit __shared__ etc. as part of type for casting
    
    - Fix fp16 reduction kernels with compiler errors:
    
      "no operator "+" matches these operands, volatile half + volatile half
    
      This patch inserts casts to remove volatile type qualifier following
      volatile loads (fp16 only). CUDA fp16 library headers should add
      volatile member functions.
    
    - Update have_fp16 to include compute 6.1 GPUs, which do support fp16,
      although their fp16 throughput is low. Updated tests.
    
    Signed-off-by: Wei Pan <[email protected]>
---
 python/tvm/contrib/nvcc.py                 |  6 +----
 src/target/source/codegen_c.cc             | 13 +++++-----
 src/target/source/codegen_c.h              | 34 ++++++++++++++++++++++++-
 src/target/source/codegen_cuda.cc          | 28 ++++++++++----------
 src/target/source/codegen_cuda.h           |  9 +++++++
 tests/python/unittest/test_codegen_cuda.py | 41 +++++++++++++++++++++++++-----
 topi/tests/python/test_topi_relu.py        | 14 +++-------
 topi/tests/python/test_topi_tensor.py      | 14 +++-------
 8 files changed, 105 insertions(+), 54 deletions(-)

diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index c50a9ce..8712f73 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -232,11 +232,7 @@ def have_fp16(compute_version):
     # 
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#arithmetic-instructions
     if major == 5 and minor == 3:
         return True
-    # NOTE: exclude compute capability 6.1 devices although it is actually 
available
-    #       to compute fp16, because these devices only have low-rate fp16 
performance.
-    if major == 6 and minor != 1:
-        return True
-    if major == 7:
+    if major >= 6:
         return True
 
     return False
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index b871b26..7f89307 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -153,14 +153,15 @@ std::string CodeGenC::GetBufferRef(
   if (alloc_storage_scope_.count(buffer)) {
     scope = alloc_storage_scope_.at(buffer);
   }
-  bool is_vol = volatile_buf_.count(buffer) != 0;
+  bool is_vol = IsVolatile(buffer);
   if (t.lanes() == 1) {
     if (!HandleTypeMatch(buffer, t) || is_vol) {
       os << "((";
       if (is_vol) {
         os << "volatile ";
       }
-      if (scope.length() != 0) {
+      // Scope may not be part of type.
+      if (!scope.empty() && IsScopePartOfType()) {
         PrintStorageScope(scope, os);
       }
       os << ' ';
@@ -189,7 +190,7 @@ std::string CodeGenC::GetBufferRef(
     if (is_vol) {
       os << "volatile ";
     }
-    if (scope.length() != 0) {
+    if (!scope.empty() && IsScopePartOfType()) {
       PrintStorageScope(scope, os);
     }
     os << ' ';
@@ -197,7 +198,7 @@ std::string CodeGenC::GetBufferRef(
     os << "*)(";
     if (!HandleTypeMatch(buffer, t.element_of())) {
       os << '(';
-      if (scope.length() != 0) {
+      if (!scope.empty() && IsScopePartOfType()) {
         PrintStorageScope(scope, os);
       }
       os << ' ';
@@ -620,14 +621,14 @@ void CodeGenC::VisitExpr_(const LoadNode* op, 
std::ostream& os) {  // NOLINT(*)
   // delcare type.
   if (op->dtype.lanes() == 1) {
     std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index);
-    os << ref;
+    HandleVolatileLoads(ref, op, os);
   } else {
     CHECK(is_one(op->predicate))
         << "predicated load is not supported";
     PrimExpr base;
     if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
       std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);
-      os << ref;
+      HandleVolatileLoads(ref, op, os);
     } else {
       // The assignment below introduces side-effect, and the resulting value 
cannot
       // be reused across multiple expression, thus a new scope is needed
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 00ed912..c6da1c4 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -178,9 +178,36 @@ class CodeGenC :
   // Print reference to struct location
   std::string GetStructRef(
       DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind);
-  // print reference to a buffer as type t in index.
+  // Print reference to a buffer as type t in index.
   virtual std::string GetBufferRef(
       DataType t, const VarNode* buffer, PrimExpr index);
+
+  /*!
+   * \brief Handle volatile loads.
+   *
+   * This is to workaround a bug in CUDA cuda_fp16.h. Volatile accesses
+   * to shared memory are required for reductions. However, __half class
+   * does not implement volatile member functions. CUDA codegen will cast
+   * away volatile qualifier from CUDA __half types.
+   */
+  virtual void HandleVolatileLoads(const std::string& value, const LoadNode* 
op,
+                                   std::ostream& os) {
+    // By default, do nothing but print the loaded value.
+    os << value;
+  }
+
+  /*!
+   * \brief Check if scope is part of type in the target language.
+   *
+   * **NOTE** In OpenCL, __local is part of type, so "__local int *"
+   * is legal. This is not the case for CUDA, where "__shared__"
+   * or "__constant__" is not part of type but a storage class (like
+   * C/C++ static).
+   */
+  virtual bool IsScopePartOfType() const {
+    return true;
+  }
+
   /*!
    * \brief If buffer is allocated as type t.
    * \param buf_var The buffer variable.
@@ -205,6 +232,11 @@ class CodeGenC :
   /*! \brief reserves common C keywords */
   void ReserveKeywordsAsUnique();
 
+  /*! \brief Check if buf_var is volatile or not. */
+  bool IsVolatile(const VarNode *buf_var) const {
+    return volatile_buf_.count(buf_var) != 0;
+  }
+
  private:
   /*! \brief whether to print in SSA form */
   bool print_ssa_form_{false};
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 0b2c54e..889d8b6 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -57,20 +57,6 @@ std::string CodeGenCUDA::Finish() {
                 << "{\n  return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
     decl_stream << "__device__ half min(half a, half b)\n"
                 << "{\n  return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
-    // FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
-    // which is needed by operations such as softmax.
-    // However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
-    // We need to figure out a solution which can satisfy both scenario.
-    // decl_stream << "__device__ half operator<="
-    //             << "(const volatile __half &a,  const volatile __half &b)\n"
-    //             << "{\n  return __hlt(a, b);\n}\n";
-    // decl_stream << "__device__ half operator+"
-    //             << "(const volatile __half &a,  const volatile __half &b)\n"
-    //             <<"{\n  return __hadd(a, b);\n}\n";
-    // decl_stream << "__device__ half operator*"
-    //             << "(const volatile __half &a, const volatile __half &b)\n"
-    //             <<   "{\n  return __hmul(a, b);\n}\n";
-    // otherwise simulate computation via float32
     decl_stream << "#else\n";
     decl_stream << _cuda_half_t_def;
     decl_stream << "#endif\n\n";
@@ -605,5 +591,19 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string 
&scope,
   return 0;
 }
 
+void CodeGenCUDA::HandleVolatileLoads(const std::string& value,
+                                      const LoadNode* op, std::ostream& os) {
+  // Cast away volatile qualifier for fp16 types. That is, only loads and
+  // stores are volatile. The loaded objects are not marked as volatile.
+  //
+  if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) {
+    os << "(";
+    PrintType(op->dtype, os);
+    os << ")(" << value << ")";
+  } else {
+    os << value;
+  }
+}
+
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index eca6871..d0a98a6 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -66,6 +66,15 @@ class CodeGenCUDA final : public CodeGenC {
   void VisitStmt_(const AttrStmtNode *op) final;
 
  private:
+  // Handle volatile loads
+  void HandleVolatileLoads(const std::string& value, const LoadNode* op,
+                           std::ostream& os) final;
+
+  // Whether scope such as "__shared__" or "__constant__"  is part of type.
+  bool IsScopePartOfType() const final {
+    return false;
+  }
+
   // Whether global barrier is needed.
   bool need_global_barrier_{false};
   // Global barrier state
diff --git a/tests/python/unittest/test_codegen_cuda.py 
b/tests/python/unittest/test_codegen_cuda.py
index 79b3544..ec36a5f 100644
--- a/tests/python/unittest/test_codegen_cuda.py
+++ b/tests/python/unittest/test_codegen_cuda.py
@@ -17,8 +17,9 @@
 # under the License.
 import tvm
 import numpy as np
+import topi
 import unittest
-from tvm.contrib.nvcc import parse_compute_version, have_int8
+from tvm.contrib.nvcc import have_fp16, have_int8
 from tvm.contrib import nvcc
 
 tx = tvm.thread_axis("threadIdx.x")
@@ -30,11 +31,8 @@ def test_cuda_vectorize_add():
         if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
             print("skip because cuda is not enabled..")
             return
-        if dtype == "float16":
-            major, minor = parse_compute_version(tvm.gpu(0).compute_version)
-            # fp16 starts from 5.3
-            if major < 6 or (major == 5 and minor < 3):
-                print("skip because gpu does not support fp16")
+        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
             return
         if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
             print("skip because gpu does not support int8")
@@ -291,6 +289,36 @@ def test_cuda_const_float_to_half():
     func(a, c)
     np.testing.assert_equal(c.asnumpy(), a_np > b.value)
 
+def test_cuda_reduction():
+    def check_cuda(dtype, m=32, n=32):
+        if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
+            print("skip because cuda is not enabled..")
+            return
+        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
+            return
+
+        a = tvm.placeholder((m, n), name="a", dtype=dtype)
+        b = tvm.placeholder((m, n), name="b", dtype=dtype)
+        c = a + b
+        d = a * b
+        e = topi.elemwise_sum([c, d])
+        g = topi.sum(e)
+        with tvm.target.cuda():
+            sg = topi.generic.schedule_reduce(g)
+            ctx = tvm.gpu(0)
+            func = tvm.build(sg, [a, b, g], 'cuda')
+            a_np = np.random.uniform(size=(m, n)).astype(a.dtype)
+            b_np = np.random.uniform(size=(m, n)).astype(b.dtype)
+            g_np = np.sum(np.add(a_np * b_np, a_np + b_np))
+            a_nd = tvm.nd.array(a_np, ctx)
+            b_nd = tvm.nd.array(b_np, ctx)
+            g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), ctx)
+            func(a_nd, b_nd, g_nd)
+            tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-3)
+
+    check_cuda("float32")
+    check_cuda("float16")
 
 if __name__ == "__main__":
     test_cuda_vectorize_add()
@@ -302,3 +330,4 @@ if __name__ == "__main__":
     test_cuda_reducition_binding()
     test_rfactor_predicates()
     test_cuda_const_float_to_half()
+    test_cuda_reduction()
\ No newline at end of file
diff --git a/topi/tests/python/test_topi_relu.py 
b/topi/tests/python/test_topi_relu.py
index 414edbc..8868d4e 100644
--- a/topi/tests/python/test_topi_relu.py
+++ b/topi/tests/python/test_topi_relu.py
@@ -20,18 +20,9 @@ import numpy as np
 import tvm
 import topi
 from topi.util import get_const_tuple
-from tvm.contrib.nvcc import parse_compute_version
+from tvm.contrib.nvcc import have_fp16
 from common import get_all_backend
 
-def skip_test(dtype, device):
-    if dtype == "float16" and device == "cuda":
-        major, minor = parse_compute_version(tvm.gpu(0).compute_version)
-        # fp16 starts from 5.3
-        if major < 6 or (major == 5 and minor < 3):
-            print("skip because gpu does not support fp16")
-            return True
-    return False
-
 def verify_relu(m, n, dtype="float32"):
     A = tvm.placeholder((m, n), name='A', dtype=dtype)
     B = topi.nn.relu(A)
@@ -44,7 +35,8 @@ def verify_relu(m, n, dtype="float32"):
         if not ctx.exist:
             print("Skip because %s is not enabled" % device)
             return
-        if skip_test(dtype, device):
+        if dtype == "float16" and device == "cuda" and not 
have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because %s does not have fp16 support" % device)
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
diff --git a/topi/tests/python/test_topi_tensor.py 
b/topi/tests/python/test_topi_tensor.py
index 84718ff..8e7073f 100644
--- a/topi/tests/python/test_topi_tensor.py
+++ b/topi/tests/python/test_topi_tensor.py
@@ -19,16 +19,7 @@ import numpy as np
 import tvm
 import topi
 from tvm.contrib.pickle_memoize import memoize
-from tvm.contrib.nvcc import parse_compute_version
-
-def skip_test(dtype, device):
-    if dtype == "float16" and device == "cuda":
-        major, minor = parse_compute_version(tvm.gpu(0).compute_version)
-        # fp16 starts from 5.3
-        if major < 6 or (major == 5 and minor < 3):
-            print("skip because gpu does not support fp16")
-            return True
-    return False
+from tvm.contrib.nvcc import have_fp16
 
 def verify_elemwise_sum(num_args, dtype):
     shape = (3,5,4)
@@ -99,7 +90,8 @@ def verify_vectorization(n, m, dtype):
         if not tvm.runtime.enabled(device):
             print("Skip because %s is not enabled" % device)
             return
-        if skip_test(dtype, device):
+        if dtype == "float16" and device == "cuda" and not 
have_fp16(tvm.gpu(0).compute_version):
+            print("Skip because gpu does not have fp16 support")
             return
         with tvm.target.create(device):
             ctx = tvm.context(device, 0)

Reply via email to