haojin2 commented on a change in pull request #11179: [MXNET-404] 
elemwise_add/sub between rsp and rsp on GPU
URL: https://github.com/apache/incubator-mxnet/pull/11179#discussion_r194243373
 
 

 ##########
 File path: src/operator/tensor/elemwise_binary_op_basic.cu
 ##########
 @@ -22,12 +22,141 @@
  * \file elemwise_binary_scalar_op.cu
  * \brief GPU Implementation of unary function.
  */
+#include <cub/cub.cuh>
 #include "./elemwise_binary_op.h"
 #include "./elemwise_binary_op-inl.h"
 
 namespace mxnet {
 namespace op {
 
+template<typename OP>
+struct RspElemwiseKernel {
+  template<typename DType, typename IType>
+  static MSHADOW_XINLINE void Map(int i, DType* out, const IType* lookup_table,
+                                  const DType* data, const IType* indices,
+                                  const nnvm::dim_t nz_rows, const nnvm::dim_t 
num_cols) {
+    if (i < nz_rows * num_cols) {
+      const nnvm::dim_t row = i / num_cols;
+      const nnvm::dim_t col = i % num_cols;
+      const nnvm::dim_t out_row = lookup_table[indices[row]] - 1;
+      const nnvm::dim_t out_idx = out_row * num_cols + col;
+      out[out_idx] = OP::Map(out[out_idx], data[i]);
+    }
+  }
+};
+
+template<typename OP>
+void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s,
+                                const nnvm::NodeAttrs &attrs,
+                                const OpContext &ctx,
+                                const NDArray &lhs,
+                                const NDArray &rhs,
+                                const OpReqType req,
+                                const NDArray &output,
+                                const bool lhs_may_be_dense,
+                                const bool rhs_may_be_dense,
+                                const bool allow_inplace,
+                                const bool scatter) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace mshadow::expr;
+  using namespace rowsparse;
+
+  CHECK(!scatter) << "scatter is not supported in RspRspOp on GPU yet...";
+  CHECK(lhs.storage_type() == kRowSparseStorage && rhs.storage_type() == 
kRowSparseStorage);
+  CHECK(output.storage_type() == kRowSparseStorage);
+
+  const nnvm::dim_t num_rows = output.shape()[0];
+  MSHADOW_TYPE_SWITCH(lhs.data().type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(lhs.aux_data(kIdx).type_flag_, IType, {
+      if (lhs.storage_initialized() && rhs.storage_initialized()) {
+        const nnvm::dim_t lhs_nz_rows = lhs.storage_shape()[0];
+        const nnvm::dim_t rhs_nz_rows = rhs.storage_shape()[0];
+        const nnvm::dim_t num_cols = lhs.data().Size() / lhs_nz_rows;
+        // Optimize for the case where one of the rsps is actually dense
+        if ((lhs_nz_rows == num_rows || rhs_nz_rows == num_rows) && req == 
kWriteInplace) {
+          const NDArray& dns = (output.IsSame(lhs)) ? lhs : rhs;
+          const NDArray& rsp = (output.IsSame(lhs)) ? rhs : lhs;
+          const bool reverse = !(lhs_nz_rows == num_rows);
+          ElemwiseBinaryOp::DnsRspDnsOp<gpu, OP>(s, attrs, ctx, dns, rsp, req, 
output, reverse);
+          return;
+        }
+        CHECK(req == kWriteTo) << "Should be kWriteTo but got " << req;
+        const TBlob& lhs_indices = lhs.aux_data(kIdx);
+        const TBlob& rhs_indices = rhs.aux_data(kIdx);
+        size_t common_row_table_bytes = num_rows * sizeof(IType);
+        IType* common_row_table = NULL;
+        void* temp_storage_ptr = NULL;
+        size_t temp_storage_bytes = 0;
+        cub::DeviceScan::InclusiveSum(temp_storage_ptr,
+                                      temp_storage_bytes,
+                                      common_row_table,
+                                      common_row_table,
+                                      num_rows,
+                                      mshadow::Stream<gpu>::GetStream(s));
+        size_t workspace_bytes = common_row_table_bytes + temp_storage_bytes;
+        Tensor<gpu, 1, char> workspace =
+          ctx.requested[0].get_space_typed<gpu, 1, 
char>(Shape1(workspace_bytes), s);
+        common_row_table = reinterpret_cast<IType*>(workspace.dptr_);
+        temp_storage_ptr = workspace.dptr_ + common_row_table_bytes;
+        mxnet_op::Kernel<set_zero, gpu>::Launch(s, num_rows, common_row_table);
+        Kernel<MarkRspRowFlgKernel, gpu>::Launch(
+          s, lhs_nz_rows, common_row_table, lhs_indices.dptr<IType>(), 
lhs_nz_rows);
+        Kernel<MarkRspRowFlgKernel, gpu>::Launch(
+          s, rhs_nz_rows, common_row_table, rhs_indices.dptr<IType>(), 
rhs_nz_rows);
+        cub::DeviceScan::InclusiveSum(temp_storage_ptr,
+                                      temp_storage_bytes,
+                                      common_row_table,
+                                      common_row_table,
+                                      num_rows,
+                                      mshadow::Stream<gpu>::GetStream(s));
+        nnvm::dim_t nnr_out = 0;
+        CUDA_CALL(cudaMemcpy(&nnr_out, &common_row_table[num_rows-1], 
sizeof(nnvm::dim_t),
+                              cudaMemcpyDeviceToHost));
+        output.CheckAndAlloc({mshadow::Shape1(nnr_out)});
+        Kernel<FillRspRowIdxKernel, gpu>::Launch(
+          s, num_rows, output.aux_data(kIdx).dptr<IType>(), common_row_table, 
num_rows);
+        Kernel<set_zero, gpu>::Launch(s, nnr_out * num_cols, 
output.data().dptr<DType>());
+        Kernel<RspElemwiseKernel<mshadow_op::plus>, gpu>::Launch(
+          s, lhs_nz_rows * num_cols, output.data().dptr<DType>(), 
common_row_table,
+          lhs.data().dptr<DType>(), lhs_indices.dptr<IType>(), lhs_nz_rows, 
num_cols);
+        Kernel<RspElemwiseKernel<OP>, gpu>::Launch(
+          s, rhs_nz_rows * num_cols, output.data().dptr<DType>(), 
common_row_table,
+          rhs.data().dptr<DType>(), rhs_indices.dptr<IType>(), rhs_nz_rows, 
num_cols);
+      } else {
+        if (lhs.storage_initialized()) {
+          if (req == kWriteTo) {
+            output.CheckAndAlloc({lhs.aux_shape(kIdx)});
+            Copy(output.data().FlatTo1D<gpu, DType>(),
+                 lhs.data().FlatTo1D<gpu, DType>(), s);
+            Copy(output.aux_data(kIdx).FlatTo1D<gpu, IType>(),
+                 lhs.aux_data(kIdx).FlatTo1D<gpu, IType>(), s);
+          }
 
 Review comment:
   Extra checks and tests added.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to