sandeep-krishnamurthy closed pull request #12430: [MXNET-882] Support for N-d 
arrays added to diag op.
URL: https://github.com/apache/incubator-mxnet/pull/12430
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 8d8aeaca73e..3ae61298de8 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -178,3 +178,5 @@ List of Contributors
 * [Aaron Markham](https://github.com/aaronmarkham)
 * [Sam Skalicky](https://github.com/samskalicky)
 * [Per Goncalves da Silva](https://github.com/perdasilva)
+* [Zhijingcheng Yu](https://github.com/jasonyu1996)
+* [Cheng-Che Lee](https://github.com/stu1130)
diff --git a/src/operator/tensor/diag_op-inl.h 
b/src/operator/tensor/diag_op-inl.h
index 3bc240f206b..deab2569e48 100644
--- a/src/operator/tensor/diag_op-inl.h
+++ b/src/operator/tensor/diag_op-inl.h
@@ -21,7 +21,7 @@
 * Copyright (c) 2015 by Contributors
 * \file diag_op-inl.h
 * \brief CPU Implementation of the diag op
-* \author Istvan Fehervari
+* \author Istvan Fehervari, Zhijingcheng Yu
 */
 
 #ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_
@@ -30,33 +30,51 @@
 #include <dmlc/parameter.h>
 #include <vector>
 #include <algorithm>
+#include <utility>
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 #include "../elemwise_op_common.h"
+#include "./broadcast_reduce_op.h"
 
 namespace mxnet {
 namespace op {
 
 struct DiagParam : public dmlc::Parameter<DiagParam> {
-  dmlc::optional<int> k;
+  int k;
+  int32_t axis1;
+  int32_t axis2;
   DMLC_DECLARE_PARAMETER(DiagParam) {
     DMLC_DECLARE_FIELD(k)
-    .set_default(dmlc::optional<int>(0))
-    .describe("Diagonal in question. The default is 0. "
-              "Use k>0 for diagonals above the main diagonal, "
-              "and k<0 for diagonals below the main diagonal. "
-              "If input has shape (S0 S1) k must be between -S0 and S1");
+      .set_default(0)
+      .describe("Diagonal in question. The default is 0. "
+                "Use k>0 for diagonals above the main diagonal, "
+                "and k<0 for diagonals below the main diagonal. "
+                "If input has shape (S0 S1) k must be between -S0 and S1");
+    DMLC_DECLARE_FIELD(axis1)
+      .set_default(0)
+      .describe("The first axis of the sub-arrays of interest. "
+                "Ignored when the input is a 1-D array.");
+    DMLC_DECLARE_FIELD(axis2)
+      .set_default(1)
+      .describe("The second axis of the sub-arrays of interest. "
+                "Ignored when the input is a 1-D array.");
   }
 };
 
-inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) {
+inline TShape DiagShapeImpl(const TShape& ishape, const int k,
+                            const int32_t axis1, const int32_t axis2) {
   if (ishape.ndim() == 1) {
     auto s = ishape[0] + std::abs(k);
     return TShape({s, s});
   }
 
-  auto h = ishape[0];
-  auto w = ishape[1];
+  int32_t x1 = CheckAxis(axis1, ishape.ndim());
+  int32_t x2 = CheckAxis(axis2, ishape.ndim());
+
+  CHECK_NE(x1, x2) << "axis1 and axis2 cannot refer to the the same axis " << 
x1;
+
+  auto h = ishape[x1];
+  auto w = ishape[x2];
 
   if (k > 0) {
     w -= k;
@@ -69,7 +87,24 @@ inline TShape DiagShapeImpl(const TShape& ishape, const 
nnvm::dim_t k) {
     s = 0;
   }
 
-  return TShape({s});
+  if (x1 > x2) {
+    std::swap(x1, x2);
+  }
+
+  int32_t n_dim = static_cast<int32_t>(ishape.ndim()) - 1;
+  TShape oshape(n_dim);
+
+  // remove axis1 and axis2 and append the new axis to the end
+  uint32_t idx = 0;
+  for (int32_t i = 0; i <= n_dim; ++i) {
+    if (i != x1 && i != x2) {
+      oshape[idx++] = ishape[i];
+    }
+  }
+
+  oshape[n_dim - 1] = s;
+
+  return oshape;
 }
 
 inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
@@ -79,12 +114,16 @@ inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
     CHECK_EQ(out_attrs->size(), 1U);
 
     const TShape& ishape = (*in_attrs)[0];
-    if (ishape.ndim() == 0) return false;
-    if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d.";
+    if (ishape.ndim() == 0) {
+      return false;
+    }
 
     const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
 
-    TShape oshape = DiagShapeImpl(ishape, param.k.value());
+    TShape oshape = DiagShapeImpl(ishape,
+                                  param.k,
+                                  param.axis1,
+                                  param.axis2);
     if (shape_is_none(oshape)) {
       LOG(FATAL) << "Diagonal does not exist.";
     }
@@ -104,42 +143,144 @@ inline bool DiagOpType(const nnvm::NodeAttrs& attrs,
   return (*out_attrs)[0] != -1;
 }
 
-template<int req>
+template<int ndim, int req, bool back>
 struct diag {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
-                                  mshadow::Shape<2> ishape, int k) {
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
+                                  mshadow::Shape<ndim> oshape,
+                                  mshadow::Shape<ndim> ishape,
+                                  index_t stride, index_t offset,
+                                  index_t base) {
     using namespace mxnet_op;
-    int j = 0;
-    if (k > 0) {
-      j = ravel(mshadow::Shape2(i, i + k), ishape);
-    } else if (k < 0) {
-      j = ravel(mshadow::Shape2(i - k, i), ishape);
+    index_t idx = i / base;
+    index_t j = ravel(unravel(idx, oshape), ishape) + offset + stride * (i - 
idx * base);
+    if (back) {
+      KERNEL_ASSIGN(out[j], req, a[i]);
     } else {
-      j = ravel(mshadow::Shape2(i, i), ishape);
+      KERNEL_ASSIGN(out[i], req, a[j]);
     }
-
-    KERNEL_ASSIGN(out[i], req, a[j]);
   }
 };
 
-template<int req>
+template<int req, bool back>
 struct diag_gen {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a,
+  MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* a,
                                   mshadow::Shape<2> oshape, int k) {
     using namespace mxnet_op;
 
     auto j = unravel(i, oshape);
     if (j[1] == (j[0] + k)) {
       auto l = j[0] < j[1] ? j[0] : j[1];
-      KERNEL_ASSIGN(out[i], req, a[l]);
-    } else {
+      if (back) {
+        KERNEL_ASSIGN(out[l], req, a[i]);
+      } else {
+        KERNEL_ASSIGN(out[i], req, a[l]);
+      }
+    } else if (!back) {
       KERNEL_ASSIGN(out[i], req, static_cast<DType>(0));
     }
   }
 };
 
+template<typename xpu, bool back>
+void DiagOpProcess(const TBlob& in_data,
+                   const TBlob& out_data,
+                   const TShape& ishape,
+                   const TShape& oshape,
+                   index_t dsize,
+                   const DiagParam& param,
+                   mxnet_op::Stream<xpu> *s,
+                   const std::vector<OpReqType>& req) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  if (ishape.ndim() > 1) {
+    // input : (leading + i, body + i, trailing)
+    uint32_t x1 = CheckAxis(param.axis1, ishape.ndim());
+    uint32_t x2 = CheckAxis(param.axis2, ishape.ndim());
+
+    uint32_t idim = ishape.ndim(), odim = oshape.ndim();
+
+    uint32_t minx = x1, maxx = x2;
+    if (minx > maxx) {
+      std::swap(minx, maxx);
+    }
+
+    // merges contiguous axes that are not separated
+    // by axis1 or axis2 since they can be directly
+    // mapped to the output and there is no need
+    // to distinguish them
+    // (After this the input will have no more than
+    // three axes, hence improving the rave and
+    // unravel efficiency)
+
+    index_t oleading = 1,
+           obody = 1,
+           otrailing = 1;
+
+    for (uint32_t i = 0; i < minx; ++i) {
+      oleading *= ishape[i];
+    }
+    for (uint32_t i = minx + 1; i < maxx; ++i) {
+      obody *= ishape[i];
+    }
+    for (uint32_t i = maxx + 1; i < idim; ++i) {
+      otrailing *= ishape[i];
+    }
+
+    index_t ileading = oleading,
+        ibody = obody * ishape[minx],
+        itrailing = otrailing * ishape[maxx];
+
+    index_t stride1 = itrailing * obody,
+        stride2 = otrailing;
+    // stride1 + stride2 is the stride for
+    // iterating over the diagonal in question
+
+    if (x1 == maxx) {
+      std::swap(stride1, stride2);
+    }
+
+    // the extra index offset introduced by k
+    index_t offset;
+    int k = param.k;
+    if (k > 0) {
+      offset = stride2 * k;
+    } else if (k < 0) {
+      offset = stride1 * -k;
+    } else {
+      offset = 0;
+    }
+
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        if (back && req[0] != kAddTo && req[0] != kNullOp) {
+          out_data.FlatTo1D<xpu, DType>(s) = 0;
+        }
+        if (ileading == 1) {
+          Kernel<diag<2, req_type, back>, xpu>::Launch(s, dsize, 
out_data.dptr<DType>(),
+                              in_data.dptr<DType>(), Shape2(obody, otrailing),
+                              Shape2(ibody, itrailing),
+                              stride1 + stride2, offset, oshape[odim - 1]);
+        } else {
+          Kernel<diag<3, req_type, back>, xpu>::Launch(s, dsize, 
out_data.dptr<DType>(),
+                              in_data.dptr<DType>(), Shape3(oleading, obody, 
otrailing),
+                              Shape3(ileading, ibody, itrailing),
+                              stride1 + stride2, offset, oshape[odim - 1]);
+        }
+      });
+    });
+  } else {
+    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Kernel<diag_gen<req_type, back>, xpu>::Launch(s, dsize, 
out_data.dptr<DType>(),
+                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]),
+                            param.k);
+      });
+    });
+  }
+}
+
 template<typename xpu>
 void DiagOpForward(const nnvm::NodeAttrs& attrs,
                    const OpContext& ctx,
@@ -159,21 +300,7 @@ void DiagOpForward(const nnvm::NodeAttrs& attrs,
   const TShape& oshape = outputs[0].shape_;
   const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
 
-  if (ishape.ndim() == 2) {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
-      });
-    });
-  } else {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
-      });
-    });
-  }
+  DiagOpProcess<xpu, false>(in_data, out_data, ishape, oshape, 
out_data.Size(), param, s, req);
 }
 
 template<typename xpu>
