haojin2 commented on a change in pull request #10371: [MXNET-263] [WIP] Support 
for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU
URL: https://github.com/apache/incubator-mxnet/pull/10371#discussion_r179234938
 
 

 ##########
 File path: src/operator/tensor/dot-inl.cuh
 ##########
 @@ -442,6 +442,99 @@ struct DotCsrRspDnsScalarKernel {
   }
 };
 
+/*!
+ * \brief GPU Kernel to re-arrange nnz elements to csc order
+ * Parallelization by output elements: 1 thread/row of csr
+ */
+struct CscDataIndicesKernel {
+  template<typename DType, typename IType, typename CType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             const DType* csr_data,
+                                             const IType* csr_indices,
+                                             const CType* csr_indptr,
+                                             DType* csc_data,
+                                             IType* csc_indices,
+                                             CType* csc_indptr,
+                                             int* workspace,
+                                             const nnvm::dim_t num_rows,
+                                             const nnvm::dim_t num_cols) {
+    if (tid < num_rows) {
+      for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) {
+        // target column
+        IType target_col = csr_indices[i];
+        int target_offset = atomicAdd(&workspace[target_col], 1);
+        CType new_pos = csc_indptr[target_col] + target_offset;
+        csc_data[new_pos] = csr_data[i];
+        csc_indices[new_pos] = tid;
+      }
+    }
+  }
+};
+
+/*!
+ * \brief GPU Kernel of getting count for every column
+ * Parallelization by output elements: 1 thread/element
+ */
+struct CsrTransHistogramKernel {
+  /*!
+   * \brief
+   * \param tid          global thread id
+   * \param in_indices   csr matrix column indices
+   * \param out_indptr   csr matrix row pointer
+   * \param nnz          number of non-zero elements in csr
+   */
+  template<typename IType, typename CType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             const IType* in_indices,
+                                             CType* out_indptr,
+                                             const nnvm::dim_t nnz) {
+    if (tid < nnz) {
+      atomicAdd(&out_indptr[in_indices[tid] + 1], 1);
+    }
+  }
+};
+
+/*!
+ * \brief GPU Kernel of dot(dns, csr.T) = dns
+ * Parallelization by output elements: 1 thread/element
+ */
+struct DotDnsCsrTransDnsKernel {
+  /*!
+   * \brief
+   * \param tid          global thread id
+   * \param lhs_data     lhs dense matrix data
+   * \param rhs_data     csr matrix data
+   * \param rhs_indices  csr matrix column indices
+   * \param rhs_indptr   csr matrix row pointer
+   * \param out          output matrix data
+   * \param lhs_num_cols lhs dns matrix number of columns
+   * \param out_num_rows output dns matrix number of rows
+   * \param out_num_cols output dns matrix number of columns
+   */
+  template<typename DType, typename IType, typename CType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             const DType* lhs_data,
+                                             const DType* rhs_data,
+                                             const IType* rhs_indices,
+                                             const CType* rhs_indptr,
+                                             DType* out,
+                                             const nnvm::dim_t lhs_num_cols,
+                                             const nnvm::dim_t out_num_rows,
+                                             const nnvm::dim_t out_num_cols) {
+    using nnvm::dim_t;
+    if (tid < out_num_rows*out_num_cols) {
+      const dim_t i = static_cast<dim_t>(tid) / out_num_cols;  // i = row this 
thread computes
+      const dim_t k = static_cast<dim_t>(tid) % out_num_cols;  // k = col this 
thread computes
+      // Compute inner product of i-th row and k-th col
+      DType sum = 0;
+      for (CType col_id = rhs_indptr[k]; col_id < rhs_indptr[k + 1]; ++col_id) 
{
+        sum += lhs_data[i * lhs_num_cols + rhs_indices[col_id]] * 
rhs_data[col_id];
+      }
+      out[i*out_num_cols+k] = sum;
 
 Review comment:
   Good catch, done.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to