eric-haibin-lin closed pull request #11643: Added the diag() operator
URL: https://github.com/apache/incubator-mxnet/pull/11643
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/ndarray.md 
b/docs/api/python/ndarray/ndarray.md
index dda534151a1..849412021e1 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -131,6 +131,7 @@ The `ndarray` package provides several classes:
     NDArray.flatten
     NDArray.expand_dims
     NDArray.split
+    NDArray.diag
 ```
 
 ### Array expand elements
@@ -364,6 +365,7 @@ The `ndarray` package provides several classes:
     ones_like
     full
     arange
+    diag
     load
     save
 ```
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index 304b17803ed..a59a92745c7 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -182,6 +182,7 @@ Composite multiple symbols into a new one by an operator.
 
     Symbol.zeros_like
     Symbol.ones_like
+    Symbol.diag
 ```
 
 ### Changing shape and type
@@ -381,6 +382,7 @@ Composite multiple symbols into a new one by an operator.
     reshape_like
     flatten
     expand_dims
+    diag
 ```
 
 ### Expanding elements
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 09395e2ec82..ff9aac05c7c 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1302,6 +1302,14 @@ def flip(self, *args, **kwargs):
         """
         return op.flip(self, *args, **kwargs)
 
+    def diag(self, k=0, **kwargs):
+        """Convenience fluent method for :py:func:`diag`.
+
+        The arguments are the same as for :py:func:`diag`, with
+        this array as data.
+        """
+        return op.diag(self, k, **kwargs)
+
     def sum(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sum`.
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index b041f4ef646..88f92cde0fe 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -2038,6 +2038,14 @@ def flip(self, *args, **kwargs):
         """
         return op.flip(self, *args, **kwargs)
 
