reminisce commented on a change in pull request #7226: Extending the GPU dot 
operator
URL: https://github.com/apache/incubator-mxnet/pull/7226#discussion_r132098902
 
 

 ##########
 File path: src/operator/tensor/dot-inl.cuh
 ##########
 @@ -199,37 +316,203 @@ struct DotCsrTransDnsDnsThreadBlockKernel {
 };
 
 /*!
- * \brief Warp block kernel of dot(csr.T(), dns1) = dns2
+ * \brief GPU warp block kernel of dot(csr.T, dns1) = dns2
  * Parallelization by columns: 1 warp computes one lhs column for all rhs 
columns
  */
-template<int req>
 struct DotCsrTransDnsDnsWarpBlockKernel {
+  /*!
+   * \brief see DotCsrTransDnsDnsScalarKernel Map for documentation.
+   */
   template<typename DType, typename IType, typename CType>
-  __device__ __forceinline__ static void Map(int tid, DType* out, const DType* 
data_l, const IType* indptr_l,
-                                             const CType* col_idx_l, const 
DType* data_r,
-                                             const int num_cols_r) {
-    const int warp_id = tid / 32;   // global warp id
-    const int lane = tid & (32-1);  // local thread id within warp
-    const int icol = warp_id;       // lhs column that this warp computes
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t warp_id = tid / 32;   // global warp id
+    const dim_t lane = tid & (32-1);  // local thread id within warp
+    const dim_t icol = warp_id;       // lhs column that this warp computes
 
     // Compute range of nnz elements in this column
-    const int low  = static_cast<int>(indptr_l[icol]);
-    const int high = static_cast<int>(indptr_l[icol+1]);
+    const dim_t low  = static_cast<dim_t>(indptr_l[icol]);
+    const dim_t high = static_cast<dim_t>(indptr_l[icol+1]);
 
     // Iterate through the nnz elements in lhs column
-    for (int j = low+lane; j < high; j+=32) {
-      const int irow = static_cast<int>(col_idx_l[j]);
+    for (dim_t j = low+lane; j < high; j+=32) {
+      const dim_t irow = static_cast<dim_t>(col_idx_l[j]);
       const DType datum_l = data_l[j];
       // Iterate over all rhs columns
-      for (int k = 0; k < num_cols_r; k++) {
+      for (dim_t k = 0; k < num_cols_r; k++) {
         const DType val = datum_l*data_r[icol*num_cols_r+k];
         atomicAdd(static_cast<DType *>(&(out[irow*num_cols_r+k])), val);
       }
     }
   }
 };
 
-inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
+/*!
+ * \brief GPU warp kernel of dot(csr.T, dns) = rsp
+ * Parallelization by columns: 1 warp computes one lhs column for one rhs 
column
+ */
+struct DotCsrTransDnsRspWarpKernel {
+  /*!
+   * \brief
+   * \param tid              global thread id
+   * \param out              output rsp matrix data
+   * \param row_flg_sum_out  inclusive prefix sum array over 0/1 marked row 
flag array
+   * \param data_l           csr matrix data
+   * \param indptr_l         csr matrix row index pointer
+   * \param col_idx_l        csr matrix column indices
+   * \param data_r           dns matrix data
+   * \param num_cols_r       dns matrix number of columns
+   */
+  template<typename DType, typename IType, typename CType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const nnvm::dim_t* 
row_flg_sum_out,
+                                             const DType* data_l,
+                                             const IType* indptr_l,
+                                             const CType* col_idx_l,
+                                             const DType* data_r,
+                                             const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t warp_id = tid / 32;           // global warp id
+    const dim_t lane = tid & (32-1);          // local thread id within warp
+    const dim_t icol = warp_id / num_cols_r;  // lhs column that this warp 
computes
+    const dim_t kcol = warp_id % num_cols_r;  // rhs column that this warp 
computes
+
+    // Compute range of nnz elements in this column
+    const dim_t low  = static_cast<dim_t>(indptr_l[icol]);
 
 Review comment:
   I'm confused here. If `icol` is the column id of lhs, how come it is applied 
in `indptr_l` (indexed by the row id of lhs)?
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to