This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a2c4b0a  add guide for implementing sparse ops (#10081)
a2c4b0a is described below

commit a2c4b0a0ccaec40e22216aa83873a49d7f7506ef
Author: Haibin Lin <linhaibin.e...@gmail.com>
AuthorDate: Mon Mar 26 16:32:11 2018 -0700

    add guide for implementing sparse ops (#10081)
---
 docs/faq/index.md                             |  2 +
 src/operator/contrib/quadratic_op-inl.h       | 83 +++++++++++++++++++++++++++
 src/operator/contrib/quadratic_op.cc          |  9 ++-
 src/operator/contrib/quadratic_op.cu          |  1 +
 tests/python/unittest/test_sparse_operator.py | 20 +++++++
 5 files changed, 114 insertions(+), 1 deletion(-)

diff --git a/docs/faq/index.md b/docs/faq/index.md
index 099cd50..098d37f 100644
--- a/docs/faq/index.md
+++ b/docs/faq/index.md
@@ -56,6 +56,8 @@ and full working examples, visit the [tutorials 
section](../tutorials/index.md).
 
 * [How do I create new operators in MXNet?](http://mxnet.io/faq/new_op.html)
 
+* [How do I implement sparse operators in MXNet 
backend?](https://cwiki.apache.org/confluence/display/MXNET/A+Guide+to+Implementing+Sparse+Operators+in+MXNet+Backend)
+
 * [How do I contribute an example or 
tutorial?](https://github.com/apache/incubator-mxnet/tree/master/example#contributing)
 
 * [How do I set MXNet's environmental 
variables?](http://mxnet.io/faq/env_var.html)
diff --git a/src/operator/contrib/quadratic_op-inl.h 
b/src/operator/contrib/quadratic_op-inl.h
index 8d73a42..fe47781 100644
--- a/src/operator/contrib/quadratic_op-inl.h
+++ b/src/operator/contrib/quadratic_op-inl.h
@@ -32,6 +32,7 @@
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 #include "../elemwise_op_common.h"
+#include "../tensor/init_op.h"
 
 namespace mxnet {
 namespace op {
@@ -73,6 +74,33 @@ inline bool QuadraticOpType(const nnvm::NodeAttrs& attrs,
   return out_attrs->at(0) != -1;
 }
 
+inline bool QuadraticOpStorageType(const nnvm::NodeAttrs& attrs,
+                                   const int dev_mask,
+                                   DispatchMode* dispatch_mode,
+                                   std::vector<int>* in_attrs,
+                                   std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const QuadraticParam& param = nnvm::get<QuadraticParam>(attrs.parsed);
+  const int in_stype = in_attrs->at(0);
+  int& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && in_stype == kDefaultStorage) {
+    // dns -> dns
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && in_stype == kCSRStorage && param.c == 0.0) {
+    // csr -> csr
+    dispatched = storage_type_assign(&out_stype, kCSRStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
 template<int req>
 struct quadratic_forward {
   template<typename DType>
@@ -115,6 +143,61 @@ void QuadraticOpForward(const nnvm::NodeAttrs& attrs,
 }
 
 template<typename xpu>
+void QuadraticOpForwardCsrImpl(const QuadraticParam& param,
+                               const OpContext& ctx,
+                               const NDArray& input,
+                               const OpReqType req,
+                               const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "QuadraticOp with CSR only supports kWriteTo";
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  if (!input.storage_initialized()) {
+    FillZerosCsrImpl(s, output);
+    return;
+  }
+  const nnvm::dim_t nnz = input.storage_shape()[0];
+  const nnvm::dim_t num_rows = output.shape()[0];
+  output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)});
+  CHECK_EQ(output.aux_type(kIdx), output.aux_type(kIndPtr))
+    << "The dtypes of indices and indptr don't match";
+  MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+    MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+        Kernel<quadratic_forward<req_type>, xpu>::Launch(
+            s, nnz, output.data().dptr<DType>(), input.data().dptr<DType>(),
+            param.a, param.b, param.c);
+        Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(),
+             input.aux_data(kIdx).FlatTo1D<xpu, IType>(), s);
+        Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, IType>(),
+             input.aux_data(kIndPtr).FlatTo1D<xpu, IType>(), s);
+      });
+    });
+  });
+}
+
+template<typename xpu>
+void QuadraticOpForwardEx(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const QuadraticParam& param = nnvm::get<QuadraticParam>(attrs.parsed);
+  const auto in_stype = inputs[0].storage_type();
+  const auto out_stype = outputs[0].storage_type();
+  if (in_stype == kCSRStorage && out_stype == kCSRStorage && param.c == 0.0) {
+    QuadraticOpForwardCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
+template<typename xpu>
 void QuadraticOpBackward(const nnvm::NodeAttrs& attrs,
                          const OpContext& ctx,
                          const std::vector<TBlob>& inputs,
diff --git a/src/operator/contrib/quadratic_op.cc 
b/src/operator/contrib/quadratic_op.cc
index 5b2d84c..d8b2d78 100644
--- a/src/operator/contrib/quadratic_op.cc
+++ b/src/operator/contrib/quadratic_op.cc
@@ -38,6 +38,11 @@ Example::
   x = [[1, 2], [3, 4]]
   y = quadratic(data=x, a=1, b=2, c=3)
   y = [[6, 11], [18, 27]]
+
+The storage type of ``quadratic`` output depends on storage types of inputs
+  - quadratic(csr, a, b, 0) = csr
+  - quadratic(default, a, b, c) = default
+
 )code" ADD_FILELINE)
 .set_attr_parser(ParamParser<QuadraticParam>)
 .set_num_inputs(1)
@@ -48,6 +53,7 @@ Example::
   })
 .set_attr<nnvm::FInferShape>("FInferShape", QuadraticOpShape)
 .set_attr<nnvm::FInferType>("FInferType", QuadraticOpType)
