eric-haibin-lin commented on a change in pull request #7226: Extending the GPU 
dot operator
URL: https://github.com/apache/incubator-mxnet/pull/7226#discussion_r131818493
 
 

 ##########
 File path: src/operator/tensor/dot-inl.cuh
 ##########
 @@ -353,27 +640,308 @@ inline void DotCsrDnsDnsImpl(mshadow::Stream<gpu>* s,
 }
 
 /*!
- * \brief Impl of dot(csr.T, dns) = rsp
+ * \brief GPU Impl of dot(csr, dns) = rsp and dot(csr.T, dns) = rsp
  */
-inline void DotCsrDnsRspImpl(mshadow::Stream<gpu>* s,
+inline void DotCsrDnsRspImpl(const OpContext& ctx,
+                             const gpu& gpu_dev,
                              const NDArray& lhs,
                              const TBlob& rhs,
                              const OpReqType req,
                              const bool trans_lhs,
                              NDArray* ret) {
-  LOG(FATAL) << "DotCsrDnsRspImpl gpu version is not implemented.";
+  if (kNullOp == req) return;
+  CHECK_EQ(lhs.storage_type(), kCSRStorage);
+  CHECK_EQ(ret->storage_type(), kRowSparseStorage);
+  if (!lhs.storage_initialized()) return;
+
+  using mshadow::Shape1;
+  using mxnet_op::Kernel;
+  using mxnet_op::set_zero;
+  using nnvm::dim_t;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+
+  const TBlob data_l = lhs.data();
+  const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
+  const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
+  const TBlob& data_r = rhs;
+
+  const dim_t num_rows_l = lhs.shape()[0];
+  const dim_t num_cols_l = lhs.shape()[1];
+  const dim_t num_cols_r = rhs.shape_[1];
+  const dim_t threads_per_warp = mxnet_op::cuda_get_device_prop().warpSize;
+  dim_t num_threads;
+  // TODO: remove kernel dependency on warpSize=32
+  if (threads_per_warp != 32) {
+    LOG(FATAL) << "DotCsrDnsRspImpl GPU kernels expect warpSize=32";
+  }
+
+  MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
+    MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
+        if (trans_lhs) {
+          // Compute number of non-zero rows (nnr) of output matrix
+          // - alloc temp storage for row_flg array and for cub's prefix sum
+          // - mark non-zero columns of csr matrix in row_flg
+          // - compute inclusive prefix sum over marked array
+          // - copy last value (nnr_out) from device to host
+          dim_t* row_flg_out = NULL;
+          void* d_temp_storage = NULL;
+          size_t temp_storage_bytes = 0;
+          cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                        temp_storage_bytes,
+                                        row_flg_out,
+                                        row_flg_out,
+                                        num_cols_l,
+                                        mshadow::Stream<gpu>::GetStream(s));
+          mshadow::Tensor<gpu, 1, char> workspace = ctx.requested[0]
+              .get_space_typed<gpu, 1, 
char>(Shape1(num_cols_l*sizeof(dim_t)+temp_storage_bytes), s);
+          row_flg_out = reinterpret_cast<dim_t*>(workspace.dptr_);
+          d_temp_storage = workspace.dptr_ + num_cols_l*sizeof(dim_t);
+          num_threads = num_cols_l;
+          Kernel<set_zero, gpu>::Launch(s, num_threads, row_flg_out);
+          num_threads = num_rows_l * threads_per_warp;
+          Kernel<MarkCsrZeroColsWarpKernel, gpu>::Launch(s, num_threads,
+              row_flg_out, col_idx_l.dptr<CType>(), indptr_l.dptr<IType>(),
+              num_rows_l, num_cols_l);
+          cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                        temp_storage_bytes,
+                                        row_flg_out,
+                                        row_flg_out,
+                                        num_cols_l,
+                                        mshadow::Stream<gpu>::GetStream(s));
+          dim_t nnr_out = 0;
+          CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], 
sizeof(dim_t),
+                               cudaMemcpyDeviceToHost));
+
+          // Allocate output matrix space
+          ret->CheckAndAlloc({Shape1(nnr_out)});
+          const TBlob data_out_blob = ret->data();
+          const TBlob row_idx_out_blob = ret->aux_data(rowsparse::kIdx);
+          MSHADOW_IDX_TYPE_SWITCH(row_idx_out_blob.type_flag_, RType, {  // 
row idx type
+            DType* data_out = data_out_blob.dptr<DType>();
+            RType* row_idx_out = row_idx_out_blob.dptr<RType>();
+            if (kWriteTo == req) {
+              num_threads = nnr_out * num_cols_r;
+              Kernel<set_zero, gpu>::Launch(s, num_threads, data_out);
+            }
+            num_threads = nnr_out;
+            Kernel<set_zero, gpu>::Launch(s, num_threads, row_idx_out);
+
+            // Fill row_idx array of output matrix, using the row_flg values
+            num_threads = num_cols_l;
+            Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_threads,
+                row_idx_out, row_flg_out, num_cols_l);
+
+            // Perform matrix-matrix multiply
+            num_threads = threads_per_warp * num_rows_l * num_cols_r;
+            Kernel<DotCsrTransDnsRspWarpKernel, gpu>::Launch(s, num_threads,
+                data_out, row_flg_out,
+                data_l.dptr<DType>(), indptr_l.dptr<IType>(), 
col_idx_l.dptr<CType>(),
+                data_r.dptr<DType>(), num_cols_r);
+          });
+        } else {
+          LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns) = 
rsp yet.";
+        }
+      });
+    });
+  });
 }
 
 /*!
- * \brief Impl of dot(csr.T, rsp) = rsp2
+ * \brief GPU Impl of dot(csr, rsp1) = rsp2 and dot(csr.T, rsp1) = rsp2
+ * TODO: Optimize for GPU; this is a baseline implementation providing
+ *       the operator functionality, it is not yet fully optimized for GPU.
  */
-inline void DotCsrRspRspImpl(mshadow::Stream<gpu>* s,
+inline void DotCsrRspRspImpl(const OpContext& ctx,
+                             const gpu& gpu_dev,
                              const NDArray& lhs,
                              const NDArray& rhs,
                              const OpReqType req,
                              const bool trans_lhs,
                              NDArray* ret) {
-  LOG(FATAL) << "DotCsrRspRspImpl gpu version is not implemented.";
+  // Reuse dot(csr, dns) implementation if rhs rsp matrix is in fact dense
+  if (rhs.storage_shape()[0] == rhs.shape()[0]) {
+    DotCsrDnsRspImpl(ctx, gpu_dev, lhs, rhs.data(), req, trans_lhs, ret);
+    return;
+  }
+  if (kNullOp == req) return;
+  CHECK_EQ(lhs.storage_type(), kCSRStorage);
 
 Review comment:
   Also CHECK_NE(req, kAddTo) since the result is rsp
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to