This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e95bcd  [MXNET-882] Support for N-d arrays added to diag op. (#12430)
1e95bcd is described below

commit 1e95bcd5841d50288c722bfe6cf69571acaffb9d
Author: Jason Yu <[email protected]>
AuthorDate: Wed Sep 19 01:58:26 2018 +0800

    [MXNET-882] Support for N-d arrays added to diag op. (#12430)
    
    * Support for N-d arrays added to diag op.
    
    * Doc for diag updated. Sanity fix. Unit test for N-d diag op added.
    
    * Unwanted print in diag test removed.
    
    * Bad negative axis support in diag fixed.
    
    * Bad negative axis support in diag fixed.
    
    * Index overflow in diag fixed. Exemplars for Nd diag added.
    
    * A bug in diag op fixed.
    
    * Types of axis number and dim size changed to dim_t in diag.
    
    * A bug in params of diag fixed.
    
    * Poisonous dim_t removed from diag parameters.
    
    * Diag impl details documented in comments.
---
 CONTRIBUTORS.md                        |   2 +
 src/operator/tensor/diag_op-inl.h      | 232 ++++++++++++++++++++++++---------
 src/operator/tensor/diag_op.cc         |  27 +++-
 tests/python/unittest/test_operator.py |  48 ++++++-
 4 files changed, 244 insertions(+), 65 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 8d8aeac..3ae6129 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -178,3 +178,5 @@ List of Contributors
 * [Aaron Markham](https://github.com/aaronmarkham)
 * [Sam Skalicky](https://github.com/samskalicky)
 * [Per Goncalves da Silva](https://github.com/perdasilva)
+* [Zhijingcheng Yu](https://github.com/jasonyu1996)
+* [Cheng-Che Lee](https://github.com/stu1130)
diff --git a/src/operator/tensor/diag_op-inl.h 
b/src/operator/tensor/diag_op-inl.h
index 3bc240f..deab256 100644
--- a/src/operator/tensor/diag_op-inl.h
+++ b/src/operator/tensor/diag_op-inl.h
@@ -21,7 +21,7 @@
 * Copyright (c) 2015 by Contributors
 * \file diag_op-inl.h
 * \brief CPU Implementation of the diag op
-* \author Istvan Fehervari
+* \author Istvan Fehervari, Zhijingcheng Yu
 */
 
 #ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
@@ -30,33 +30,51 @@
 #include <dmlc/parameter.h>
 #include <vector>
 #include <algorithm>
+#include <utility>
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 #include "../elemwise_op_common.h"
+#include "./broadcast_reduce_op.h"
 
 namespace mxnet {
 namespace op {
 
 struct DiagParam : public dmlc::Parameter<DiagParam> {
-  dmlc::optional<int> k;
+  int k;
+  int32_t axis1;
+  int32_t axis2;
   DMLC_DECLARE_PARAMETER(DiagParam) {
     DMLC_DECLARE_FIELD(k)
-    .set_default(dmlc::optional<int>(0))
-    .describe("Diagonal in question. The default is 0. "
-              "Use k>0 for diagonals above the main diagonal, "
-              "and k<0 for diagonals below the main diagonal. "
-              "If input has shape (S0 S1) k must be between -S0 and S1");
+      .set_default(0)
+      .describe("Diagonal in question. The default is 0. "
+                "Use k>0 for diagonals above the main diagonal, "
+                "and k<0 for diagonals below the main diagonal. "
+                "If input has shape (S0 S1) k must be between -S0 and S1");
+    DMLC_DECLARE_FIELD(axis1)
+      .set_default(0)
+      .describe("The first axis of the sub-arrays of interest. "
+                "Ignored when the input is a 1-D array.");
+    DMLC_DECLARE_FIELD(axis2)
+      .set_default(1)
+      .describe("The second axis of the sub-arrays of interest. "
+                "Ignored when the input is a 1-D array.");
   }
 };
 
-inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
+inline TShape DiagShapeImpl(const TShape& ishape, const int k,
+                            const int32_t axis1, const int32_t axis2) {
   if (ishape.ndim() == 1) {
     auto s = ishape[0] + std::abs(k);
     return TShape({s, s});
   }
 
-  auto h = ishape[0];
-  auto w = ishape[1];
+  int32_t x1 = CheckAxis(axis1, ishape.ndim());
+  int32_t x2 = CheckAxis(axis2, ishape.ndim());
+
+  CHECK_NE(x1, x2) << "axis1 and axis2 cannot refer to the the same axis " << 
x1;
+
+  auto h = ishape[x1];
+  auto w = ishape[x2];
 
   if (k > 0) {
     w -= k;
@@ -69,7 +87,24 @@ inline TShape DiagShapeImpl(const TShape& ishape, const 
nnvm::dim_t k) {
     s = 0;
   }
 
-  return TShape({s});
+  if (x1 > x2) {
+    std::swap(x1, x2);
+  }
+
+  int32_t n_dim = static_cast<int32_t>(ishape.ndim()) - 1;
+  TShape oshape(n_dim);
+
+  // remove axis1 and axis2 and append the new axis to the end
+  uint32_t idx = 0;
+  for (int32_t i = 0; i <= n_dim; ++i) {
+    if (i != x1 && i != x2) {
+      oshape[idx++] = ishape[i];
+    }
+  }
+
+  oshape[n_dim - 1] = s;
+
+  return oshape;
 }
 
 inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
@@ -79,12 +114,16 @@ inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
     CHECK_EQ(out_attrs->size(), 1U);
 
     const TShape& ishape = (*in_attrs)[0];
-    if (ishape.ndim() == 0) return false;
-    if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d.";
+    if (ishape.ndim() == 0) {
+      return false;
+    }
 
     const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
 
-    TShape oshape = DiagShapeImpl(ishape, param.k.value());
+    TShape oshape = DiagShapeImpl(ishape,
+                                  param.k,
+                                  param.axis1,
+                                  param.axis2);
     if (shape_is_none(oshape)) {
       LOG(FATAL) << "Diagonal does not exist.";
     }
@@ -104,42 +143,144 @@ inline bool DiagOpType(const nnvm::NodeAttrs& attrs,
   return (*out_attrs)[0] != -1;
 }
 
-template<int req>
+template<int ndim, int req, bool back>
 struct diag {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
-                                  mshadow::Shape<2> ishape, int k) {
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
+                                  mshadow::Shape<ndim> oshape,
+                                  mshadow::Shape<ndim> ishape,
+                                  index_t stride, index_t offset,
+                                  index_t base) {
     using namespace mxnet_op;
-    int j = 0;
-    if (k > 0) {
-      j = ravel(mshadow::Shape2(i, i + k), ishape);
-    } else if (k < 0) {
-      j = ravel(mshadow::Shape2(i - k, i), ishape);
+    index_t idx = i / base;
+    index_t j = ravel(unravel(idx, oshape), ishape) + offset + stride * (i - 
idx * base);
+    if (back) {
+      KERNEL_ASSIGN(out[j], req, a[i]);
     } else {
-      j = ravel(mshadow::Shape2(i, i), ishape);
+      KERNEL_ASSIGN(out[i], req, a[j]);
     }
-
-    KERNEL_ASSIGN(out[i], req, a[j]);
   }
 };
 
-template<int req>
+template<int req, bool back>
 struct diag_gen {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
                                   mshadow::Shape<2> oshape, int k) {
     using namespace mxnet_op;
 
     auto j = unravel(i, oshape);
     if (j[1] == (j[0] + k)) {
       auto l = j[0] < j[1] ? j[0] : j[1];
-      KERNEL_ASSIGN(out[i], req, a[l]);
-    } else {
+      if (back) {
+        KERNEL_ASSIGN(out[l], req, a[i]);
+      } else {
+        KERNEL_ASSIGN(out[i], req, a[l]);
+      }
+    } else if (!back) {
       KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
     }
   }
 };
 
+template<typename xpu, bool back>
+void DiagOpProcess(const TBlob& in_data,
+                   const TBlob& out_data,
+                   const TShape& ishape,
+                   const TShape& oshape,
+                   index_t dsize,
+                   const DiagParam& param,
+                   mxnet_op::Stream<xpu> *s,
+                   const std::vector<OpReqType>& req) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  if (ishape.ndim() > 1) {
+    // input : (leading + i, body + i, trailing)
+    uint32_t x1 = CheckAxis(param.axis1, ishape.ndim());
+    uint32_t x2 = CheckAxis(param.axis2, ishape.ndim());
+
+    uint32_t idim = ishape.ndim(), odim = oshape.ndim();
+
+    uint32_t minx = x1, maxx = x2;
+    if (minx > maxx) {
+      std::swap(minx, maxx);
+    }
+
+    // merges contiguous axes that are not separated
+    // by axis1 or axis2 since they can be directly
+    // mapped to the output and there is no need
+    // to distinguish them
+    // (After this the input will have no more than
+    // three axes, hence improving the rave and
+    // unravel efficiency)
+
+    index_t oleading = 1,
+           obody = 1,
+           otrailing = 1;
+
+    for (uint32_t i = 0; i < minx; ++i) {
+      oleading *= ishape[i];
+    }
+    for (uint32_t i = minx + 1; i < maxx; ++i) {
+      obody *= ishape[i];
+    }
+    for (uint32_t i = maxx + 1; i < idim; ++i) {
+      otrailing *= ishape[i];
+    }
+
+    index_t ileading = oleading,
+        ibody = obody * ishape[minx],
+        itrailing = otrailing * ishape[maxx];
+
+    index_t stride1 = itrailing * obody,
+        stride2 = otrailing;
+    // stride1 + stride2 is the stride for
+    // iterating over the diagonal in question
+
+    if (x1 == maxx) {
+      std::swap(stride1, stride2);
+    }
+
+    // the extra index offset introduced by k
+    index_t offset;
+    int k = param.k;
+    if (k > 0) {
+      offset = stride2 * k;
+    } else if (k < 0) {
+      offset = stride1 * -k;
+    } else {
+      offset = 0;
+    }
+
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        if (back && req[0] != kAddTo && req[0] != kNullOp) {
+          out_data.FlatTo1D<xpu, DType>(s) = 0;
+        }
+        if (ileading == 1) {
+          Kernel<diag<2, req_type, back>, xpu>::Launch(s, dsize, 
out_data.dptr<DType>(),
+                              in_data.dptr<DType>(), Shape2(obody, otrailing),
+                              Shape2(ibody, itrailing),
+                              stride1 + stride2, offset, oshape[odim - 1]);
+        } else {
+          Kernel<diag<3, req_type, back>, xpu>::Launch(s, dsize, 
out_data.dptr<DType>(),
+                              in_data.dptr<DType>(), Shape3(oleading, obody, 
otrailing),
+                              Shape3(ileading, ibody, itrailing),
+                              stride1 + stride2, offset, oshape[odim - 1]);
+        }
+      });
+    });
+  } else {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag_gen<req_type, back>, xpu>::Launch(s, dsize, 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]),
+                            param.k);
+      });
+    });
+  }
+}
+
 template<typename xpu>
 void DiagOpForward(const nnvm::NodeAttrs& attrs,
                    const OpContext& ctx,
@@ -159,21 +300,7 @@ void DiagOpForward(const nnvm::NodeAttrs& attrs,
   const TShape& oshape = outputs[0].shape_;
   const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
 
-  if (ishape.ndim() == 2) {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
-      });
-    });
-  } else {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
-      });
-    });
-  }
+  DiagOpProcess<xpu, false>(in_data, out_data, ishape, oshape, 
out_data.Size(), param, s, req);
 }
 
 template<typename xpu>
