eric-haibin-lin commented on a change in pull request #10371: [MXNET-263] 
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU
URL: https://github.com/apache/incubator-mxnet/pull/10371#discussion_r183603340
 
 

 ##########
 File path: src/operator/tensor/dot-inl.cuh
 ##########
 @@ -895,6 +991,151 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx,
   });
 }
 
+/*
+ * \brief GPU Impl of dot(dns, csr) = csr
+ */
+inline void DotDnsCsrCsrImpl(const OpContext& ctx, const gpu& gpu_dev,
+                             const TBlob& lhs, const NDArray& rhs,
+                             const OpReqType req, NDArray* ret) {
+  LOG(FATAL) << "dot(dense, csr) = csr is not implemented on GPU";
+}
+
+/*
+ * \brief GPU Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns
+ */
+inline void DotDnsCsrDnsImpl(const OpContext& ctx, const gpu& gpu_dev,
+                             const TBlob& dns, const NDArray& rhs,
+                             const OpReqType req, NDArray* ret,
+                             const bool transpose_b) {
+  if (req == kNullOp) {
+    return;
+  }
+  CHECK_EQ(req, kWriteTo);
+  CHECK_EQ(rhs.storage_type(), kCSRStorage);
+
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using nnvm::dim_t;
+
+  /* Initialize data structures */
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  TBlob csr_data = rhs.data();
+  TBlob csr_indices = rhs.aux_data(csr::kIdx);
+  TBlob csr_indptr = rhs.aux_data(csr::kIndPtr);
+  if (!rhs.storage_initialized()) {
+    FillZerosCsrImpl(s, *ret);
+    return;
+  }
+
+  MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {     // data type
+    MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {     // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {  // colidx type
+        const nnvm::dim_t out_num_rows = ret->shape()[0];
+        const nnvm::dim_t out_num_cols = ret->shape()[1];
+        // if dot(dense, csr) = dns, transform to csc first
+        if (!transpose_b) {
+          const nnvm::dim_t num_csr_rows = rhs.shape()[0];
+          const nnvm::dim_t num_csr_cols = rhs.shape()[1];
+          const nnvm::dim_t num_dns_rows = dns.shape_[0];
+          const nnvm::dim_t nnz = rhs.storage_shape().Size();
+
+          IType* original_idx_ptr = nullptr;
+          IType* csc_indices_ptr = nullptr;
+          IType* csc_cols_ptr = nullptr;
+          CType* csr_rows_ptr = nullptr;
+          CType* csc_indptr_ptr = nullptr;
+          DType* csc_data_ptr = nullptr;
+          char* temp_storage_ptr = nullptr;
+          size_t original_idx_bytes = nnz*sizeof(IType);
+          size_t csc_indices_bytes = nnz*sizeof(IType);
+          size_t csc_cols_bytes = nnz*sizeof(IType);
+          size_t csr_rows_bytes = nnz*sizeof(CType);
+          size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType);
+          size_t csc_data_bytes = nnz*sizeof(DType);
+          size_t scan_temp_storage_bytes = 0;
+          size_t temp_storage_bytes = SortByKeyWorkspaceSize<dim_t, dim_t, 
gpu>(nnz);
 
 Review comment:
   nit: it should be `SortByKeyWorkspaceSize<IType, IType, gpu>(nnz);`? We 
might add 32bit int for idx dtype in the future to speedup / reduce memory

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to