This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 2ccbcec  GPU gemms true fp16 (#17466) (#18023)
2ccbcec is described below

commit 2ccbcecc3cc4faa26af2a9e583122980c6f9a6bb
Author: MoisesHer <[email protected]>
AuthorDate: Wed Apr 15 14:20:31 2020 -0700

    GPU gemms true fp16 (#17466) (#18023)
    
    * Temporal solution for fp16 accumulation in Bert gemms
    
    * Resolve alpha/beta type issue
    
    * add documentation for env variable MXNET_FC_TRUE_FP16
    
    * Improve description of env variable
    
    * Add unitest checking environment variable
    
    * keep pseudo-fp16 if architecture does not support Float16Compute
    
    * Fix cpplint
---
 docs/static_site/src/pages/api/faq/env_var.md |  4 ++
 src/operator/contrib/transformer.cu           | 30 +++++++++++++--
 src/operator/linalg_impl.h                    | 53 ++++++++++++++++++++++-----
 tests/python/gpu/test_gluon_gpu.py            | 21 +++++++++++
 4 files changed, 95 insertions(+), 13 deletions(-)

diff --git a/docs/static_site/src/pages/api/faq/env_var.md 
b/docs/static_site/src/pages/api/faq/env_var.md
index b91d476..e0b70a6 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -358,6 +358,10 @@ If ctypes is used, it must be 
`mxnet._ctypes.ndarray.NDArrayBase`.
   - Values: 0(false) or 1(true) ```(default=1)```
   - This variable controls whether to use the MKL-DNN backend in fused RNN 
operator for CPU context. There are two fusion implementations of RNN operator 
in MXNet. The MKL-DNN implementation has a better performance than the naive 
one, but the latter is more stable in the backward operation currently.
 
+* MXNET_FC_TRUE_FP16
+  - Values: 0(false) or 1(true) ```(default=0)```
+  - If this variable is set to true, MXNet will perform fp16 accumulation when 
using cuBLAS and input datatype is set to float16. This could increase the 
speed of the computation, but might result in loss of accuracy. This makes this 
setting useful mainly for inference usecases.
+
 Settings for Minimum Memory Usage
 ---------------------------------
 - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
diff --git a/src/operator/contrib/transformer.cu 
b/src/operator/contrib/transformer.cu
index 59029ea..44c8ebd 100644
--- a/src/operator/contrib/transformer.cu
+++ b/src/operator/contrib/transformer.cu
@@ -50,6 +50,28 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool 
transA, bool transB,
       << "Must init CuBLAS handle in stream";
 
   cublasHandle_t blas_handle = mshadow::Stream<gpu>::GetBlasHandle(s);
+  auto err = CUBLAS_STATUS_SUCCESS;
+  using TrueFP16Type = DType;
+  using PseudoFP16Type = typename CublasType<DType>::ScaleType;
+  // Set up alpha and beta values in the possible formats needed (only 
different when dtype == half)
+  TrueFP16Type trueFP16_alpha = static_cast<TrueFP16Type>(alpha);
+  TrueFP16Type trueFP16_beta = static_cast<TrueFP16Type>(beta);
+  PseudoFP16Type pseudoFP16_alpha = static_cast<PseudoFP16Type>(alpha);
+  PseudoFP16Type pseudoFP16_beta = static_cast<PseudoFP16Type>(beta);
+  const void *alpha_ptr;
+  const void *beta_ptr;
+  cudaDataType_t computeType;
+  bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false);
+  if (use_true_fp16) {
+    alpha_ptr = &trueFP16_alpha;
+    beta_ptr = &trueFP16_beta;
+    computeType = CublasType<TrueFP16Type>::kCudaFlag;
+  } else {
+    alpha_ptr = &pseudoFP16_alpha;
+    beta_ptr = &pseudoFP16_beta;
+    computeType = CublasType<PseudoFP16Type>::kCudaFlag;
+  }
+
   // cublasGemmStridedBatchedEx is only supported for GPU with architecture
   // capabilities equal or greater than 5.0. Fall back to
   // cublasSgemmStridedBatched, which doesn't support implicit conversion
