This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f15b1b8  Added the diag() operator (#11643)
f15b1b8 is described below

commit f15b1b88b9f055420ba19bb73e93b229bf03febd
Author: Istvan Fehervari <[email protected]>
AuthorDate: Thu Jul 19 13:26:46 2018 -0700

    Added the diag() operator (#11643)
    
    * Added np.diag as mxnet operator, WIP
    
    Done:
    2d input forward pass
    Missing:
    1d input forward
    all backward
    
    * Added a simple gradient transfer backwards operator for diag
    
    Fixed small typos as well
    
    * Finished backward operation
    
    * Added full support for k
    
    * Finished added the 1D case to the diag operator
    
    Finished function documentation
    Added unit tests
    
    * Fixed cpplinter errors in the diag operator
    
    Issues were extra white spaces and include order
    
    * Fixed indentation in diag_op-inl.h
    
    * Changed diag operator tests to use np.diag() as comparison
    
    * Fixed kernel bug in gpu diag operator
    
    * Replaced the min operator with an inline if statement.
    
    * Added diag to ndarray and symbol
    
    * Replaced the type of parameter k from int32 to nnvm::dim
    
    * Added default argument to k in ndarray and symbol
    
    * Fixed ndarray and symbol diag calls
    
    * Fixed the optional k parameter
    
    * Fixed cpp linting error
    
    * Changed test data datatype to float32
    
    * K values resulting into 0-sized diagonals will now throw an exception.
    
    Added matching test case
    
    * Fixed unittest
    
    * Added diag to NDArray and Symbol api doc
    
    * Added missing api doc
---
 docs/api/python/ndarray/ndarray.md     |   2 +
 docs/api/python/symbol/symbol.md       |   2 +
 python/mxnet/ndarray/ndarray.py        |   8 ++
 python/mxnet/symbol/symbol.py          |   8 ++
 src/operator/tensor/diag_op-inl.h      | 217 +++++++++++++++++++++++++++++++++
 src/operator/tensor/diag_op.cc         |  93 ++++++++++++++
 src/operator/tensor/diag_op.cu         |  39 ++++++
 tests/python/unittest/test_operator.py |  75 +++++++++++-
 8 files changed, 443 insertions(+), 1 deletion(-)

diff --git a/docs/api/python/ndarray/ndarray.md 
b/docs/api/python/ndarray/ndarray.md
index dda5341..8494120 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -131,6 +131,7 @@ The `ndarray` package provides several classes:
     NDArray.flatten
     NDArray.expand_dims
     NDArray.split
+    NDArray.diag
 ```
 
 ### Array expand elements
@@ -364,6 +365,7 @@ The `ndarray` package provides several classes:
     ones_like
     full
     arange
+    diag
     load
     save
 ```
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index 304b178..a59a927 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -182,6 +182,7 @@ Composite multiple symbols into a new one by an operator.
 
     Symbol.zeros_like
     Symbol.ones_like
+    Symbol.diag
 ```
 
 ### Changing shape and type
@@ -381,6 +382,7 @@ Composite multiple symbols into a new one by an operator.
     reshape_like
     flatten
     expand_dims
+    diag
 ```
 
 ### Expanding elements
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 09395e2..ff9aac0 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1302,6 +1302,14 @@ fixed-size items.
         """
         return op.flip(self, *args, **kwargs)
 
+    def diag(self, k=0, **kwargs):
+        """Convenience fluent method for :py:func:`diag`.
+
+        The arguments are the same as for :py:func:`diag`, with
+        this array as data.
+        """
+        return op.diag(self, k, **kwargs)
+
     def sum(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sum`.
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index b041f4e..88f92cd 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -2038,6 +2038,14 @@ class Symbol(SymbolBase):
         """
         return op.flip(self, *args, **kwargs)
 
+    def diag(self, k=0, **kwargs):
+        """Convenience fluent method for :py:func:`diag`.
+
+        The arguments are the same as for :py:func:`diag`, with
+        this array as data.
+        """
+        return op.diag(self, k, **kwargs)
+
     def sum(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sum`.
 
diff --git a/src/operator/tensor/diag_op-inl.h 
b/src/operator/tensor/diag_op-inl.h
new file mode 100644
index 0000000..3bc240f
--- /dev/null
+++ b/src/operator/tensor/diag_op-inl.h
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op-inl.h
+* \brief CPU Implementation of the diag op
+* \author Istvan Fehervari
+*/
+
+#ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
+#define MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
+
+#include <dmlc/parameter.h>
+#include <vector>
+#include <algorithm>
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct DiagParam : public dmlc::Parameter<DiagParam> {
+  dmlc::optional<int> k;
+  DMLC_DECLARE_PARAMETER(DiagParam) {
+    DMLC_DECLARE_FIELD(k)
+    .set_default(dmlc::optional<int>(0))
+    .describe("Diagonal in question. The default is 0. "
+              "Use k>0 for diagonals above the main diagonal, "
+              "and k<0 for diagonals below the main diagonal. "
+              "If input has shape (S0 S1) k must be between -S0 and S1");
+  }
+};
+
+inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
+  if (ishape.ndim() == 1) {
+    auto s = ishape[0] + std::abs(k);
+    return TShape({s, s});
+  }
+
+  auto h = ishape[0];
+  auto w = ishape[1];
+
+  if (k > 0) {
+    w -= k;
+  } else if (k < 0) {
+    h += k;
+  }
+
+  auto s = std::min(h, w);
+  if (s < 0) {
+    s = 0;
+  }
+
+  return TShape({s});
+}
+
+inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
+                             std::vector<TShape>* in_attrs,
+                             std::vector<TShape>* out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+
+    const TShape& ishape = (*in_attrs)[0];
+    if (ishape.ndim() == 0) return false;
+    if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d.";
+
+    const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+    TShape oshape = DiagShapeImpl(ishape, param.k.value());
+    if (shape_is_none(oshape)) {
+      LOG(FATAL) << "Diagonal does not exist.";
+    }
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
+
+    return out_attrs->at(0).ndim() != 0U;
+}
+
+inline bool DiagOpType(const nnvm::NodeAttrs& attrs,
+                       std::vector<int> *in_attrs,
+                       std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
+  return (*out_attrs)[0] != -1;
+}
+
+template<int req>
+struct diag {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
+                                  mshadow::Shape<2> ishape, int k) {
+    using namespace mxnet_op;
+    int j = 0;
+    if (k > 0) {
+      j = ravel(mshadow::Shape2(i, i + k), ishape);
+    } else if (k < 0) {
+      j = ravel(mshadow::Shape2(i - k, i), ishape);
+    } else {
+      j = ravel(mshadow::Shape2(i, i), ishape);
+    }
+
+    KERNEL_ASSIGN(out[i], req, a[j]);
+  }
+};
+
+template<int req>
+struct diag_gen {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
+                                  mshadow::Shape<2> oshape, int k) {
+    using namespace mxnet_op;
+
+    auto j = unravel(i, oshape);
+    if (j[1] == (j[0] + k)) {
+      auto l = j[0] < j[1] ? j[0] : j[1];
+      KERNEL_ASSIGN(out[i], req, a[l]);
+    } else {
+      KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
+    }
+  }
+};
+
+template<typename xpu>
+void DiagOpForward(const nnvm::NodeAttrs& attrs,
+                   const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_EQ(req[0], kWriteTo);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const TShape& ishape = inputs[0].shape_;
+  const TShape& oshape = outputs[0].shape_;
+  const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+  if (ishape.ndim() == 2) {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
+      });
+    });
+  } else {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
+      });
+    });
+  }
+}
+
+template<typename xpu>
+void DiagOpBackward(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const TShape& ishape = inputs[0].shape_;
+  const TShape& oshape = outputs[0].shape_;
+  const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+  if (oshape.ndim() == 2) {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
+      });
+    });
+  } else {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
+      });
+    });
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
diff --git a/src/operator/tensor/diag_op.cc b/src/operator/tensor/diag_op.cc
new file mode 100644
index 0000000..1ad3b8a
--- /dev/null
+++ b/src/operator/tensor/diag_op.cc
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op.cc
+* \brief
+* \author Istvan Fehervari
+*/
+
+#include "./diag_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(DiagParam);
+
+NNVM_REGISTER_OP(diag)
+.describe(R"code(Extracts a diagonal or constructs a diagonal array.
+
+``diag``'s behavior depends on the input array dimensions:
+
+- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other 
elements are zero
+- 2-D arrays: returns elements in the diagonal as a new 1-D array
+- N-D arrays: not supported yet
+
+Examples::
+
+  x = [[1, 2, 3],
+       [4, 5, 6]]
+
+  diag(x) = [1, 5]
+
+  diag(x, k=1) = [2, 6]
+
+  diag(x, k=-1) = [4]
+
+  x = [1, 2, 3]
+
+  diag(x) = [[1, 0, 0],
+             [0, 2, 0],
+             [0, 0, 3]]
+
+  diag(x, k=1) = [[0, 1, 0],
+                  [0, 0, 2],
+                  [0, 0, 0]]
+
+  diag(x, k=-1) = [[0, 0, 0],
+                   [1, 0, 0],
+                   [0, 2, 0]]
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<DiagParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", DiagOpShape)
+.set_attr<nnvm::FInferType>("FInferType", DiagOpType)
+.set_attr<FCompute>("FCompute<cpu>", DiagOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_diag"})
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(DiagParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_diag)
+.set_attr_parser(ParamParser<DiagParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", DiagOpBackward<cpu>);
+
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/diag_op.cu b/src/operator/tensor/diag_op.cu
new file mode 100644
index 0000000..a3928f7
--- /dev/null
+++ b/src/operator/tensor/diag_op.cu
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op.cu
+* \brief GPU Implementation of the diag op
+* \author Istvan Fehervari
+*/
+
+#include "./diag_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(diag)
+.set_attr<FCompute>("FCompute<gpu>", DiagOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_diag)
+.set_attr<FCompute>("FCompute<gpu>", DiagOpBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index c870709..0592e5a 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -27,7 +27,7 @@ from distutils.version import LooseVersion
 from numpy.testing import assert_allclose, assert_array_equal
 from mxnet.test_utils import *
 from mxnet.base import py_str, MXNetError, _as_list
-from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabled
+from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabled, assertRaises
 import unittest
 
 def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req):
@@ -7045,6 +7045,79 @@ def test_op_roi_align():
     test_roi_align_value(2)
     test_roi_align_autograd()
 
+@with_seed()
+def test_diag():
+
+    # Test 2d input
+    h = np.random.randint(2,9)
+    w = np.random.randint(2,9)
+    a_np = np.random.random((h, w)).astype(np.float32)
+    a = mx.nd.array(a_np).astype('float32')
+    
+    # k == 0
+    r = mx.nd.diag(a)
+    assert_almost_equal(r.asnumpy(), np.diag(a_np))
+
+    # k == 1
+    k = 1
+    r = mx.nd.diag(a, k=k)
+    assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+    # k == -1
+    k = -1
+    r = mx.nd.diag(a, k=k)
+    assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+    # random k
+    k = np.random.randint(-min(h,w) + 1, min(h,w))
+    r = mx.nd.diag(a, k=k)
+    assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+    # invalid k
+    k = max(h,w) + 1
+    assertRaises(MXNetError, mx.nd.diag, a, k=k)
+
+    # Test 2d backward, k=0
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 2d backward, k=1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=1)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 2d backward, k=-1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=-1)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # test 1d input
+    d = np.random.randint(2,9)
+    a_np = np.random.random((d))
+    a = mx.nd.array(a_np)
+    
+    # k is random
+    k = np.random.randint(-d,d)
+    r = mx.nd.diag(a, k=k)
+
+    assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+    # Test 2d backward, k=0
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 2d backward, k=1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=1)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 2d backward, k=-1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=-1)
+    check_numeric_gradient(diag_sym, [a_np])
+
 
 if __name__ == '__main__':
     import nose

Reply via email to