@@ -194,23 +321,10 @@ void DiagOpBackward(const nnvm::NodeAttrs& attrs,
   const TShape& oshape = outputs[0].shape_;
   const DiagParam& param = nnvm::get<DiagParam>(attrs.parsed);
 
-  if (oshape.ndim() == 2) {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag_gen<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(oshape[0], 
oshape[1]), param.k.value());
-      });
-    });
-  } else {
-    MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-        Kernel<diag<req_type>, xpu>::Launch(s, out_data.Size(), 
out_data.dptr<DType>(),
-                            in_data.dptr<DType>(), Shape2(ishape[0], 
ishape[1]), param.k.value());
-      });
-    });
-  }
+  DiagOpProcess<xpu, true>(in_data, out_data, oshape, ishape, in_data.Size(), 
param, s, req);
 }
 
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/tensor/diag_op.cc b/src/operator/tensor/diag_op.cc
index 1ad3b8adc02..cd5be9d0fd5 100644
--- a/src/operator/tensor/diag_op.cc
+++ b/src/operator/tensor/diag_op.cc
@@ -21,7 +21,7 @@
 * Copyright (c) 2015 by Contributors
 * \file diag_op.cc
 * \brief
-* \author Istvan Fehervari
+* \author Istvan Fehervari, Zhijingcheng Yu
 */
 
 #include "./diag_op-inl.h"