@@ -194,23 +321,10 @@ void DiagOpBackward(const nnvm::NodeAttrs& attrs,
   const TShape& oshape = outputs[0].shape_;
   const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
 
-  if (oshape.ndim() == 2) {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
-      });
-    });
-  } else {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
-      });
-    });
-  }
+  DiagOpProcess<xpu, true>(in_data, out_data, oshape, ishape, in_data.Size(), 
param, s, req);
 }
 
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/tensor/diag_op.cc b/src/operator/tensor/diag_op.cc
index 1ad3b8a..cd5be9d 100644
--- a/src/operator/tensor/diag_op.cc
+++ b/src/operator/tensor/diag_op.cc
@@ -21,7 +21,7 @@
 * Copyright (c) 2015 by Contributors
 * \file diag_op.cc
 * \brief
-* \author Istvan Fehervari
+* \author Istvan Fehervari, Zhijingcheng Yu
 */
 
 #include "./diag_op-inl.h"
@@ -36,9 +36,13 @@ NNVM_REGISTER_OP(diag)
 
 ``diag``'s behavior depends on the input array dimensions:
 
-- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other 
elements are zero
-- 2-D arrays: returns elements in the diagonal as a new 1-D array
-- N-D arrays: not supported yet
+- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other 
elements are zero.
+- N-D arrays: extracts the diagonals of the sub-arrays with axes specified by 
``axis1`` and ``axis2``.
+  The output shape would be decided by removing the axes numbered ``axis1`` 
and ``axis2`` from the
+  input shape and appending to the result a new axis with the size of the 
diagonals in question.
+
+  For example, when the input shape is `(2, 3, 4, 5)`, ``axis1`` and ``axis2`` 
are 0 and 2
+  respectively and ``k`` is 0, the resulting shape would be `(3, 5, 2)`.
 
 Examples::
 