+.set_attr<FInferStorageType>("FInferStorageType", QuadraticOpStorageType)
 .set_attr<FCompute>("FCompute<cpu>", QuadraticOpForward<cpu>)
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_contrib_backward_quadratic"})
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
@@ -62,7 +68,8 @@ NNVM_REGISTER_OP(_contrib_backward_quadratic)
 .set_num_inputs(2)
 .set_num_outputs(1)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", QuadraticOpBackward<cpu>);
+.set_attr<FCompute>("FCompute<cpu>", QuadraticOpBackward<cpu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", QuadraticOpForwardEx<cpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/contrib/quadratic_op.cu 
b/src/operator/contrib/quadratic_op.cu
index ede773a..72d15ab 100644
--- a/src/operator/contrib/quadratic_op.cu
+++ b/src/operator/contrib/quadratic_op.cu
@@ -27,6 +27,7 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_contrib_quadratic)
+.set_attr<FComputeEx>("FComputeEx<gpu>", QuadraticOpForwardEx<gpu>)
 .set_attr<FCompute>("FCompute<gpu>", QuadraticOpForward<gpu>);
 
 NNVM_REGISTER_OP(_contrib_backward_quadratic)
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 407f776..9417df3 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1933,6 +1933,26 @@ def test_sparse_nd_where():
     test_where_helper((5, 9))
     test_where_numeric_gradient((5, 9))
 
+@with_seed()
+def test_sparse_quadratic_function():
+    def f(x, a, b, c):
+        return a * x**2 + b * x + c
+
+    def check_sparse_quadratic_function(a, b, c, expected_stype):
+      # check forward and compare the result with dense op
+      ndim = 2
+      shape = rand_shape_nd(ndim, 5)
+      data = rand_ndarray(shape=shape, stype='csr')
+      data_np = data.asnumpy()
+      expected = f(data_np, a, b, c)
+      output = mx.nd.contrib.quadratic(data, a=a, b=b, c=c)
+      assert(output.stype == expected_stype)
+      assert_almost_equal(output.asnumpy(), expected)
+
+    a = np.random.random_sample()
+    b = np.random.random_sample()
+    check_sparse_quadratic_function(a, b, 0.0, 'csr')
+    check_sparse_quadratic_function(a, b, 1.0, 'default')
 
 if __name__ == '__main__':
     import nose

-- 
To stop receiving notification emails like this one, please contact
hai...@apache.org.

Reply via email to