@@ -36,9 +36,13 @@ NNVM_REGISTER_OP(diag)
 
 ``diag``'s behavior depends on the input array dimensions:
 
-- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other 
elements are zero
-- 2-D arrays: returns elements in the diagonal as a new 1-D array
-- N-D arrays: not supported yet
+- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other 
elements are zero.
+- N-D arrays: extracts the diagonals of the sub-arrays with axes specified by 
``axis1`` and ``axis2``.
+  The output shape would be decided by removing the axes numbered ``axis1`` 
and ``axis2`` from the
+  input shape and appending to the result a new axis with the size of the 
diagonals in question.
+
+  For example, when the input shape is `(2, 3, 4, 5)`, ``axis1`` and ``axis2`` 
are 0 and 2
+  respectively and ``k`` is 0, the resulting shape would be `(3, 5, 2)`.
 
 Examples::
 
@@ -65,6 +69,21 @@ Examples::
                    [1, 0, 0],
                    [0, 2, 0]]
 
+  x = [[[1, 2],
+        [3, 4]],
+
+       [[5, 6],
+        [7, 8]]]
+
+  diag(x) = [[1, 7],
+             [2, 8]]
+
+  diag(x, k=1) = [[3],
+                  [4]]
+
+  diag(x, axis1=-2, axis2=-1) = [[1, 4],
+                                 [5, 8]]
+
 )code" ADD_FILELINE)
 .set_attr_parser(ParamParser<DiagParam>)
 .set_num_inputs(1)
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 2bf7e848850..3c052ed6608 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4339,7 +4339,7 @@ def test_invalid_shape():
         x = mx.sym.Variable('x')
         y = mx.sym.Variable('y')
         where_sym = mx.sym.where(condition, x, y)
