eric-haibin-lin closed pull request #10400: Add support for cast storage on same stypes URL: https://github.com/apache/incubator-mxnet/pull/10400
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 46de10ac9cc..f905bf8f722 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -30,6 +30,7 @@ #include <algorithm> #include "../mxnet_op.h" #include "../operator_common.h" +#include "../../src/operator/tensor/init_op.h" #ifdef __CUDACC__ #include "./cast_storage-inl.cuh" #endif // __CUDACC__ @@ -328,6 +329,50 @@ void CastStorageCsrDnsImpl(const OpContext& ctx, }); } +/*! + * \brief Casts a csr matrix to another csr. + */ +template <typename xpu> +void CastStorageCsrCsrImpl(const OpContext& ctx, const NDArray& csr, + NDArray* output) { + mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); + if (!csr.storage_initialized()) { + FillZerosCsrImpl(s, *output); + return; + } + std::vector<TShape> aux_shapes({csr.aux_shape(csr::kIndPtr), csr.aux_shape(csr::kIdx)}); + output->CheckAndAlloc(aux_shapes); + const TBlob& val = output->data(); + const TBlob& indptr = output->aux_data(csr::kIndPtr); + const TBlob& idx = output->aux_data(csr::kIdx); + mxnet_op::copy(s, val, csr.data()); + mxnet_op::copy(s, indptr, csr.aux_data(csr::kIndPtr)); + mxnet_op::copy(s, idx, csr.aux_data(csr::kIdx)); +} + +/*! + * \brief Casts a rsp matrix to another rsp. + */ +template <typename xpu> +void CastStorageRspRspImpl(const OpContext& ctx, const NDArray& rsp, + NDArray* output) { + CHECK_EQ(rsp.storage_type(), output->storage_type()) + << "Copying with different storage type"; + mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); + if (!rsp.storage_initialized()) { + FillZerosRspImpl(s, *output); + return; + } + auto aux_shape = rsp.aux_shape(rowsparse::kIdx); + output->CheckAndAlloc({aux_shape}); + const TBlob& val = output->data(); + const TBlob& idx = output->aux_data(rowsparse::kIdx); + const TBlob& from_val = rsp.data(); + const TBlob& from_idx = rsp.aux_data(rowsparse::kIdx); + mxnet_op::copy(s, val, from_val); + mxnet_op::copy(s, idx, from_idx); +} + template<typename xpu> void CastStorageComputeImpl(const OpContext& ctx, const NDArray& input, @@ -346,6 +391,12 @@ void CastStorageComputeImpl(const OpContext& ctx, } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { TBlob ret = output.data(); CastStorageCsrDnsImpl<xpu>(ctx, input, &ret); + } else if (src_stype == kCSRStorage && dst_stype == kCSRStorage) { + NDArray ret = output; + CastStorageCsrCsrImpl<xpu>(ctx, input, &ret); + } else if (src_stype == kRowSparseStorage && dst_stype == kRowSparseStorage) { + NDArray ret = output; + CastStorageRspRspImpl<xpu>(ctx, input, &ret); #if MXNET_USE_MKLDNN == 1 } else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) { CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type); diff --git a/src/operator/tensor/cast_storage.cc b/src/operator/tensor/cast_storage.cc index 9f257b140f7..f77a50a7e70 100644 --- a/src/operator/tensor/cast_storage.cc +++ b/src/operator/tensor/cast_storage.cc @@ -46,6 +46,8 @@ The storage type of ``cast_storage`` output depends on stype parameter: - cast_storage(row_sparse, 'default') = default - cast_storage(default, 'csr') = csr - cast_storage(default, 'row_sparse') = row_sparse +- cast_storage(csr, 'csr') = csr +- cast_storage(row_sparse, 'row_sparse') = row_sparse Example:: diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 9417df31748..5ad5215036d 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1177,10 +1177,13 @@ def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad= shape_3d = rand_shape_3d() check_cast_storage(shape_2d, d, 'csr', 'default') check_cast_storage(shape_2d, d, 'default', 'csr') + check_cast_storage(shape_2d, d, 'csr', 'csr') check_cast_storage(shape_2d, d, 'row_sparse', 'default') check_cast_storage(shape_2d, d, 'default', 'row_sparse') + check_cast_storage(shape_2d, d, 'row_sparse', 'row_sparse') check_cast_storage(shape_3d, d, 'row_sparse', 'default') check_cast_storage(shape_3d, d, 'default', 'row_sparse') + check_cast_storage(shape_3d, d, 'row_sparse', 'row_sparse') for i in range(4, 6): shape = rand_shape_nd(i, 5) check_cast_storage(shape, d, 'default', 'row_sparse') ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services