@@ -65,6 +69,21 @@ Examples::
                    [1, 0, 0],
                    [0, 2, 0]]
 
+  x = [[[1, 2],
+        [3, 4]],
+
+       [[5, 6],
+        [7, 8]]]
+
+  diag(x) = [[1, 7],
+             [2, 8]]
+
+  diag(x, k=1) = [[3],
+                  [4]]
+
+  diag(x, axis1=-2, axis2=-1) = [[1, 4],
+                                 [5, 8]]
+
 )code" ADD_FILELINE)
 .set_attr_parser(ParamParser<DiagParam>)
 .set_num_inputs(1)
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 5937ecd..43c3578 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4339,7 +4339,7 @@ def test_where():
         x = mx.sym.Variable('x')
         y = mx.sym.Variable('y')
         where_sym = mx.sym.where(condition, x, y)
-       
+
         assert_exception(lambda: 
where_sym.eval(x=mx.nd.array([[2,3],[4,5],[6,7]]),
                                                 
y=mx.nd.array([[8,9],[10,11],[12,13]]),
                                                 condition=mx.nd.array([1,0])), 
MXNetError)
@@ -4982,7 +4982,7 @@ def _validate_sample_location(input_rois, input_offset, 
spatial_scale, pooled_w,
                     trans_x = input_offset[roi_idx, class_id * 2, part_h, 
part_w] * trans_std
                     trans_y = input_offset[roi_idx, class_id * 2 + 1, part_h, 
part_w] * trans_std
                     bin_h_start, bin_w_start = ph * bin_size_h + roi_start_h, 
pw * bin_size_w + roi_start_w
-                    
+
                     need_check = True
                     while need_check:
                         pass_check = True
@@ -6812,6 +6812,50 @@ def test_diag():
     diag_sym = mx.sym.diag(data=data, k=-1)
     check_numeric_gradient(diag_sym, [a_np])
 
+    # Test 4d input
+    x1 = np.random.randint(3,9)
+    x2 = np.random.randint(3,9)
+    x3 = np.random.randint(3,9)
+    x4 = np.random.randint(3,9)
+    a_np = np.random.random((x1, x2, x3, x4)).astype(np.float32)
+    a = mx.nd.array(a_np).astype('float32')
+
+    # k = 0, axis1=0, axis2=1
+    r = mx.nd.diag(data=a, k=0, axis1=0, axis2=1)
+    assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=0, axis1=0, 
axis2=1))
+
+    # k = 1, axis1=1, axis2=0
+    r = mx.nd.diag(data=a, k=1, axis1=1, axis2=0)
+    assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=1, axis1=1, 
axis2=0))
+
+    # k = -1 axis1=1, axis3=3
+    r = mx.nd.diag(data=a, k=-1, axis1=1, axis2=3)
+    assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=-1, axis1=1, 
axis2=3))
+
+    # k = 2, axis1=-2, axis2=0
+    r = mx.nd.diag(data=a, k=2, axis1=-2, axis2=0)
+    assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=2, axis1=-2, 
axis2=0))
+
+    # Test 4d backward, k=0, axis1=3, axis2=0
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=0, axis1=3, axis2=0)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 4d backward, k=1, axis1=1, axis2=2
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=1, axis1=1, axis2=2)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 4d backward, k=-1, axis1=2, axis2=0
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=-1, axis1=2, axis2=0)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 4d backward, k=-2, axis1=1, axis2=-1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=-2, axis1=1, axis2=-1)
+    check_numeric_gradient(diag_sym, [a_np])
+
 @with_seed()
 def test_depthtospace():
     def f(x, blocksize):

Reply via email to