+    def diag(self, k=0, **kwargs):
+        """Convenience fluent method for :py:func:`diag`.
+
+        The arguments are the same as for :py:func:`diag`, with
+        this array as data.
+        """
+        return op.diag(self, k, **kwargs)
+
     def sum(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sum`.
 
diff --git a/src/operator/tensor/diag_op-inl.h 
b/src/operator/tensor/diag_op-inl.h
new file mode 100644
index 00000000000..3bc240f206b
--- /dev/null
+++ b/src/operator/tensor/diag_op-inl.h
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op-inl.h
+* \brief CPU Implementation of the diag op
+* \author Istvan Fehervari
+*/
+
+#ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
+#define MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
+
+#include <dmlc/parameter.h>
+#include <vector>
+#include <algorithm>
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct DiagParam : public dmlc::Parameter<DiagParam> {
+  dmlc::optional<int> k;
+  DMLC_DECLARE_PARAMETER(DiagParam) {
+    DMLC_DECLARE_FIELD(k)
+    .set_default(dmlc::optional<int>(0))
+    .describe("Diagonal in question. The default is 0. "
+              "Use k>0 for diagonals above the main diagonal, "
+              "and k<0 for diagonals below the main diagonal. "
+              "If input has shape (S0 S1) k must be between -S0 and S1");
+  }
+};
+
+inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
+  if (ishape.ndim() == 1) {
+    auto s = ishape[0] + std::abs(k);
+    return TShape({s, s});
+  }
+
+  auto h = ishape[0];
+  auto w = ishape[1];
+
+  if (k > 0) {
+    w -= k;
+  } else if (k < 0) {
+    h += k;
+  }
+
+  auto s = std::min(h, w);
+  if (s < 0) {
+    s = 0;
+  }
+
+  return TShape({s});
+}
+
+inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
+                             std::vector<TShape>* in_attrs,
+                             std::vector<TShape>* out_attrs) {
+    CHECK_EQ(in_attrs->size(), 1U);
+    CHECK_EQ(out_attrs->size(), 1U);
+
+    const TShape& ishape = (*in_attrs)[0];
+    if (ishape.ndim() == 0) return false;
+    if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d.";
+
+    const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+    TShape oshape = DiagShapeImpl(ishape, param.k.value());
+    if (shape_is_none(oshape)) {
+      LOG(FATAL) << "Diagonal does not exist.";
+    }
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
+
+    return out_attrs->at(0).ndim() != 0U;
+}
+
+inline bool DiagOpType(const nnvm::NodeAttrs& attrs,
+                       std::vector<int> *in_attrs,
+                       std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+
+  TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
+  TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
+  return (*out_attrs)[0] != -1;
+}
+
+template<int req>
+struct diag {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
+                                  mshadow::Shape<2> ishape, int k) {
+    using namespace mxnet_op;
+    int j = 0;
+    if (k > 0) {
+      j = ravel(mshadow::Shape2(i, i + k), ishape);
+    } else if (k < 0) {
+      j = ravel(mshadow::Shape2(i - k, i), ishape);
+    } else {
+      j = ravel(mshadow::Shape2(i, i), ishape);
+    }
+
+    KERNEL_ASSIGN(out[i], req, a[j]);
+  }
+};
+
+template<int req>
+struct diag_gen {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
+                                  mshadow::Shape<2> oshape, int k) {
+    using namespace mxnet_op;
+
+    auto j = unravel(i, oshape);
+    if (j[1] == (j[0] + k)) {
+      auto l = j[0] < j[1] ? j[0] : j[1];
+      KERNEL_ASSIGN(out[i], req, a[l]);
+    } else {
+      KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
+    }
+  }
+};
+
+template<typename xpu>
+void DiagOpForward(const nnvm::NodeAttrs& attrs,
+                   const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_EQ(req[0], kWriteTo);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const TShape& ishape = inputs[0].shape_;
+  const TShape& oshape = outputs[0].shape_;
+  const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+  if (ishape.ndim() == 2) {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
+      });
+    });
+  } else {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
+      });
+    });
+  }
+}
+
+template<typename xpu>
+void DiagOpBackward(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const TShape& ishape = inputs[0].shape_;
+  const TShape& oshape = outputs[0].shape_;
+  const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+  if (oshape.ndim() == 2) {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
+      });
+    });
+  } else {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
+      });
+    });
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
diff --git a/src/operator/tensor/diag_op.cc b/src/operator/tensor/diag_op.cc
new file mode 100644
index 00000000000..1ad3b8adc02
--- /dev/null
+++ b/src/operator/tensor/diag_op.cc
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op.cc
+* \brief
+* \author Istvan Fehervari
+*/
+
+#include "./diag_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(DiagParam);
+
+NNVM_REGISTER_OP(diag)
+.describe(R"code(Extracts a diagonal or constructs a diagonal array.
+
+``diag``'s behavior depends on the input array dimensions:
+
+- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other 
elements are zero
+- 2-D arrays: returns elements in the diagonal as a new 1-D array
+- N-D arrays: not supported yet
+
+Examples::
+
+  x = [[1, 2, 3],
+       [4, 5, 6]]
+
+  diag(x) = [1, 5]
+
+  diag(x, k=1) = [2, 6]
+
+  diag(x, k=-1) = [4]
+
+  x = [1, 2, 3]
+
+  diag(x) = [[1, 0, 0],
+             [0, 2, 0],
+             [0, 0, 3]]
+
+  diag(x, k=1) = [[0, 1, 0],
+                  [0, 0, 2],
+                  [0, 0, 0]]
+
+  diag(x, k=-1) = [[0, 0, 0],
+                   [1, 0, 0],
+                   [0, 2, 0]]
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<DiagParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", DiagOpShape)
+.set_attr<nnvm::FInferType>("FInferType", DiagOpType)
+.set_attr<FCompute>("FCompute<cpu>", DiagOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_diag"})
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(DiagParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_diag)
+.set_attr_parser(ParamParser<DiagParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", DiagOpBackward<cpu>);
+
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/diag_op.cu b/src/operator/tensor/diag_op.cu
new file mode 100644
index 00000000000..a3928f76386
--- /dev/null
+++ b/src/operator/tensor/diag_op.cu
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op.cu
+* \brief GPU Implementation of the diag op
+* \author Istvan Fehervari
+*/
+
+#include "./diag_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(diag)
+.set_attr<FCompute>("FCompute<gpu>", DiagOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_diag)
+.set_attr<FCompute>("FCompute<gpu>", DiagOpBackward<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 814266ad9aa..a763037409a 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -27,7 +27,7 @@
 from numpy.testing import assert_allclose, assert_array_equal
 from mxnet.test_utils import *
 from mxnet.base import py_str, MXNetError, _as_list
-from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabled
+from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabled, assertRaises
 import unittest
 
 def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req):
@@ -7033,6 +7033,79 @@ def test_roi_align_autograd(sampling_ratio=0):
     test_roi_align_value(2)
     test_roi_align_autograd()
 
+@with_seed()
+def test_diag():
+
+    # Test 2d input
+    h = np.random.randint(2,9)
+    w = np.random.randint(2,9)
+    a_np = np.random.random((h, w)).astype(np.float32)
+    a = mx.nd.array(a_np).astype('float32')
+    
+    # k == 0
+    r = mx.nd.diag(a)
+    assert_almost_equal(r.asnumpy(), np.diag(a_np))
+
+    # k == 1
+    k = 1
+    r = mx.nd.diag(a, k=k)
+    assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+    # k == -1
+    k = -1
+    r = mx.nd.diag(a, k=k)
+    assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+    # random k
+    k = np.random.randint(-min(h,w) + 1, min(h,w))
+    r = mx.nd.diag(a, k=k)
+    assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+    # invalid k
+    k = max(h,w) + 1
+    assertRaises(MXNetError, mx.nd.diag, a, k=k)
+
+    # Test 2d backward, k=0
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 2d backward, k=1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=1)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 2d backward, k=-1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=-1)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # test 1d input
+    d = np.random.randint(2,9)
+    a_np = np.random.random((d))
+    a = mx.nd.array(a_np)
+    
+    # k is random
+    k = np.random.randint(-d,d)
+    r = mx.nd.diag(a, k=k)
+
+    assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+    # Test 2d backward, k=0
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 2d backward, k=1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=1)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 2d backward, k=-1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=-1)
+    check_numeric_gradient(diag_sym, [a_np])
+
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to