-       
+
         assert_exception(lambda: 
where_sym.eval(x=mx.nd.array([[2,3],[4,5],[6,7]]),
                                                 
y=mx.nd.array([[8,9],[10,11],[12,13]]),
                                                 condition=mx.nd.array([1,0])), 
MXNetError)
@@ -4982,7 +4982,7 @@ def _validate_sample_location(input_rois, input_offset, 
spatial_scale, pooled_w,
                     trans_x = input_offset[roi_idx, class_id * 2, part_h, 
part_w] * trans_std
                     trans_y = input_offset[roi_idx, class_id * 2 + 1, part_h, 
part_w] * trans_std
                     bin_h_start, bin_w_start = ph * bin_size_h + roi_start_h, 
pw * bin_size_w + roi_start_w
-                    
+
                     need_check = True
                     while need_check:
                         pass_check = True
@@ -6812,6 +6812,50 @@ def test_diag():
     diag_sym = mx.sym.diag(data=data, k=-1)
     check_numeric_gradient(diag_sym, [a_np])
 
+    # Test 4d input
+    x1 = np.random.randint(3,9)
+    x2 = np.random.randint(3,9)
+    x3 = np.random.randint(3,9)
+    x4 = np.random.randint(3,9)
+    a_np = np.random.random((x1, x2, x3, x4)).astype(np.float32)
+    a = mx.nd.array(a_np).astype('float32')
+
+    # k = 0, axis1=0, axis2=1
+    r = mx.nd.diag(data=a, k=0, axis1=0, axis2=1)
+    assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=0, axis1=0, 
axis2=1))
+
+    # k = 1, axis1=1, axis2=0
+    r = mx.nd.diag(data=a, k=1, axis1=1, axis2=0)
+    assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=1, axis1=1, 
axis2=0))
+
+    # k = -1 axis1=1, axis3=3
+    r = mx.nd.diag(data=a, k=-1, axis1=1, axis2=3)
+    assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=-1, axis1=1, 
axis2=3))
+
+    # k = 2, axis1=-2, axis2=0
+    r = mx.nd.diag(data=a, k=2, axis1=-2, axis2=0)
+    assert_almost_equal(r.asnumpy(), np.diagonal(a_np, offset=2, axis1=-2, 
axis2=0))
+
+    # Test 4d backward, k=0, axis1=3, axis2=0
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=0, axis1=3, axis2=0)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 4d backward, k=1, axis1=1, axis2=2
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=1, axis1=1, axis2=2)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 4d backward, k=-1, axis1=2, axis2=0
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=-1, axis1=2, axis2=0)
+    check_numeric_gradient(diag_sym, [a_np])
+
+    # Test 4d backward, k=-2, axis1=1, axis2=-1
+    data = mx.sym.Variable('data')
+    diag_sym = mx.sym.diag(data=data, k=-2, axis1=1, axis2=-1)
+    check_numeric_gradient(diag_sym, [a_np])
+
 @with_seed()
 def test_depthtospace():
     def f(x, blocksize):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to