eric-haibin-lin closed pull request #10400: Add support for cast storage on 
same stypes
URL: https://github.com/apache/incubator-mxnet/pull/10400
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/cast_storage-inl.h 
b/src/operator/tensor/cast_storage-inl.h
index 46de10ac9cc..f905bf8f722 100644
--- a/src/operator/tensor/cast_storage-inl.h
+++ b/src/operator/tensor/cast_storage-inl.h
@@ -30,6 +30,7 @@
 #include <algorithm>
 #include "../mxnet_op.h"
 #include "../operator_common.h"
+#include "../../src/operator/tensor/init_op.h"
 #ifdef __CUDACC__
 #include "./cast_storage-inl.cuh"
 #endif  // __CUDACC__
@@ -328,6 +329,50 @@ void CastStorageCsrDnsImpl(const OpContext& ctx,
   });
 }
 
+/*!
+ * \brief Casts a csr matrix to another csr.
+ */
+template <typename xpu>
+void CastStorageCsrCsrImpl(const OpContext& ctx, const NDArray& csr,
+                           NDArray* output) {
+  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+  if (!csr.storage_initialized()) {
+    FillZerosCsrImpl(s, *output);
+    return;
+  }
+  std::vector<TShape> aux_shapes({csr.aux_shape(csr::kIndPtr), 
csr.aux_shape(csr::kIdx)});
+  output->CheckAndAlloc(aux_shapes);
+  const TBlob& val = output->data();
+  const TBlob& indptr = output->aux_data(csr::kIndPtr);
+  const TBlob& idx = output->aux_data(csr::kIdx);
+  mxnet_op::copy(s, val, csr.data());
+  mxnet_op::copy(s, indptr, csr.aux_data(csr::kIndPtr));
+  mxnet_op::copy(s, idx, csr.aux_data(csr::kIdx));
+}
+
+/*!
+ * \brief Casts a rsp matrix to another rsp.
+ */
+template <typename xpu>
+void CastStorageRspRspImpl(const OpContext& ctx, const NDArray& rsp,
+                           NDArray* output) {
+  CHECK_EQ(rsp.storage_type(), output->storage_type())
+      << "Copying with different storage type";
+  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+  if (!rsp.storage_initialized()) {
+    FillZerosRspImpl(s, *output);
+    return;
+  }
+  auto aux_shape = rsp.aux_shape(rowsparse::kIdx);
+  output->CheckAndAlloc({aux_shape});
+  const TBlob& val = output->data();
+  const TBlob& idx = output->aux_data(rowsparse::kIdx);
+  const TBlob& from_val = rsp.data();
+  const TBlob& from_idx = rsp.aux_data(rowsparse::kIdx);
+  mxnet_op::copy(s, val, from_val);
+  mxnet_op::copy(s, idx, from_idx);
+}
+
 template<typename xpu>
 void CastStorageComputeImpl(const OpContext& ctx,
                             const NDArray& input,
@@ -346,6 +391,12 @@ void CastStorageComputeImpl(const OpContext& ctx,
   } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) {
     TBlob ret = output.data();
     CastStorageCsrDnsImpl<xpu>(ctx, input, &ret);
+  } else if (src_stype == kCSRStorage && dst_stype == kCSRStorage) {
+    NDArray ret = output;
+    CastStorageCsrCsrImpl<xpu>(ctx, input, &ret);
+  } else if (src_stype == kRowSparseStorage && dst_stype == kRowSparseStorage) 
{
+    NDArray ret = output;
+    CastStorageRspRspImpl<xpu>(ctx, input, &ret);
 #if MXNET_USE_MKLDNN == 1
   } else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) {
     CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type);
diff --git a/src/operator/tensor/cast_storage.cc 
b/src/operator/tensor/cast_storage.cc
index 9f257b140f7..f77a50a7e70 100644
--- a/src/operator/tensor/cast_storage.cc
+++ b/src/operator/tensor/cast_storage.cc
@@ -46,6 +46,8 @@ The storage type of ``cast_storage`` output depends on stype 
parameter:
 - cast_storage(row_sparse, 'default') = default
 - cast_storage(default, 'csr') = csr
 - cast_storage(default, 'row_sparse') = row_sparse
+- cast_storage(csr, 'csr') = csr
+- cast_storage(row_sparse, 'row_sparse') = row_sparse
 
 Example::
 
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 9417df31748..5ad5215036d 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1177,10 +1177,13 @@ def check_cast_storage(shape, density, from_stype, 
to_stype, check_numeric_grad=
         shape_3d = rand_shape_3d()
         check_cast_storage(shape_2d, d, 'csr', 'default')
         check_cast_storage(shape_2d, d, 'default', 'csr')
+        check_cast_storage(shape_2d, d, 'csr', 'csr')
         check_cast_storage(shape_2d, d, 'row_sparse', 'default')
         check_cast_storage(shape_2d, d, 'default', 'row_sparse')
+        check_cast_storage(shape_2d, d, 'row_sparse', 'row_sparse')
         check_cast_storage(shape_3d, d, 'row_sparse', 'default')
         check_cast_storage(shape_3d, d, 'default', 'row_sparse')
+        check_cast_storage(shape_3d, d, 'row_sparse', 'row_sparse')
         for i in range(4, 6):
             shape = rand_shape_nd(i, 5)
             check_cast_storage(shape, d, 'default', 'row_sparse')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to