KellenSunderland closed pull request #13398: Use dynamic omp schedule for 
sparse dot with large matrix
URL: https://github.com/apache/incubator-mxnet/pull/13398
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index e77569671eb..11fd1e63c36 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -524,6 +524,37 @@ struct Kernel<OP, cpu> {
     return true;
   }
 
+  /*!
+   * \brief Launch a generic CPU kernel with dynamic schedule. This is 
recommended
+   * for irregular workloads such as spmv.
+   * When using this for a new kernel op, add declaration and tuning objects to
+   * operator_tune.cc
+   * \tparam Args Varargs type to eventually pass to the OP::Map() function
+   * \param N Number of iterations
+   * \param args Varargs to eventually pass to the OP::Map() function
+   */
+  template<typename ...Args>
+  inline static bool LaunchDynamic(mshadow::Stream<cpu> *, const int64_t N, 
Args... args) {
+#ifdef _OPENMP
+    const int omp_threads = 
engine::OpenMP::Get()->GetRecommendedOMPThreadCount(false);
+    if (omp_threads < 2) {
+      for (int64_t i = 0; i < N; ++i) {
+        OP::Map(i, args...);
+      }
+    } else {
+      #pragma omp parallel for num_threads(omp_threads) schedule(dynamic)
+      for (int64_t i = 0; i < N; ++i) {
+        OP::Map(i, args...);
+      }
+    }
+#else
+    for (int64_t i = 0; i < N; ++i) {
+      OP::Map(i, args...);
+    }
+#endif
+    return true;
+  }
+
   /*!
    * \brief Launch CPU kernel which has OMP tuning data available.
    * When using this for a new kernel op, add declaration and tuning objects to
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 5e469108eda..69c35f85c64 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -791,6 +791,14 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
               s, num_threads, data_out.dptr<DType>());
         }
         num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
+        bool dynamic = false;
+        const dim_t large_matrix_threshold = 1024 * 10;
+        if (data_out.shape_[0] > large_matrix_threshold) {
+          dynamic = true;
+          // each unit of work processes at least 1024 elements in the output
+          const dim_t unit_work_per_thread = 1024;
+          num_threads = data_out.Size() / unit_work_per_thread;
+        }
         dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
         if (trans_lhs) {
           mxnet_op::Kernel<DotCsrTransDnsDnsByRowBlocks, cpu>::Launch(s, 
num_threads,
@@ -798,10 +806,17 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
               col_idx_l.dptr<CType>(), data_r.dptr<DType>(), seg_len,
               lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]);
         } else {
-          mxnet_op::Kernel<DotCsrDnsDnsByRowBlocks, cpu>::Launch(s, 
num_threads,
-              data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
-              col_idx_l.dptr<CType>(), data_r.dptr<DType>(), seg_len,
-              data_out.shape_[0], data_out.shape_[1]);
+          if (dynamic) {
+            mxnet_op::Kernel<DotCsrDnsDnsByRowBlocks, cpu>::LaunchDynamic(s, 
num_threads,
+                data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
+                col_idx_l.dptr<CType>(), data_r.dptr<DType>(), seg_len,
+                data_out.shape_[0], data_out.shape_[1]);
+          } else {
+            mxnet_op::Kernel<DotCsrDnsDnsByRowBlocks, cpu>::Launch(s, 
num_threads,
+                data_out.dptr<DType>(), data_l.dptr<DType>(), 
indptr_l.dptr<IType>(),
+                col_idx_l.dptr<CType>(), data_r.dptr<DType>(), seg_len,
+                data_out.shape_[0], data_out.shape_[1]);
+          }
         }
       });
     });


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to