This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f15b1b8 Added the diag() operator (#11643)
f15b1b8 is described below
commit f15b1b88b9f055420ba19bb73e93b229bf03febd
Author: Istvan Fehervari <[email protected]>
AuthorDate: Thu Jul 19 13:26:46 2018 -0700
Added the diag() operator (#11643)
* Added np.diag as mxnet operator, WIP
Done:
2d input forward pass
Missing:
1d input forward
all backward
* Added a simple gradient transfer backwards operator for diag
Fixed small typos as well
* Finished backward operation
* Added full support for k
* Finished added the 1D case to the diag operator
Finished function documentation
Added unit tests
* Fixed cpplinter errors in the diag operator
Issues were extra white spaces and include order
* Fixed indentation in diag_op-inl.h
* Changed diag operator tests to use np.diag() as comparison
* Fixed kernel bug in gpu diag operator
* Replaced the min operator with an inline if statement.
* Added diag to ndarray and symbol
* Replaced the type of parameter k from int32 to nnvm::dim
* Added default argument to k in ndarray and symbol
* Fixed ndarray and symbol diag calls
* Fixed the optional k parameter
* Fixed cpp linting error
* Changed test data datatype to float32
* K values resulting into 0-sized diagonals will now throw an exception.
Added matching test case
* Fixed unittest
* Added diag to NDArray and Symbol api doc
* Added missing api doc
---
docs/api/python/ndarray/ndarray.md | 2 +
docs/api/python/symbol/symbol.md | 2 +
python/mxnet/ndarray/ndarray.py | 8 ++
python/mxnet/symbol/symbol.py | 8 ++
src/operator/tensor/diag_op-inl.h | 217 +++++++++++++++++++++++++++++++++
src/operator/tensor/diag_op.cc | 93 ++++++++++++++
src/operator/tensor/diag_op.cu | 39 ++++++
tests/python/unittest/test_operator.py | 75 +++++++++++-
8 files changed, 443 insertions(+), 1 deletion(-)
diff --git a/docs/api/python/ndarray/ndarray.md
b/docs/api/python/ndarray/ndarray.md
index dda5341..8494120 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -131,6 +131,7 @@ The `ndarray` package provides several classes:
NDArray.flatten
NDArray.expand_dims
NDArray.split
+ NDArray.diag
```
### Array expand elements
@@ -364,6 +365,7 @@ The `ndarray` package provides several classes:
ones_like
full
arange
+ diag
load
save
```
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index 304b178..a59a927 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -182,6 +182,7 @@ Composite multiple symbols into a new one by an operator.
Symbol.zeros_like
Symbol.ones_like
+ Symbol.diag
```
### Changing shape and type
@@ -381,6 +382,7 @@ Composite multiple symbols into a new one by an operator.
reshape_like
flatten
expand_dims
+ diag
```
### Expanding elements
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 09395e2..ff9aac0 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1302,6 +1302,14 @@ fixed-size items.
"""
return op.flip(self, *args, **kwargs)
+ def diag(self, k=0, **kwargs):
+ """Convenience fluent method for :py:func:`diag`.
+
+ The arguments are the same as for :py:func:`diag`, with
+ this array as data.
+ """
+ return op.diag(self, k, **kwargs)
+
def sum(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sum`.
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index b041f4e..88f92cd 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -2038,6 +2038,14 @@ class Symbol(SymbolBase):
"""
return op.flip(self, *args, **kwargs)
+ def diag(self, k=0, **kwargs):
+ """Convenience fluent method for :py:func:`diag`.
+
+ The arguments are the same as for :py:func:`diag`, with
+ this array as data.
+ """
+ return op.diag(self, k, **kwargs)
+
def sum(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sum`.
diff --git a/src/operator/tensor/diag_op-inl.h
b/src/operator/tensor/diag_op-inl.h
new file mode 100644
index 0000000..3bc240f
--- /dev/null
+++ b/src/operator/tensor/diag_op-inl.h
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op-inl.h
+* \brief CPU Implementation of the diag op
+* \author Istvan Fehervari
+*/
+
+#ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
+#define MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
+
+#include <dmlc/parameter.h>
+#include <vector>
+#include <algorithm>
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+struct DiagParam : public dmlc::Parameter<DiagParam> {
+ dmlc::optional<int> k;
+ DMLC_DECLARE_PARAMETER(DiagParam) {
+ DMLC_DECLARE_FIELD(k)
+ .set_default(dmlc::optional<int>(0))
+ .describe("Diagonal in question. The default is 0. "
+ "Use k>0 for diagonals above the main diagonal, "
+ "and k<0 for diagonals below the main diagonal. "
+ "If input has shape (S0 S1) k must be between -S0 and S1");
+ }
+};
+
+inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
+ if (ishape.ndim() == 1) {
+ auto s = ishape[0] + std::abs(k);
+ return TShape({s, s});
+ }
+
+ auto h = ishape[0];
+ auto w = ishape[1];
+
+ if (k > 0) {
+ w -= k;
+ } else if (k < 0) {
+ h += k;
+ }
+
+ auto s = std::min(h, w);
+ if (s < 0) {
+ s = 0;
+ }
+
+ return TShape({s});
+}
+
+inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape>* in_attrs,
+ std::vector<TShape>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+
+ const TShape& ishape = (*in_attrs)[0];
+ if (ishape.ndim() == 0) return false;
+ if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d.";
+
+ const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+ TShape oshape = DiagShapeImpl(ishape, param.k.value());
+ if (shape_is_none(oshape)) {
+ LOG(FATAL) << "Diagonal does not exist.";
+ }
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
+
+ return out_attrs->at(0).ndim() != 0U;
+}
+
+inline bool DiagOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ CHECK_EQ(out_attrs->size(), 1U);
+
+ TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
+ TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
+ return (*out_attrs)[0] != -1;
+}
+
+template<int req>
+struct diag {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
+ mshadow::Shape<2> ishape, int k) {
+ using namespace mxnet_op;
+ int j = 0;
+ if (k > 0) {
+ j = ravel(mshadow::Shape2(i, i + k), ishape);
+ } else if (k < 0) {
+ j = ravel(mshadow::Shape2(i - k, i), ishape);
+ } else {
+ j = ravel(mshadow::Shape2(i, i), ishape);
+ }
+
+ KERNEL_ASSIGN(out[i], req, a[j]);
+ }
+};
+
+template<int req>
+struct diag_gen {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
+ mshadow::Shape<2> oshape, int k) {
+ using namespace mxnet_op;
+
+ auto j = unravel(i, oshape);
+ if (j[1] == (j[0] + k)) {
+ auto l = j[0] < j[1] ? j[0] : j[1];
+ KERNEL_ASSIGN(out[i], req, a[l]);
+ } else {
+ KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
+ }
+ }
+};
+
+template<typename xpu>
+void DiagOpForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mxnet_op;
+ using namespace mshadow;
+ CHECK_EQ(inputs.size(), 1U);
+ CHECK_EQ(outputs.size(), 1U);
+ CHECK_EQ(req.size(), 1U);
+ CHECK_EQ(req[0], kWriteTo);
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ const TBlob& in_data = inputs[0];
+ const TBlob& out_data = outputs[0];
+ const TShape& ishape = inputs[0].shape_;
+ const TShape& oshape = outputs[0].shape_;
+ const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+ if (ishape.ndim() == 2) {
+ MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+ Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(),
out_data.dptr<DType>(),
+ in_data.dptr<DType>(), Shape2(ishape[0],
ishape[1]), param.k.value());
+ });
+ });
+ } else {
+ MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+ Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(),
out_data.dptr<DType>(),
+ in_data.dptr<DType>(), Shape2(oshape[0],
oshape[1]), param.k.value());
+ });
+ });
+ }
+}
+
+template<typename xpu>
+void DiagOpBackward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mxnet_op;
+ using namespace mshadow;
+ CHECK_EQ(inputs.size(), 1U);
+ CHECK_EQ(outputs.size(), 1U);
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+
+ const TBlob& in_data = inputs[0];
+ const TBlob& out_data = outputs[0];
+ const TShape& ishape = inputs[0].shape_;
+ const TShape& oshape = outputs[0].shape_;
+ const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
+
+ if (oshape.ndim() == 2) {
+ MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+ Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(),
out_data.dptr<DType>(),
+ in_data.dptr<DType>(), Shape2(oshape[0],
oshape[1]), param.k.value());
+ });
+ });
+ } else {
+ MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+ MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+ Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(),
out_data.dptr<DType>(),
+ in_data.dptr<DType>(), Shape2(ishape[0],
ishape[1]), param.k.value());
+ });
+ });
+ }
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
diff --git a/src/operator/tensor/diag_op.cc b/src/operator/tensor/diag_op.cc
new file mode 100644
index 0000000..1ad3b8a
--- /dev/null
+++ b/src/operator/tensor/diag_op.cc
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op.cc
+* \brief
+* \author Istvan Fehervari
+*/
+
+#include "./diag_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(DiagParam);
+
+NNVM_REGISTER_OP(diag)
+.describe(R"code(Extracts a diagonal or constructs a diagonal array.
+
+``diag``'s behavior depends on the input array dimensions:
+
+- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other
elements are zero
+- 2-D arrays: returns elements in the diagonal as a new 1-D array
+- N-D arrays: not supported yet
+
+Examples::
+
+ x = [[1, 2, 3],
+ [4, 5, 6]]
+
+ diag(x) = [1, 5]
+
+ diag(x, k=1) = [2, 6]
+
+ diag(x, k=-1) = [4]
+
+ x = [1, 2, 3]
+
+ diag(x) = [[1, 0, 0],
+ [0, 2, 0],
+ [0, 0, 3]]
+
+ diag(x, k=1) = [[0, 1, 0],
+ [0, 0, 2],
+ [0, 0, 0]]
+
+ diag(x, k=-1) = [[0, 0, 0],
+ [1, 0, 0],
+ [0, 2, 0]]
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<DiagParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"data"};
+ })
+.set_attr<nnvm::FInferShape>("FInferShape", DiagOpShape)
+.set_attr<nnvm::FInferType>("FInferType", DiagOpType)
+.set_attr<FCompute>("FCompute<cpu>", DiagOpForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_diag"})
+.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
+.add_arguments(DiagParam::__FIELDS__());
+
+
+NNVM_REGISTER_OP(_backward_diag)
+.set_attr_parser(ParamParser<DiagParam>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", DiagOpBackward<cpu>);
+
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/tensor/diag_op.cu b/src/operator/tensor/diag_op.cu
new file mode 100644
index 0000000..a3928f7
--- /dev/null
+++ b/src/operator/tensor/diag_op.cu
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+* Copyright (c) 2015 by Contributors
+* \file diag_op.cu
+* \brief GPU Implementation of the diag op
+* \author Istvan Fehervari
+*/
+
+#include "./diag_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+NNVM_REGISTER_OP(diag)
+.set_attr<FCompute>("FCompute<gpu>", DiagOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_diag)
+.set_attr<FCompute>("FCompute<gpu>", DiagOpBackward<gpu>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index c870709..0592e5a 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -27,7 +27,7 @@ from distutils.version import LooseVersion
from numpy.testing import assert_allclose, assert_array_equal
from mxnet.test_utils import *
from mxnet.base import py_str, MXNetError, _as_list
-from common import setup_module, with_seed, teardown,
assert_raises_cudnn_disabled
+from common import setup_module, with_seed, teardown,
assert_raises_cudnn_disabled, assertRaises
import unittest
def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req):
@@ -7045,6 +7045,79 @@ def test_op_roi_align():
test_roi_align_value(2)
test_roi_align_autograd()
+@with_seed()
+def test_diag():
+
+ # Test 2d input
+ h = np.random.randint(2,9)
+ w = np.random.randint(2,9)
+ a_np = np.random.random((h, w)).astype(np.float32)
+ a = mx.nd.array(a_np).astype('float32')
+
+ # k == 0
+ r = mx.nd.diag(a)
+ assert_almost_equal(r.asnumpy(), np.diag(a_np))
+
+ # k == 1
+ k = 1
+ r = mx.nd.diag(a, k=k)
+ assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+ # k == -1
+ k = -1
+ r = mx.nd.diag(a, k=k)
+ assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+ # random k
+ k = np.random.randint(-min(h,w) + 1, min(h,w))
+ r = mx.nd.diag(a, k=k)
+ assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+ # invalid k
+ k = max(h,w) + 1
+ assertRaises(MXNetError, mx.nd.diag, a, k=k)
+
+ # Test 2d backward, k=0
+ data = mx.sym.Variable('data')
+ diag_sym = mx.sym.diag(data=data)
+ check_numeric_gradient(diag_sym, [a_np])
+
+ # Test 2d backward, k=1
+ data = mx.sym.Variable('data')
+ diag_sym = mx.sym.diag(data=data, k=1)
+ check_numeric_gradient(diag_sym, [a_np])
+
+ # Test 2d backward, k=-1
+ data = mx.sym.Variable('data')
+ diag_sym = mx.sym.diag(data=data, k=-1)
+ check_numeric_gradient(diag_sym, [a_np])
+
+ # test 1d input
+ d = np.random.randint(2,9)
+ a_np = np.random.random((d))
+ a = mx.nd.array(a_np)
+
+ # k is random
+ k = np.random.randint(-d,d)
+ r = mx.nd.diag(a, k=k)
+
+ assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
+
+ # Test 2d backward, k=0
+ data = mx.sym.Variable('data')
+ diag_sym = mx.sym.diag(data=data)
+ check_numeric_gradient(diag_sym, [a_np])
+
+ # Test 2d backward, k=1
+ data = mx.sym.Variable('data')
+ diag_sym = mx.sym.diag(data=data, k=1)
+ check_numeric_gradient(diag_sym, [a_np])
+
+ # Test 2d backward, k=-1
+ data = mx.sym.Variable('data')
+ diag_sym = mx.sym.diag(data=data, k=-1)
+ check_numeric_gradient(diag_sym, [a_np])
+
if __name__ == '__main__':
import nose