szha commented on a change in pull request #12804: CudnnFind() usage 
improvements
URL: https://github.com/apache/incubator-mxnet/pull/12804#discussion_r227974194
 
 

 ##########
 File path: src/operator/nn/cudnn/cudnn_convolution-inl.h
 ##########
 @@ -611,236 +612,274 @@ class CuDNNConvolutionOp {
     }
   }
 
-  void SelectAlgo(const RunContext& rctx,
+  void CuDNNAlgoSetter(const RunContext& rctx,
                   const std::vector<TShape>& in_shape,
                   const std::vector<TShape>& out_shape,
                   cudnnDataType_t cudnn_forward_compute_type,
-                  cudnnDataType_t cudnn_backward_compute_type) {
-    if (!CuDNNConvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_,
-                                       cudnn_forward_compute_type, 
cudnn_backward_compute_type,
-                                       SMArch(rctx.ctx.dev_id), add_to_weight_,
-                                       &forward_algo_, &back_algo_, 
&back_algo_w_)) {
-      mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
-      CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
-      size_t workspace_byte = static_cast<size_t>(param_.workspace * 
sizeof(DType));
-      #if CUDNN_MAJOR >= 7
-      // Starting with cuDNNv7, the algo number returned by *Get*() is not the 
entire
-      // story: the notion of whether the algo ran in Tensor Core mode is not 
known.
-      // Since we want to report the Tensor Core mode in the verbose output, 
we switch
-      // to using the new *Get*_v7() call.  Since the function signature of 
*Get*_v7() matches
-      // that of *Find*(), we can unify the find-vs-get logic by using 
function pointers.
-
-      // Forward Algorithm Find/Get() v7
-      std::vector<cudnnConvolutionFwdAlgoPerf_t> 
fwd_results(MaxForwardAlgos(s->dnn_handle_));
-      int actual_fwd_algos = 0;
-      auto fwd_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? 
cudnnGetConvolutionForwardAlgorithm_v7
-                                                : 
cudnnFindConvolutionForwardAlgorithm;
-      CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
-                                        in_desc_,
-                                        filter_desc_,
-                                        forward_conv_desc_,
-                                        out_desc_,
-                                        fwd_results.size(),
-                                        &actual_fwd_algos,
-                                        fwd_results.data()));
-      fwd_results.resize(actual_fwd_algos);
-      AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
-                      cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
-                                                 workspace_byte, 
&forward_algo_);
-
-      // Backprop-to-Filter Algorithm Find/Get() v7
-      auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
-      std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> 
bwd_filt_results(max_bwd_filt_algos);
-      int actual_bwd_filter_algos = 0;
-      // In cudnn v7.1.4, find() returned wgrad algos that could fail for 
large c if we
-      // were summing into the output (i.e. beta != 0).  Get() returned OK 
algos though.
-      auto bwd_filter_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? 
cudnnGetConvolutionBackwardFilterAlgorithm_v7
-                                                : 
cudnnFindConvolutionBackwardFilterAlgorithm;
-      CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+                  cudnnDataType_t cudnn_backward_compute_type,
+                  CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+                  CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+                  CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+    // Not in algo registry, must determine via *Get*() or *Find*()
+    mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
+    CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
+    size_t workspace_byte = static_cast<size_t>(param_.workspace * 
sizeof(DType));
+#if CUDNN_MAJOR >= 7
+    // Starting with cuDNNv7, the algo number returned by *Get*() is not the 
entire
+    // story: the notion of whether the algo ran in Tensor Core mode is not 
known.
+    // Since we want to report the Tensor Core mode in the verbose output, we 
switch
+    // to using the new *Get*_v7() call.  Since the function signature of 
*Get*_v7() matches
+    // that of *Find*(), we can unify the find-vs-get logic by using function 
pointers.
+
+    // Forward Algorithm Find/Get() v7
+    std::vector<cudnnConvolutionFwdAlgoPerf_t> 
fwd_results(MaxForwardAlgos(s->dnn_handle_));
+    int actual_fwd_algos = 0;
+    auto fwd_algo_discoverer =
+      param_.cudnn_tune.value() == conv::kOff ? 
cudnnGetConvolutionForwardAlgorithm_v7
+                                              : 
cudnnFindConvolutionForwardAlgorithm;
+    CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
+                                      in_desc_,
+                                      filter_desc_,
+                                      forward_conv_desc_,
+                                      out_desc_,
+                                      fwd_results.size(),
+                                      &actual_fwd_algos,
+                                      fwd_results.data()));
+    fwd_results.resize(actual_fwd_algos);
+    AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
+                    cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
+                                               workspace_byte, fwd);
+
+    // Backprop-to-Filter Algorithm Find/Get() v7
+    auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
+    std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> 
bwd_filt_results(max_bwd_filt_algos);
+    int actual_bwd_filter_algos = 0;
+    // In cudnn v7.1.4, find() returned wgrad algos that could fail for large 
c if we
+    // were summing into the output (i.e. beta != 0).  Get() returned OK algos 
though.
+    auto bwd_filter_algo_discoverer =
+      param_.cudnn_tune.value() == conv::kOff ? 
cudnnGetConvolutionBackwardFilterAlgorithm_v7
+                                              : 
cudnnFindConvolutionBackwardFilterAlgorithm;
+    CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+                                             in_desc_,
+                                             out_desc_,
+                                             back_conv_desc_w_,
+                                             filter_desc_,
+                                             bwd_filt_results.size(),
+                                             &actual_bwd_filter_algos,
+                                             bwd_filt_results.data()));
+    bwd_filt_results.resize(actual_bwd_filter_algos);
+    AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
+                    cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, 
"backprop-to-filter",
+                                                     workspace_byte, flt);
+
+    // Backprop-to-Data Algorithm Find/Get() v7
+    auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
+    std::vector<cudnnConvolutionBwdDataAlgoPerf_t> 
bwd_data_results(max_bwd_data_algos);
+    int actual_bwd_data_algos = 0;
+    auto bwd_data_algo_discoverer =
+      param_.cudnn_tune.value() == conv::kOff ? 
cudnnGetConvolutionBackwardDataAlgorithm_v7
+                                              : 
cudnnFindConvolutionBackwardDataAlgorithm;
+    CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+                                           filter_desc_,
+                                           out_desc_,
+                                           back_conv_desc_,
+                                           in_desc_,
+                                           bwd_data_results.size(),
+                                           &actual_bwd_data_algos,
+                                           bwd_data_results.data()));
+    bwd_data_results.resize(actual_bwd_data_algos);
+    AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
+                    cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, 
"backprop-to-data",
+                                                   workspace_byte, bwd);
+#else
+    // CUDNN_MAJOR < 7
+    const int kMaxAlgos = 10;
+    int nalgo = kMaxAlgos;
+    int i = 0;
+    size_t min_memory_needs = 0;
+    // Forward Algorithm Find/Get, v6 and earlier
+    if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
+      // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM 
is
+      // supported.  Hard-coded this since the algo find() or get() throws an 
FPE.
+      fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
+    } else if (!param_.cudnn_tune.value()) {
+      cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
+      CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
                                                in_desc_,
-                                               out_desc_,
-                                               back_conv_desc_w_,
                                                filter_desc_,
-                                               bwd_filt_results.size(),
-                                               &actual_bwd_filter_algos,
-                                               bwd_filt_results.data()));
-      bwd_filt_results.resize(actual_bwd_filter_algos);
-      AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
-                      cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, 
"backprop-to-filter",
-                                   workspace_byte, &back_algo_w_);
-
-      // Backprop-to-Data Algorithm Find/Get() v7
-      auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
-      std::vector<cudnnConvolutionBwdDataAlgoPerf_t> 
bwd_data_results(max_bwd_data_algos);
-      int actual_bwd_data_algos = 0;
-      auto bwd_data_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? 
cudnnGetConvolutionBackwardDataAlgorithm_v7
-                                                : 
cudnnFindConvolutionBackwardDataAlgorithm;
-      CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
-                                             filter_desc_,
-                                             out_desc_,
-                                             back_conv_desc_,
-                                             in_desc_,
-                                             bwd_data_results.size(),
-                                             &actual_bwd_data_algos,
-                                             bwd_data_results.data()));
-      bwd_data_results.resize(actual_bwd_data_algos);
-      AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
-                      cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, 
"backprop-to-data",
-                                    workspace_byte, &back_algo_);
-      #else
-      // CUDNN_MAJOR < 7
-      const int kMaxAlgos = 10;
-      int nalgo = kMaxAlgos;
-      int i = 0;
-      size_t min_memory_needs = 0;
-      // Forward Algorithm Find/Get, v6 and earlier
-      if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
-        // In cuDNNv6, for kNHWC, only 
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
-        // supported.  Hard-coded this since the algo find() or get() throws 
an FPE.
-        forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
-      } else if (!param_.cudnn_tune.value()) {
-        cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
-        CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
-                                                 in_desc_,
-                                                 filter_desc_,
-                                                 forward_conv_desc_,
-                                                 out_desc_,
-                                                 
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
-                                                 workspace_byte,
-                                                 &fastest_fwd_algo));
-        forward_algo_.Set(fastest_fwd_algo, false);
-      } else {
-        cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
-                                                        in_desc_,
-                                                        filter_desc_,
-                                                        forward_conv_desc_,
-                                                        out_desc_,
-                                                        kMaxAlgos,
-                                                        &nalgo,
-                                                        fwd_algo));
-        i = 0;
-        while (i < nalgo
-               && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
-                   || (param_.cudnn_tune.value() == conv::kLimited
-                       && fwd_algo[i].memory > workspace_byte))) {
-          ++i;
-          min_memory_needs =
-            (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs, 
fwd_algo[i].memory);
-        }
-        if (i == nalgo) {
-          LOG(FATAL) << nalgo << " forward algorithms with minimum memory 
requirement "
-                     << min_memory_needs << " bytes have been tried. Workspace 
size is set to "
-                     << workspace_byte << " bytes, please consider reducing 
the batch/model size, "
-                     << "or increasing workspace size.";
-        } else {
-          forward_algo_.Set(fwd_algo[i].algo, false);
-        }
+                                               forward_conv_desc_,
+                                               out_desc_,
+                                               
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+                                               workspace_byte,
+                                               &fastest_fwd_algo));
+      fwd->Set(fastest_fwd_algo, false);
+    } else {
+      cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
+      CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
+                                                      in_desc_,
+                                                      filter_desc_,
+                                                      forward_conv_desc_,
+                                                      out_desc_,
+                                                      kMaxAlgos,
+                                                      &nalgo,
+                                                      fwd_algo));
+      i = 0;
+      while (i < nalgo
+             && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
+                 || (param_.cudnn_tune.value() == conv::kLimited
+                     && fwd_algo[i].memory > workspace_byte))) {
+        ++i;
+        min_memory_needs =
+          (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs, 
fwd_algo[i].memory);
       }
-      // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
-      if (!param_.cudnn_tune.value()) {
-        cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
-        CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                                          in_desc_,
-                                          out_desc_,
-                                          back_conv_desc_w_,
-                                          filter_desc_,
-                                          
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
-                                          workspace_byte,
-                                          &fastest_bwd_filt_algo));
-        back_algo_w_.Set(fastest_bwd_filt_algo, false);
+      if (i == nalgo) {
+        LOG(FATAL) << nalgo << " forward algorithms with minimum memory 
requirement "
+                   << min_memory_needs << " bytes have been tried. Workspace 
size is set to "
+                   << workspace_byte << " bytes, please consider reducing the 
batch/model size, "
+                   << "or increasing workspace size.";
 
 Review comment:
   There are several messages that look alike. Consider wrapping it into a macro

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to