@@ -59,12 +81,12 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool 
transA, bool transB,
     CUBLAS_CALL(cublasGemmStridedBatchedEx(
         blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
         static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
-        reinterpret_cast<void*>(&alpha),
+        alpha_ptr,
         a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
         b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
-        reinterpret_cast<void*>(&beta),
+        beta_ptr,
         c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
-        static_cast<int>(batchCount), CUDA_R_32F, algo));
+        static_cast<int>(batchCount), computeType, algo));
   } else {
     if (std::is_same<DType, float>::value) {
       CUBLAS_CALL(cublasSgemmStridedBatched(
@@ -124,7 +146,7 @@ void gemm_switch_fp32accum(mshadow::Stream<gpu>* s, bool 
transA, bool transB,
   cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
   if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
     CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, 
strideA, b, ldb,
-      strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
+      strideB, beta, c, ldc, strideC, batchCount, 
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
   } else {
     CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, 
strideA, b, ldb,
       strideB, beta, c, ldc, strideC, batchCount);
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index d83eb0d..fd6800d 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -249,6 +249,7 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const 
Tensor<gpu, 2, mshadow::half:
                                              mshadow::half::half_t beta,
                                              bool tA, bool tB, Stream<gpu> *s) 
{
   using namespace mxnet;
+  using namespace mxnet::common::cuda;
   using mshadow::gpu;
   CHECK_NOTNULL(s);
   check_gemm(A, B, C, alpha, beta, tA, tB);
@@ -261,25 +262,59 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const 
Tensor<gpu, 2, mshadow::half:
   auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);
 #endif
 
-  // pseudo-fp16 (fp32 math with fp16 I/O)
-  float alpha_f = float(alpha);  // NOLINT(*)
-  float beta_f = float(beta);  // NOLINT(*)
-
-  // As of cuda8, cublas adopted the cuda datatype, rather than maintaining 
its own datatype.
+// As of cuda8, cublas adopted the cuda datatype, rather than maintaining its 
own datatype.
 #if CUDA_VERSION >= 8000
   cudaDataType_t half_datatype = CUDA_R_16F;
 #else
   cublasDataType_t half_datatype = CUBLAS_DATA_HALF;
 #endif
-  CUBLAS_CALL(cublasSgemmEx(blas_handle,
+  auto algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
+  using TrueFP16Type = mshadow::half::half_t;
+  using PseudoFP16Type = typename CublasType<mshadow::half::half_t>::ScaleType;
+  TrueFP16Type trueFP16_alpha = static_cast<TrueFP16Type>(alpha);
+  TrueFP16Type trueFP16_beta = static_cast<TrueFP16Type>(beta);
+  PseudoFP16Type pseudoFP16_alpha = static_cast<PseudoFP16Type>(alpha);
+  PseudoFP16Type pseudoFP16_beta = static_cast<PseudoFP16Type>(beta);
+  const void *alpha_ptr;
+  const void *beta_ptr;
+  cudaDataType_t computeType;
+  bool use_true_fp16 = dmlc::GetEnv("MXNET_FC_TRUE_FP16", false);
+  if (use_true_fp16) {
+    alpha_ptr = &trueFP16_alpha;
+    beta_ptr = &trueFP16_beta;
+    computeType = CublasType<TrueFP16Type>::kCudaFlag;
+  } else {
+    alpha_ptr = &pseudoFP16_alpha;
+    beta_ptr = &pseudoFP16_beta;
+    computeType = CublasType<PseudoFP16Type>::kCudaFlag;
+  }
+  if (SupportsFloat16Compute(s->dev_id)) {
+    CUBLAS_CALL(cublasGemmEx(blas_handle,
                             (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
                             (tA ? CUBLAS_OP_T : CUBLAS_OP_N),
                             C.size(1), C.size(0), (tB ? B.size(1) : B.size(0)),
-                            &alpha_f,
+                            alpha_ptr,
                             B.dptr_, half_datatype, B.stride_,
                             A.dptr_, half_datatype, A.stride_,
-                            &beta_f,
-                            C.dptr_, half_datatype, C.stride_));
+                            beta_ptr,
+                            C.dptr_, half_datatype, C.stride_,
+                            computeType, algo));
+  } else {
+    // pseudo-fp16 (fp32 math with fp16 I/O)
+    if (use_true_fp16)
+      common::LogOnce("MXNET_FC_TRUE_FP16 was set but this architecture does 
not support it.");
+    float alpha_f = static_cast<float>(alpha);
+    float beta_f = static_cast<float>(beta);
+    CUBLAS_CALL(cublasSgemmEx(blas_handle,
+                             (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
+                             (tA ? CUBLAS_OP_T : CUBLAS_OP_N),
+                             C.size(1), C.size(0), (tB ? B.size(1) : 
B.size(0)),
+                             &alpha_f,
+                             B.dptr_, half_datatype, B.stride_,
+                             A.dptr_, half_datatype, A.stride_,
+                             &beta_f,
+                             C.dptr_, half_datatype, C.stride_));
+  }
 #if CUDA_VERSION >= 9000
   SetCublasMathMode(blas_handle, previous_math_mode);
 #endif
diff --git a/tests/python/gpu/test_gluon_gpu.py 
b/tests/python/gpu/test_gluon_gpu.py
index aa56eee..42a2424 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -615,6 +615,27 @@ def test_symbol_block_symbolic_bn_fp16_cast():
         y1 = net(x)
         assert np.dtype(y1.dtype).name == 'float16'
 
+@with_seed()
+def test_gemms_true_fp16():
+    ctx = mx.gpu(0)
+    input = mx.nd.random.uniform(shape=(1, 512), dtype='float16', ctx=ctx)
+    weights = mx.nd.random.uniform(shape=(128, 512), ctx=ctx)
+
+    net = nn.Dense(128, in_units=512, use_bias=False)
+    net.cast('float16')
+    net.initialize(ctx=ctx)
+    net.weight.set_data(weights)
+    ref_results = net(input)
+
+    os.environ["MXNET_FC_TRUE_FP16"] = "1"
+    results_trueFP16 = net(input)
+    atol = 1e-2
+    rtol = 1e-2
+    assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(),
+                        atol=atol, rtol=rtol)
+    os.environ["MXNET_FC_TRUE_FP16"] = "0"
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to