eric-haibin-lin commented on a change in pull request #8938: Add operator for 
dot(dns, csr) = csr
URL: https://github.com/apache/incubator-mxnet/pull/8938#discussion_r156471898
 
 

 ##########
 File path: src/operator/tensor/dot-inl.h
 ##########
 @@ -811,6 +891,94 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
   });
 }
 
+/*
+ * \brief CPU Impl of dot(dns, csr) = csr
+ */
+inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
+                             const TBlob& lhs, const NDArray& rhs,
+                             const OpReqType req, NDArray* ret) {
+  if (kNullOp == req) return;
+  CHECK_EQ(rhs.storage_type(), kCSRStorage);
+  if (!rhs.storage_initialized()) return;
+
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using nnvm::dim_t;
+
+  /*Initialize data structures*/
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const NDArray& out = *ret;
+  const TBlob data_l = lhs;
+  const TBlob data_r = rhs.data();
+  const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
+  const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
+
+  MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, {     // data type
+    MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, {     // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, {  // colidx type
+        /* Allocate workspace */
+        CType num_cols_out = out.shape()[1];
+        CType rhs_data_size = static_cast<CType>(col_idx_r.shape_.Size());
+        size_t workspace_size = 2 * num_cols_out * sizeof(CType);
+        Tensor<cpu, 1, char> workspace =
+            ctx.requested[0].get_space_typed<cpu, 1, char>(
+                Shape1(workspace_size), s);
+        CType* col_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
+
+        CType* prefix_sum = col_flg;
+        CType* nnc_idx = prefix_sum + num_cols_out;
+
+        /* Set the column flags for nnz columns */
+        mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_cols_out,
+                                                          col_flg);
+        mxnet_op::Kernel<MarkRowFlgKernel, cpu>::Launch(
+            s, rhs_data_size, col_flg, col_idx_r.dptr<CType>());
+
+        /* 1. Calculate prefix sum from col flgs
+         * 2. Storage all non zero column indexes in nnc_idx
+         */
+        CType cur = 0;
+        prefix_sum[0] = col_flg[0];
+        if (prefix_sum[0]) nnc_idx[cur++] = 0;
+        for (CType i = 1; i < num_cols_out; i++) {
+          prefix_sum[i] = prefix_sum[i - 1] + col_flg[i];
+          if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i;
+        }
+
+        /* Allocate aux data for out */
+        IType num_rows_l = lhs.shape_[0];
+        dim_t nnc = prefix_sum[num_cols_out - 1];
+        dim_t nnz = nnc * num_rows_l;
+        out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1));
+        out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz));
+        out.CheckAndAllocData(Shape1(nnz));
+
+        /* Set csr indptr and index according to nnc_idx*/
+        IType* indptr_out = out.aux_data(csr::kIndPtr).dptr<IType>();
+        CType* col_idx_out = out.aux_data(csr::kIdx).dptr<CType>();
+        DType* data_out = out.data().dptr<DType>();
+        mxnet_op::Kernel<PopulateCsrForNNC, cpu>::Launch(
+            s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l);
+        mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out);
+
+        if (nnc == 0) {
 
 Review comment:
   Because you already checked `rhs.storage_initialized()` in line 922?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to