eric-haibin-lin commented on a change in pull request #10371: [MXNET-263] 
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU
URL: https://github.com/apache/incubator-mxnet/pull/10371#discussion_r183613011
 
 

 ##########
 File path: src/operator/tensor/dot-inl.cuh
 ##########
 @@ -895,6 +997,150 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx,
   });
 }
 
+/*
+ * \brief GPU Impl of dot(dns, csr) = csr
+ */
+template<typename gpu>
+inline void DotDnsCsrCsrImpl(const OpContext& ctx,
+                             const TBlob& lhs, const NDArray& rhs,
+                             const OpReqType req, NDArray* ret) {
+  LOG(FATAL) << "dot(dense, csr) = csr is not implemented on GPU";
+}
+
+/*
+ * \brief GPU Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns
+ */
+template<typename gpu>
+inline void DotDnsCsrDnsImpl(const OpContext& ctx,
+                             const TBlob& dns, const NDArray& rhs,
+                             const OpReqType req, NDArray* ret,
+                             const bool transpose_b) {
+  CHECK_EQ(req, kWriteTo);
+  CHECK_EQ(rhs.storage_type(), kCSRStorage);
+
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using nnvm::dim_t;
+
+  /* Initialize data structures */
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  TBlob csr_data = rhs.data();
+  TBlob csr_indices = rhs.aux_data(csr::kIdx);
+  TBlob csr_indptr = rhs.aux_data(csr::kIndPtr);
+  if (!rhs.storage_initialized()) {
+    FillZerosCsrImpl(s, *ret);
+    return;
+  }
+
+  MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {     // data type
+    MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {     // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {  // colidx type
+        const CType out_num_rows = ret->shape()[0];
+        const CType out_num_cols = ret->shape()[1];
+        // if dot(dense, csr) = dns, transform to csc first
+        if (!transpose_b) {
+          const nnvm::dim_t num_csr_rows = rhs.shape()[0];
+          const nnvm::dim_t num_csr_cols = rhs.shape()[1];
+          const nnvm::dim_t num_dns_rows = dns.shape_[0];
+          const nnvm::dim_t nnz = rhs.storage_shape().Size();
+
+          IType* original_idx_ptr = nullptr;
+          IType* csc_indices_ptr = nullptr;
+          IType* csc_cols_ptr = nullptr;
+          CType* csr_rows_ptr = nullptr;
+          CType* csc_indptr_ptr = nullptr;
+          DType* csc_data_ptr = nullptr;
+          char* temp_storage_ptr = nullptr;
+          size_t original_idx_bytes = nnz*sizeof(IType);
+          size_t csc_indices_bytes = nnz*sizeof(IType);
+          size_t csc_cols_bytes = nnz*sizeof(IType);
+          size_t csr_rows_bytes = nnz*sizeof(CType);
+          size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType);
+          size_t csc_data_bytes = nnz*sizeof(DType);
+          size_t scan_temp_storage_bytes = 0;
+          size_t temp_storage_bytes = SortByKeyWorkspaceSize<dim_t, dim_t, 
gpu>(nnz);
+          IType* csr_indices_ptr = csr_indices.dptr<IType>();
+          cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
+                                        scan_temp_storage_bytes,
+                                        csc_indptr_ptr,
+                                        csc_indptr_ptr,
+                                        num_csr_cols+1,
+                                        mshadow::Stream<gpu>::GetStream(s));
+          temp_storage_bytes = std::max(temp_storage_bytes, 
scan_temp_storage_bytes);
+          temp_storage_bytes += (sizeof(dim_t) - temp_storage_bytes % 
sizeof(dim_t));
+          size_t total_workspace_bytes =
+            original_idx_bytes + csc_indices_bytes + csc_cols_bytes + 
csr_rows_bytes +
+            csc_indptr_bytes + csc_data_bytes + temp_storage_bytes;
+          total_workspace_bytes += (sizeof(IType) - total_workspace_bytes % 
sizeof(IType));
+          Tensor<gpu, 1, char> workspace = ctx.requested[0]
+              .get_space_typed<gpu, 1, char>(Shape1(total_workspace_bytes), s);
+          original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_);
+          csc_indices_ptr = reinterpret_cast<IType*>(workspace.dptr_ + 
original_idx_bytes);
+          csc_cols_ptr = reinterpret_cast<IType*>(workspace.dptr_ + 
original_idx_bytes +
+                                                  csc_indices_bytes);
+          csr_rows_ptr = reinterpret_cast<CType*>(workspace.dptr_ + 
original_idx_bytes +
+                                                  csc_indices_bytes + 
csc_cols_bytes);
+          csc_indptr_ptr = reinterpret_cast<CType*>(workspace.dptr_ + 
original_idx_bytes +
+                                                    csc_indices_bytes + 
csc_cols_bytes +
+                                                    csr_rows_bytes);
+          temp_storage_ptr = workspace.dptr_ + original_idx_bytes + 
csc_indices_bytes +
+                             csc_cols_bytes + csr_rows_bytes + 
csc_indptr_bytes;
+          csc_data_ptr = reinterpret_cast<DType*>(
+                           workspace.dptr_ + total_workspace_bytes - 
csc_data_bytes);
+
+          // Fill original_idx
+          mxnet_op::Kernel<range_fwd, gpu>::Launch(
+            s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
+          // Fill csc_cols with copy of csr_indices
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, 
kWriteTo>, gpu>::Launch(
+            s, nnz, csc_cols_ptr, csr_indices_ptr);
+          // Allocate the tensors needed for SortByKey
+          Tensor<gpu, 1, IType> original_idx(original_idx_ptr, Shape1(nnz), s);
+          Tensor<gpu, 1, IType> csc_cols(csc_cols_ptr, Shape1(nnz), s);
+          Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, 
Shape1(temp_storage_bytes), s);
+
+          int num_bits = 1;
+          unsigned int a = num_csr_cols - 1;
+          while (a >>= 1) num_bits++;
 
 Review comment:
   `num_csr_cols` is `size_t` type. if you change `ilog2(unsigned int)` to 
`ilog2(size_t)`, it should work. 
   Note you cannot have both `ilog2(unsigned int)` and `ilog2(size_t)` since 
they will confuse compiler for the input is unsigned. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to