This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new a2c4b0a add guide for implementing sparse ops (#10081) a2c4b0a is described below commit a2c4b0a0ccaec40e22216aa83873a49d7f7506ef Author: Haibin Lin <linhaibin.e...@gmail.com> AuthorDate: Mon Mar 26 16:32:11 2018 -0700 add guide for implementing sparse ops (#10081) --- docs/faq/index.md | 2 + src/operator/contrib/quadratic_op-inl.h | 83 +++++++++++++++++++++++++++ src/operator/contrib/quadratic_op.cc | 9 ++- src/operator/contrib/quadratic_op.cu | 1 + tests/python/unittest/test_sparse_operator.py | 20 +++++++ 5 files changed, 114 insertions(+), 1 deletion(-) diff --git a/docs/faq/index.md b/docs/faq/index.md index 099cd50..098d37f 100644 --- a/docs/faq/index.md +++ b/docs/faq/index.md @@ -56,6 +56,8 @@ and full working examples, visit the [tutorials section](../tutorials/index.md). * [How do I create new operators in MXNet?](http://mxnet.io/faq/new_op.html) +* [How do I implement sparse operators in MXNet backend?](https://cwiki.apache.org/confluence/display/MXNET/A+Guide+to+Implementing+Sparse+Operators+in+MXNet+Backend) + * [How do I contribute an example or tutorial?](https://github.com/apache/incubator-mxnet/tree/master/example#contributing) * [How do I set MXNet's environmental variables?](http://mxnet.io/faq/env_var.html) diff --git a/src/operator/contrib/quadratic_op-inl.h b/src/operator/contrib/quadratic_op-inl.h index 8d73a42..fe47781 100644 --- a/src/operator/contrib/quadratic_op-inl.h +++ b/src/operator/contrib/quadratic_op-inl.h @@ -32,6 +32,7 @@ #include "../mxnet_op.h" #include "../operator_common.h" #include "../elemwise_op_common.h" +#include "../tensor/init_op.h" namespace mxnet { namespace op { @@ -73,6 +74,33 @@ inline bool QuadraticOpType(const nnvm::NodeAttrs& attrs, return out_attrs->at(0) != -1; } +inline bool QuadraticOpStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int>* in_attrs, + std::vector<int>* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const QuadraticParam& param = nnvm::get<QuadraticParam>(attrs.parsed); + const int in_stype = in_attrs->at(0); + int& out_stype = out_attrs->at(0); + bool dispatched = false; + if (!dispatched && in_stype == kDefaultStorage) { + // dns -> dns + dispatched = storage_type_assign(&out_stype, kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); + } + if (!dispatched && in_stype == kCSRStorage && param.c == 0.0) { + // csr -> csr + dispatched = storage_type_assign(&out_stype, kCSRStorage, + dispatch_mode, DispatchMode::kFComputeEx); + } + if (!dispatched) { + dispatched = dispatch_fallback(out_attrs, dispatch_mode); + } + return dispatched; +} + template<int req> struct quadratic_forward { template<typename DType> @@ -115,6 +143,61 @@ void QuadraticOpForward(const nnvm::NodeAttrs& attrs, } template<typename xpu> +void QuadraticOpForwardCsrImpl(const QuadraticParam& param, + const OpContext& ctx, + const NDArray& input, + const OpReqType req, + const NDArray& output) { + using namespace mshadow; + using namespace mxnet_op; + using namespace csr; + if (req == kNullOp) return; + CHECK_EQ(req, kWriteTo) << "QuadraticOp with CSR only supports kWriteTo"; + Stream<xpu> *s = ctx.get_stream<xpu>(); + if (!input.storage_initialized()) { + FillZerosCsrImpl(s, output); + return; + } + const nnvm::dim_t nnz = input.storage_shape()[0]; + const nnvm::dim_t num_rows = output.shape()[0]; + output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)}); + CHECK_EQ(output.aux_type(kIdx), output.aux_type(kIndPtr)) + << "The dtypes of indices and indptr don't match"; + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + Kernel<quadratic_forward<req_type>, xpu>::Launch( + s, nnz, output.data().dptr<DType>(), input.data().dptr<DType>(), + param.a, param.b, param.c); + Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(), + input.aux_data(kIdx).FlatTo1D<xpu, IType>(), s); + Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, IType>(), + input.aux_data(kIndPtr).FlatTo1D<xpu, IType>(), s); + }); + }); + }); +} + +template<typename xpu> +void QuadraticOpForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<NDArray>& inputs, + const std::vector<OpReqType>& req, + const std::vector<NDArray>& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const QuadraticParam& param = nnvm::get<QuadraticParam>(attrs.parsed); + const auto in_stype = inputs[0].storage_type(); + const auto out_stype = outputs[0].storage_type(); + if (in_stype == kCSRStorage && out_stype == kCSRStorage && param.c == 0.0) { + QuadraticOpForwardCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]); + } else { + LogUnimplementedOp(attrs, ctx, inputs, req, outputs); + } +} + +template<typename xpu> void QuadraticOpBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<TBlob>& inputs, diff --git a/src/operator/contrib/quadratic_op.cc b/src/operator/contrib/quadratic_op.cc index 5b2d84c..d8b2d78 100644 --- a/src/operator/contrib/quadratic_op.cc +++ b/src/operator/contrib/quadratic_op.cc @@ -38,6 +38,11 @@ Example:: x = [[1, 2], [3, 4]] y = quadratic(data=x, a=1, b=2, c=3) y = [[6, 11], [18, 27]] + +The storage type of ``quadratic`` output depends on storage types of inputs + - quadratic(csr, a, b, 0) = csr + - quadratic(default, a, b, c) = default + )code" ADD_FILELINE) .set_attr_parser(ParamParser<QuadraticParam>) .set_num_inputs(1) @@ -48,6 +53,7 @@ Example:: }) .set_attr<nnvm::FInferShape>("FInferShape", QuadraticOpShape) .set_attr<nnvm::FInferType>("FInferType", QuadraticOpType) +.set_attr<FInferStorageType>("FInferStorageType", QuadraticOpStorageType) .set_attr<FCompute>("FCompute<cpu>", QuadraticOpForward<cpu>) .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_contrib_backward_quadratic"}) .set_attr<nnvm::FInplaceOption>("FInplaceOption", @@ -62,7 +68,8 @@ NNVM_REGISTER_OP(_contrib_backward_quadratic) .set_num_inputs(2) .set_num_outputs(1) .set_attr<nnvm::TIsBackward>("TIsBackward", true) -.set_attr<FCompute>("FCompute<cpu>", QuadraticOpBackward<cpu>); +.set_attr<FCompute>("FCompute<cpu>", QuadraticOpBackward<cpu>) +.set_attr<FComputeEx>("FComputeEx<cpu>", QuadraticOpForwardEx<cpu>); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/quadratic_op.cu b/src/operator/contrib/quadratic_op.cu index ede773a..72d15ab 100644 --- a/src/operator/contrib/quadratic_op.cu +++ b/src/operator/contrib/quadratic_op.cu @@ -27,6 +27,7 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_contrib_quadratic) +.set_attr<FComputeEx>("FComputeEx<gpu>", QuadraticOpForwardEx<gpu>) .set_attr<FCompute>("FCompute<gpu>", QuadraticOpForward<gpu>); NNVM_REGISTER_OP(_contrib_backward_quadratic) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 407f776..9417df3 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1933,6 +1933,26 @@ def test_sparse_nd_where(): test_where_helper((5, 9)) test_where_numeric_gradient((5, 9)) +@with_seed() +def test_sparse_quadratic_function(): + def f(x, a, b, c): + return a * x**2 + b * x + c + + def check_sparse_quadratic_function(a, b, c, expected_stype): + # check forward and compare the result with dense op + ndim = 2 + shape = rand_shape_nd(ndim, 5) + data = rand_ndarray(shape=shape, stype='csr') + data_np = data.asnumpy() + expected = f(data_np, a, b, c) + output = mx.nd.contrib.quadratic(data, a=a, b=b, c=c) + assert(output.stype == expected_stype) + assert_almost_equal(output.asnumpy(), expected) + + a = np.random.random_sample() + b = np.random.random_sample() + check_sparse_quadratic_function(a, b, 0.0, 'csr') + check_sparse_quadratic_function(a, b, 1.0, 'default') if __name__ == '__main__': import nose -- To stop receiving notification emails like this one, please contact hai...@apache.org.