This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8039377  add op npx.index_update (#18545)
8039377 is described below

commit 8039377e6630bcb00c5a95abdaf0851803686bc6
Author: JiangZhaoh <[email protected]>
AuthorDate: Wed Jun 17 01:45:30 2020 +0800

    add op npx.index_update (#18545)
    
    * add op npx.index_update
    
    * remove debug comment
    
    * change eps
    
    * fix stupid error
    
    * add blank line in docs
    
    * gpu temporary space request alignment
    
    * fix test error
    
    Co-authored-by: Ubuntu <[email protected]>
---
 python/mxnet/_numpy_op_doc.py                      |  72 ++++++
 src/operator/tensor/index_add-inl.h                |   2 +-
 src/operator/tensor/index_add_backward.cc          |  18 +-
 .../tensor/{index_add-inl.h => index_update-inl.h} | 175 ++++++++------
 src/operator/tensor/index_update.cc                | 261 +++++++++++++++++++++
 src/operator/tensor/index_update.cu                | 204 ++++++++++++++++
 tests/python/unittest/test_numpy_op.py             | 162 +++++++++++++
 7 files changed, 813 insertions(+), 81 deletions(-)

diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index fecd0e6..b8f4a49 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -630,6 +630,7 @@ def _npx_index_add(a, ind, val):
     """
     Add values to input according to given indexes.
     If exists repeate positions to be updated, the update value will be 
accumulated.
+
     Parameters
     ----------
     a : ndarray
@@ -643,10 +644,12 @@ def _npx_index_add(a, ind, val):
               - ind.dtype should be 'int32' or 'int64'
     val : ndarray
         Input data. The array to update the input 'a'.
+
     Returns
     -------
     out : ndarray
         The output array.
+
     Examples
     --------
     >>> a = np.zeros((2, 3, 4))
@@ -699,6 +702,75 @@ def _npx_index_add(a, ind, val):
     pass
 
 
+def _npx_index_update(a, ind, val):
+    """
+    Update values to input according to given indexes.
+    If multiple indices refer to the same location it is undefined which 
update is chosen; it may choose
+    the order of updates arbitrarily and nondeterministically (e.g., due to 
concurrent updates on some
+    hardware platforms). Recommend not to use repeate positions.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input data. The array to be updated.
+        Support dtype: 'float32', 'float64', 'int32', 'int64'.
+    ind : ndarray
+        Indexes for indicating update positions.
+        For example, array([[0, 1], [2, 3], [4, 5]] indicates here are two 
positions to
+        be updated, which is (0, 2, 4) and (1, 3, 5).
+        Note: - 'ind' cannot be empty array '[]', for that case, please use 
operator 'add' instead.
+              - 0 <= ind.ndim <= 2.
+              - ind.dtype should be 'int32' or 'int64'
+    val : ndarray
+        Input data. The array to update the input 'a'.
+        Support dtype: 'float32', 'float64', 'int32', 'int64'.
+
+    Returns
+    -------
+    out : ndarray
+        The output array.
+
+    Examples
+    --------
+    >>> a = np.zeros((2, 3, 4))
+    >>> ind = np.array([[0, 0], [0, 0], [0, 1]], dtype='int32')
+    >>> val = np.arange(2).reshape(2) + 1
+    >>> b = npx.index_update(a, ind, val)
+    >>> b
+    array([[[1., 2., 0., 0.],
+            [0., 0., 0., 0.],
+            [0., 0., 0., 0.]],
+
+           [[0., 0., 0., 0.],
+            [0., 0., 0., 0.],
+            [0., 0., 0., 0.]]])
+
+    >>> ind=np.array([[0, 0], [0, 1]], dtype='int32') 
+    >>> val = np.arange(8).reshape(2, 4) 
+    >>> b = npx.index_update(a, ind, val)
+    >>> b
+    array([[[0., 1., 2., 3.],
+            [4., 5., 6., 7.],
+            [0., 0., 0., 0.]],
+
+           [[0., 0., 0., 0.],
+            [0., 0., 0., 0.],
+            [0., 0., 0., 0.]]])
+    
+    >>> val = np.arange(4).reshape(4)  # brocast 'val'
+    >>> b = npx.index_update(a, ind, val)
+    >>> b
+    array([[[0., 1., 2., 3.],
+            [0., 1., 2., 3.],
+            [0., 0., 0., 0.]],
+
+        [[0., 0., 0., 0.],
+            [0., 0., 0., 0.],
+            [0., 0., 0., 0.]]])
+    """
+    pass
+
+
 def _np_diag(array, k=0):
     """
     Extracts a diagonal or constructs a diagonal array.
diff --git a/src/operator/tensor/index_add-inl.h 
b/src/operator/tensor/index_add-inl.h
index 83463da..122aa01 100644
--- a/src/operator/tensor/index_add-inl.h
+++ b/src/operator/tensor/index_add-inl.h
@@ -52,7 +52,7 @@ inline bool IndexModifyOpType(const nnvm::NodeAttrs& attrs,
   CHECK_NE((*in_attrs)[1], -1);
   CHECK_NE((*in_attrs)[2], -1);
   CHECK_EQ((*in_attrs)[0], (*in_attrs)[2])
-    << "index_add(a, ind, val) only support a.dtype == val.dtype";
+    << "index_add/index_update(a, ind, val) only support a.dtype == val.dtype";
   CHECK((*in_attrs)[1] == mshadow::kInt64 ||
         (*in_attrs)[1] == mshadow::kInt32)
     << "'ind' only support int dtype.";
diff --git a/src/operator/tensor/index_add_backward.cc 
b/src/operator/tensor/index_add_backward.cc
index 158695b..0fe1009 100644
--- a/src/operator/tensor/index_add_backward.cc
+++ b/src/operator/tensor/index_add_backward.cc
@@ -67,15 +67,15 @@ void IndexAddBackwardValCPUCompute(DType* grad_val,
 
 template<>
 void IndexAddOpBackwardValImpl<cpu>(const OpContext& ctx,
-                               const TBlob& grad_val,
-                               const TBlob& ograd,
-                               const TBlob& t_ind,
-                               const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_tail_shape,
-                               const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_pre_stride,
-                               const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
-                               const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
-                               const int tail_size, const int ind_num, const 
int ind_ndim,
-                               const int ndim) {
+                                    const TBlob& grad_val,
+                                    const TBlob& ograd,
+                                    const TBlob& t_ind,
+                                    const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape,
+                                    const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride,
+                                    const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride,
+                                    const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape,
+                                    const int tail_size, const int ind_num, 
const int ind_ndim,
+                                    const int ndim) {
   using namespace mshadow;
   using namespace mxnet_op;
   int seg = MXNET_SPECIAL_MAX_NDIM - ndim;
diff --git a/src/operator/tensor/index_add-inl.h 
b/src/operator/tensor/index_update-inl.h
similarity index 52%
copy from src/operator/tensor/index_add-inl.h
copy to src/operator/tensor/index_update-inl.h
index 83463da..8319647 100644
--- a/src/operator/tensor/index_add-inl.h
+++ b/src/operator/tensor/index_update-inl.h
@@ -18,15 +18,17 @@
  */
 
 /*!
- * \file index_add-inl.h
- * \brief Function definition of index_add operator
+ * \file index_update-inl.h
+ * \brief Function definition of index_update operator
 */
-#ifndef MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_
-#define MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_
+#ifndef MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_
+#define MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_
 
 #include <mxnet/operator_util.h>
 #include <vector>
 #include <algorithm>
+#include "./index_add-inl.h"
+#include "./sort_op.h"
 #include "../mxnet_op.h"
 #include "../operator_common.h"
 #include "../elemwise_op_common.h"
@@ -34,51 +36,25 @@
 namespace mxnet {
 namespace op {
 
-inline bool IndexModifyOpShape(const nnvm::NodeAttrs& attrs,
-                               mxnet::ShapeVector* in_attrs,
-                               mxnet::ShapeVector* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 3U);
-  CHECK_EQ(out_attrs->size(), 1U);
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
-  return true;
-}
-
-inline bool IndexModifyOpType(const nnvm::NodeAttrs& attrs,
-                              std::vector<int>* in_attrs,
-                              std::vector<int>* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 3U);
-  CHECK_EQ(out_attrs->size(), 1U);
-  CHECK_NE((*in_attrs)[0], -1);
-  CHECK_NE((*in_attrs)[1], -1);
-  CHECK_NE((*in_attrs)[2], -1);
-  CHECK_EQ((*in_attrs)[0], (*in_attrs)[2])
-    << "index_add(a, ind, val) only support a.dtype == val.dtype";
-  CHECK((*in_attrs)[1] == mshadow::kInt64 ||
-        (*in_attrs)[1] == mshadow::kInt32)
-    << "'ind' only support int dtype.";
-  TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
-  return (*out_attrs)[0] != -1;
-}
-
 template<typename xpu, typename DType>
-void IndexAddForwardCalc(mshadow::Stream<xpu> *s,
-                         const int ind_num, DType* out,
-                         const DType* val,
-                         const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_tail_shape,
-                         const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_pre_stride,
-                         const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
-                         const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
-                         const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_shape,
-                         const int a_tail_size,
-                         const int ind_ndim, const int* ind,
-                         const int a_ndim);
+void IndexUpdateForwardCalc(mshadow::Stream<xpu> *s,
+                            const int ind_num, DType* out,
+                            const DType* val,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_tail_shape,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_pre_stride,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_shape,
+                            const int a_tail_size,
+                            const int ind_ndim, const int* ind,
+                            const int a_ndim);
 
 template<typename xpu>
-void IndexAddOpForward(const nnvm::NodeAttrs& attrs,
-                       const OpContext& ctx,
-                       const std::vector<TBlob>& inputs,
-                       const std::vector<OpReqType>& req,
-                       const std::vector<TBlob>& outputs) {
+void IndexUpdateOpForward(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
   using namespace mxnet_op;
   using namespace mshadow;
   CHECK_EQ(inputs.size(), 3U);
@@ -88,7 +64,7 @@ void IndexAddOpForward(const nnvm::NodeAttrs& attrs,
   TBlob ind = inputs[1];
   TBlob val = inputs[2];
   TBlob out = outputs[0];
-  CHECK_GT(a.shape_.ndim(), 0) << "The first input is saclar, please use '+' 
instead.";
+  CHECK_GT(a.shape_.ndim(), 0) << "The first input is saclar, please use '=' 
instead.";
   int a_ndim = a.shape_.ndim();
   CHECK_LE(a_ndim, MXNET_SPECIAL_MAX_NDIM)
     << "ndim should less than "<< MXNET_SPECIAL_MAX_NDIM
@@ -152,33 +128,33 @@ void IndexAddOpForward(const nnvm::NodeAttrs& attrs,
                 (Shape1(ind.shape_.Size()), s));
   mxnet_op::copy(s, t_ind, ind);
   MSHADOW_TYPE_SWITCH(a.type_flag_, DType, {
-    IndexAddForwardCalc<xpu, DType>(s, ind_num,
-                                    out.dptr<DType>(), val.dptr<DType>(),
-                                    a_tail_shape, a_pre_stride,
-                                    val_stride, val_shape, a_shape,
-                                    a_tail_size, ind_ndim,
-                                    t_ind.dptr<int>(), a_ndim);
+    IndexUpdateForwardCalc<xpu, DType>(s, ind_num,
+                                       out.dptr<DType>(), val.dptr<DType>(),
+                                       a_tail_shape, a_pre_stride,
+                                       val_stride, val_shape, a_shape,
+                                       a_tail_size, ind_ndim,
+                                       t_ind.dptr<int>(), a_ndim);
   });
 }
 
 template<typename xpu>
-void IndexAddOpBackwardValImpl(const OpContext& ctx,
-                               const TBlob& grad_val,
-                               const TBlob& ograd,
-                               const TBlob& t_ind,
-                               const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_tail_shape,
-                               const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_pre_stride,
-                               const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
-                               const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
-                               const int tail_size, const int ind_num, const 
int ind_ndim,
-                               const int ndim);
+void IndexUpdateOpBackwardValImpl(const OpContext& ctx,
+                                  const TBlob& grad_val,
+                                  const TBlob& ograd,
+                                  const TBlob& t_ind,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_tail_shape,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_pre_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                  const int tail_size, const int ind_num, 
const int ind_ndim,
+                                  const int ndim);
 
 template<typename xpu>
-inline void IndexAddOpBackwardVal(const nnvm::NodeAttrs& attrs,
-                                  const OpContext& ctx,
-                                  const std::vector<TBlob>& inputs,
-                                  const std::vector<OpReqType>& req,
-                                  const std::vector<TBlob>& outputs) {
+inline void IndexUpdateOpBackwardVal(const nnvm::NodeAttrs& attrs,
+                                     const OpContext& ctx,
+                                     const std::vector<TBlob>& inputs,
+                                     const std::vector<OpReqType>& req,
+                                     const std::vector<TBlob>& outputs) {
   using namespace mshadow;
   using namespace mxnet_op;
   if (req[0] == kNullOp) {
@@ -221,11 +197,68 @@ inline void IndexAddOpBackwardVal(const nnvm::NodeAttrs& 
attrs,
   }
   mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride = 
mxnet_op::calc_stride(ograd_pre_shape);
   mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride = 
mxnet_op::calc_stride(val_shape);
-  IndexAddOpBackwardValImpl<xpu>(ctx, grad_val, ograd, t_ind, 
ograd_tail_shape, ograd_pre_stride,
-                                 val_stride, val_shape, tail_size, ind_num, 
ind_ndim, out_ndim);
+  IndexUpdateOpBackwardValImpl<xpu>(ctx, grad_val, ograd, t_ind, 
ograd_tail_shape, ograd_pre_stride,
+                                    val_stride, val_shape, tail_size, ind_num, 
ind_ndim, out_ndim);
+}
+
+template<typename DType>
+struct ReqCopy {
+  MSHADOW_XINLINE static void Map(size_t i, DType* dest, const DType* src, 
const int req) {
+    KERNEL_ASSIGN(dest[i], req, src[i]);
+  }
+};
+
+template<typename xpu>
+void IndexUpdateOpBackwardAImpl(const OpContext& ctx,
+                                const TBlob& grad_a,
+                                const TBlob& ograd,
+                                const TBlob& t_ind,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
grada_pre_stride,
+                                const int tail_size, const int ind_num, const 
int ind_ndim,
+                                const int seg, const int req);
+
+template<typename xpu>
+inline void IndexUpdateOpBackwardA(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<TBlob>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  if (req[0] == kNullOp) {
+    return;
+  }
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  TBlob ograd = inputs[0];
+  TBlob ind = inputs[1];
+  const TBlob& grad_a = outputs[0];
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  // get the number of 'ind' index
+  if (ind.shape_.ndim() == 0) {
+    ind.shape_ = Shape2(1, 1);
+  } else if (ind.shape_.ndim() == 1) {
+    ind.shape_ = Shape2(1, ind.shape_[0]);
+  }
+  int ind_ndim = ind.shape_[0];
+  int ind_num = ind.shape_[1];
+  int out_ndim = ograd.shape_.ndim();
+  int tail_size = static_cast<int>(ograd.shape_.ProdShape(ind_ndim, out_ndim));
+  mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_shape;
+  for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = out_ndim - 1; i >= 0; --i, --j) 
{
+    grada_shape[i] = (j >= 0) ? grad_a.shape_[j] : 1;
+  }
+  mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_shape(grada_shape);
+  int seg = MXNET_SPECIAL_MAX_NDIM - out_ndim;
+  for (int i = seg + ind_ndim; i < seg + out_ndim; ++i) {
+    grada_pre_shape[i] = 1;
+  }
+  mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride = 
mxnet_op::calc_stride(grada_pre_shape);
+  IndexUpdateOpBackwardAImpl<xpu>(ctx, grad_a, ograd, ind, grada_pre_stride,
+                                  tail_size, ind_num, ind_ndim, seg, req[0]);
 }
 
 }   // namespace op
 }   // namespace mxnet
 
-#endif  // MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_
+#endif  // MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_
diff --git a/src/operator/tensor/index_update.cc 
b/src/operator/tensor/index_update.cc
new file mode 100644
index 0000000..ffd5d29
--- /dev/null
+++ b/src/operator/tensor/index_update.cc
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file index_update-inl.cc
+ * \brief implementation of index_update operator
+*/
+#include <vector>
+#include "./index_update-inl.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename DType>
+struct IndexUpdateForwardCPUKernel {
+  MSHADOW_XINLINE static void Map(size_t i, DType* out,
+                                  const DType* val,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_tail_shape,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_pre_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_shape,
+                                  const int a_tail_size, const int ind_num,
+                                  const int ind_ndim, const int* ind,
+                                  const int a_ndim, const int seg) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      CHECK_LT(ind[dim * ind_num + i], a_shape[seg + dim])
+        << "IndexError: index " << ind[dim * ind_num + i]
+        << " is out of bounds for axis " << dim
+        << " with size " << a_shape[seg + dim];
+      CHECK_GE(ind[dim * ind_num + i], 0)
+        << "IndexError: index " << ind[dim * ind_num + i]
+        << " should be greater or equal to 0.";
+      id += a_pre_stride[seg + dim] * ind[dim * ind_num + i];
+    }
+    id *= a_tail_size;
+    for (int _i = 0; _i < a_tail_size; ++_i) {
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_id = mxnet_op::unravel(_i, 
a_tail_shape);
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+      for (int _j = 0; _j < seg; ++_j) {
+        val_id[_j] = 0;
+      }
+      for (int _j = seg; _j < seg + a_ndim; ++_j) {
+        val_id[_j] = (val_shape[_j] == 1) ? 0 : a_tail_id[_j];
+      }
+      val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 : 
i;
+      index_t val_dest = mxnet_op::dot(val_id, val_stride);
+      out[id + _i] = val[val_dest];
+    }
+  }
+};
+
+template<typename xpu, typename DType>
+void IndexUpdateForwardCalc(mshadow::Stream<xpu> *s,
+                            const int ind_num, DType* out,
+                            const DType* val,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_tail_shape,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_pre_stride,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_shape,
+                            const int a_tail_size,
+                            const int ind_ndim, const int* ind,
+                            const int a_ndim) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  int seg = MXNET_SPECIAL_MAX_NDIM - a_ndim;
+  Kernel<IndexUpdateForwardCPUKernel<DType>, xpu>::Launch(
+    s, ind_num, out, val, a_tail_shape, a_pre_stride,
+    val_stride, val_shape, a_shape, a_tail_size, ind_num,
+    ind_ndim, ind, a_ndim, seg);
+}
+
+
+template<typename DType>
+void IndexUpdateBackwardValCPUCompute(DType* grad_val,
+                                      const DType* ograd,
+                                      const int* ind_vec,
+                                      const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape,
+                                      const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride,
+                                      const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride,
+                                      const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape,
+                                      const int ograd_tail_size, const int 
ind_num,
+                                      const int ind_ndim, const int out_ndim,
+                                      const int seg) {
+  #pragma omp parallel for 
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+  for (index_t i = 0; i < static_cast<index_t>(ind_num); ++i) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      id += ograd_pre_stride[seg + dim] * ind_vec[dim * ind_num + i];
+    }
+    id *= ograd_tail_size;
+    for (int _i = 0; _i < ograd_tail_size; ++_i) {
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_id =
+        mxnet_op::unravel(_i, ograd_tail_shape);
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+      for (int _j = 0; _j < seg; ++_j) {
+        val_id[_j] = 0;
+      }
+      for (int _j = seg; _j < seg + out_ndim; ++_j) {
+        val_id[_j] = (val_shape[_j] == 1) ? 0 : ograd_tail_id[_j];
+      }
+      val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 : 
i;
+      index_t val_dest = mxnet_op::dot(val_id, val_stride);
+      #pragma omp critical
+      {
+        grad_val[val_dest] += ograd[id + _i];
+      }
+    }
+  }
+}
+
+template<>
+void IndexUpdateOpBackwardValImpl<cpu>(const OpContext& ctx,
+                                const TBlob& grad_val,
+                                const TBlob& ograd,
+                                const TBlob& t_ind,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_tail_shape,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_pre_stride,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                const int tail_size, const int ind_num, const 
int ind_ndim,
+                                const int ndim) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  int seg = MXNET_SPECIAL_MAX_NDIM - ndim;
+  MSHADOW_TYPE_SWITCH(grad_val.type_flag_, DType, {
+    IndexUpdateBackwardValCPUCompute<DType>(
+      grad_val.dptr<DType>(), ograd.dptr<DType>(), t_ind.dptr<int>(),
+      ograd_tail_shape, ograd_pre_stride, val_stride, val_shape, tail_size,
+      ind_num, ind_ndim, ndim, seg);
+  });
+}
+
+template<typename DType>
+void IndexUpdateBackwardACPUCompute(DType* out_grad,
+                                    const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride,
+                                    const int tail_size, const int ind_num, 
const int ind_ndim,
+                                    const int32_t* ind, const int seg,
+                                    const int req) {
+  #pragma omp parallel for 
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+  for (index_t i = 0; i < static_cast<index_t>(ind_num); ++i) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      id += grada_pre_stride[seg + dim] * ind[dim * ind_num + i];
+    }
+    id *= tail_size;
+    for (int _i = 0; _i < tail_size; ++_i) {
+      out_grad[id + _i] = 0;
+    }
+  }
+}
+
+template<>
+void IndexUpdateOpBackwardAImpl<cpu>(const OpContext& ctx,
+                                     const TBlob& grad_a,
+                                     const TBlob& ograd,
+                                     const TBlob& ind,
+                                     const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride,
+                                     const int tail_size, const int ind_num, 
const int ind_ndim,
+                                     const int seg, const int req) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+  MSHADOW_TYPE_SWITCH(grad_a.type_flag_, DType, {
+    size_t temp_mem_size = ind.shape_.Size() * sizeof(int) +
+                           ograd.shape_.Size() * sizeof(DType);
+    Tensor<cpu, 1, char> temp_mem =
+      ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(temp_mem_size), s);
+    TBlob t_ograd = TBlob(temp_mem.dptr_, ograd.shape_, ograd.dev_mask(),
+                          ograd.type_flag_, ograd.dev_id());
+    TBlob t_ind = TBlob(temp_mem.dptr_ + ograd.Size() * sizeof(DType), 
ind.shape_, ind.dev_mask(),
+                        mshadow::kInt32, ind.dev_id());
+    mxnet_op::copy(s, t_ograd, ograd);
+    mxnet_op::copy(s, t_ind, ind);
+    IndexUpdateBackwardACPUCompute<DType>(t_ograd.dptr<DType>(),
+                                          grada_pre_stride, tail_size,
+                                          ind_num, ind_ndim,
+                                          t_ind.dptr<int32_t>(), seg, req);
+    Kernel<ReqCopy<DType>, cpu>::Launch(s, grad_a.shape_.Size(), 
grad_a.dptr<DType>(),
+                                        t_ograd.dptr<DType>(), req);
+  });
+}
+
+NNVM_REGISTER_OP(_npx_index_update)
+.describe(R"code(This operators implements the "=" mimic function.
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a", "ind", "val"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", IndexModifyOpShape)
+.set_attr<nnvm::FInferType>("FInferType", IndexModifyOpType)
+.set_attr<FCompute>("FCompute<cpu>", IndexUpdateOpForward<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FGradient>("FGradient",
+  [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+      auto a_grad = MakeNode("_backward_index_update_a", n->attrs.name + 
"_backward_a",
+                              {ograds[0], n->inputs[1]}, nullptr, &n);
+      auto idx_grad = MakeNode("zeros_like", n->attrs.name + 
"_backward_indices",
+                              {n->inputs[1]}, nullptr, &n);
+      auto val_grad = MakeNode("_backward_index_update_val", n->attrs.name + 
"_backward_val",
+                              {ograds[0], n->inputs[1]}, nullptr, &n);
+      // auto val_grad = MakeNode("zeros_like", n->attrs.name + 
"_backward_val",
+      //                         {n->inputs[2]}, nullptr, &n);
+      std::vector<nnvm::NodeEntry> ret;
+      ret.emplace_back(a_grad);
+      ret.emplace_back(idx_grad);
+      ret.emplace_back(val_grad);
+      return ret;
+  })
+.add_argument("a", "NDArray-or-Symbol", "Input ndarray")
+.add_argument("ind", "NDArray-or-Symbol", "Index ndarray")
+.add_argument("val", "NDArray-or-Symbol", "Update ndarray");
+
+
+NNVM_REGISTER_OP(_backward_index_update_a)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", IndexUpdateOpBackwardA<cpu>);
+
+
+NNVM_REGISTER_OP(_backward_index_update_val)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
+.set_attr<FCompute>("FCompute<cpu>", IndexUpdateOpBackwardVal<cpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/tensor/index_update.cu 
b/src/operator/tensor/index_update.cu
new file mode 100644
index 0000000..3ed788f
--- /dev/null
+++ b/src/operator/tensor/index_update.cu
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file index_update.cu
+ * \brief GPU implementation of index_update operator
+ */
+
+#include <cub/cub.cuh>
+#include "./index_update-inl.h"
+#include "../tensor/util/tensor_util-inl.cuh"
+#include "../tensor/util/tensor_util-inl.h"
+
+
+namespace mxnet {
+namespace op {
+
+template<typename DType>
+struct IndexUpdateForwardGPUKernel {
+  MSHADOW_XINLINE static void Map(size_t i, DType* out,
+                                  const DType* val,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_tail_shape,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_pre_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                  const int a_tail_size, const int ind_num,
+                                  const int ind_ndim, const int* ind,
+                                  const int a_ndim, const int seg) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      id += a_pre_stride[seg + dim] * ind[dim * ind_num + i];
+    }
+    id *= a_tail_size;
+    for (int _i = 0; _i < a_tail_size; ++_i) {
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_id = mxnet_op::unravel(_i, 
a_tail_shape);
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+      for (int _j = 0; _j < seg; ++_j) {
+        val_id[_j] = 0;
+      }
+      for (int _j = seg; _j < seg + a_ndim; ++_j) {
+        val_id[_j] = (val_shape[_j] == 1) ? 0 : a_tail_id[_j];
+      }
+      val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 : 
i;
+      index_t val_dest = mxnet_op::dot(val_id, val_stride);
+      out[id + _i] = val[val_dest];
+    }
+  }
+};
+
+template<typename xpu, typename DType>
+void IndexUpdateForwardCalc(mshadow::Stream<xpu> *s,
+                            const int ind_num, DType* out,
+                            const DType* val,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_tail_shape,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_pre_stride,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_shape,
+                            const int a_tail_size,
+                            const int ind_ndim, const int* ind,
+                            const int a_ndim) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  int seg = MXNET_SPECIAL_MAX_NDIM - a_ndim;
+  Kernel<IndexUpdateForwardGPUKernel<DType>, xpu>::Launch(
+    s, ind_num, out, val, a_tail_shape, a_pre_stride,
+    val_stride, val_shape, a_tail_size, ind_num,
+    ind_ndim, ind, a_ndim, seg);
+}
+
+
+struct IndexUpdateBackwardValGPUKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(size_t i, DType* grad_val,
+                                  const DType* ograd, const int* ind_vec,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_tail_shape,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_pre_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                  const int ograd_tail_size, const int ind_num,
+                                  const int ind_ndim, const int out_ndim, 
const int seg) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      id += ograd_pre_stride[seg + dim] * ind_vec[dim * ind_num + i];
+    }
+    id *= ograd_tail_size;
+    for (int _i = 0; _i < ograd_tail_size; ++_i) {
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_id =
+        mxnet_op::unravel(_i, ograd_tail_shape);
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+      for (int _j = 0; _j < seg; ++_j) {
+        val_id[_j] = 0;
+      }
+      for (int _j = seg; _j < seg + out_ndim; ++_j) {
+        val_id[_j] = (val_shape[_j] == 1) ? 0 : ograd_tail_id[_j];
+      }
+      val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 : 
i;
+      index_t val_dest = mxnet_op::dot(val_id, val_stride);
+      atomicAdd(&grad_val[val_dest], ograd[id + _i]);
+    }
+  }
+};
+
+template<>
+void IndexUpdateOpBackwardValImpl<gpu>(const OpContext& ctx,
+                                const TBlob& grad_val,
+                                const TBlob& ograd,
+                                const TBlob& t_ind,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_tail_shape,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_pre_stride,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                const int tail_size, const int ind_num, const 
int ind_ndim,
+                                const int ndim) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  int seg = MXNET_SPECIAL_MAX_NDIM - ndim;
+  MSHADOW_TYPE_SWITCH(grad_val.type_flag_, DType, {
+    Kernel<IndexUpdateBackwardValGPUKernel, gpu>::Launch(
+        s, ind_num, grad_val.dptr<DType>(), ograd.dptr<DType>(), 
t_ind.dptr<int>(),
+        ograd_tail_shape, ograd_pre_stride, val_stride, val_shape, tail_size,
+        ind_num, ind_ndim, ndim, seg);
+  });
+}
+
+template<typename DType>
+struct IndexUpdateBackwardAGPUKernel {
+  MSHADOW_XINLINE static void Map(size_t i, DType* out_grad,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
grada_pre_stride,
+                                  const int tail_size, const int ind_num, 
const int ind_ndim,
+                                  const int32_t* ind, const int seg,
+                                  const int req) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      id += grada_pre_stride[seg + dim] * ind[dim * ind_num + i];
+    }
+    id *= tail_size;
+    for (int _i = 0; _i < tail_size; ++_i) {
+      out_grad[id + _i] = 0;
+    }
+  }
+};
+
+template<>
+void IndexUpdateOpBackwardAImpl<gpu>(const OpContext& ctx,
+                                     const TBlob& grad_a,
+                                     const TBlob& ograd,
+                                     const TBlob& ind,
+                                     const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride,
+                                     const int tail_size, const int ind_num, 
const int ind_ndim,
+                                     const int seg, const int req) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  MSHADOW_TYPE_SWITCH(grad_a.type_flag_, DType, {
+    size_t alignment = std::max(sizeof(DType), sizeof(int32_t));
+    size_t id_size = PadBytes(sizeof(int32_t) * ind.Size(), alignment);
+    size_t ograd_size = PadBytes(sizeof(DType) * ograd.Size(), alignment);
+    size_t temp_mem_size = id_size + ograd_size;
+    Tensor<gpu, 1, char> temp_mem =
+      ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_mem_size), s);
+    TBlob t_ograd = TBlob(temp_mem.dptr_, ograd.shape_, ograd.dev_mask(),
+                          ograd.type_flag_, ograd.dev_id());
+    TBlob t_ind = TBlob(temp_mem.dptr_ + ograd_size, ind.shape_, 
ind.dev_mask(),
+                        mshadow::kInt32, ind.dev_id());
+    mxnet_op::copy(s, t_ograd, ograd);
+    mxnet_op::copy(s, t_ind, ind);
+    Kernel<IndexUpdateBackwardAGPUKernel<DType>, gpu>::Launch(s, ind_num, 
t_ograd.dptr<DType>(),
+                                                           grada_pre_stride, 
tail_size,
+                                                           ind_num, ind_ndim,
+                                                           
t_ind.dptr<int32_t>(), seg, req);
+    Kernel<ReqCopy<DType>, gpu>::Launch(s, grad_a.shape_.Size(), 
grad_a.dptr<DType>(),
+                                        t_ograd.dptr<DType>(), req);
+  });
+}
+
+NNVM_REGISTER_OP(_npx_index_update)
+.set_attr<FCompute>("FCompute<gpu>", IndexUpdateOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_index_update_val)
+.set_attr<FCompute>("FCompute<gpu>", IndexUpdateOpBackwardVal<gpu>);
+
+NNVM_REGISTER_OP(_backward_index_update_a)
+.set_attr<FCompute>("FCompute<gpu>", IndexUpdateOpBackwardA<gpu>);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 3ea6119..6440dec 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1479,6 +1479,168 @@ def test_npx_index_add():
 
 @with_seed()
 @use_np
+def test_npx_index_update():
+    class TestIndexUpdate(HybridBlock):
+        def __init__(self):
+            super(TestIndexUpdate, self).__init__()
+
+        def hybrid_forward(self, F, a, ind, val):
+            return F.npx.index_update(a, ind, val)
+
+    def check_index_update_forward(mx_ret, a, ind, val, ind_ndim, ind_num, 
eps):
+        if val.dtype != a.dtype:
+            val = val.astype(a.dtype)
+        ind_arr = ind.transpose()
+        if ind_arr.ndim == 0:
+            ind_arr = _np.array([ind_arr])
+        for i in range(ind_arr.shape[0]):
+            t_ind = ind_arr[i]
+            t_ind = tuple(t_ind.tolist()) if type(t_ind) is _np.ndarray else 
t_ind.tolist()
+            if val.ndim + ind_ndim > a.ndim:
+                t_val = val[tuple([0 if val.shape[0]==1 else i])]
+                if type(t_val) is _np.ndarray and t_val.shape[0] == 1:
+                    expect_tmp = _np.squeeze(t_val, axis=0)
+                else:
+                    expect_tmp = t_val
+            else:
+                expect_tmp = val
+            mx_tmp = mx_ret[t_ind]
+            close_pos = _np.where(_np.isclose(expect_tmp, mx_tmp, rtol=eps, 
atol=eps))
+            if a[t_ind].ndim == 0:
+                if close_pos[0].size == 1:
+                    mx_ret[t_ind] = 0
+                    a[t_ind] = 0
+            else:
+                mx_ret[t_ind][close_pos] = 0
+                a[t_ind][close_pos] = 0
+        assert_almost_equal(mx_ret, a, rtol=eps, atol=eps)
+
+    def index_update_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num, 
grad_req_a, grad_req_val):
+        if grad_req_a == 'add':
+            init_a_grad = _np.array(a_grad)
+        if grad_req_val == 'add':
+            init_val_grad = _np.array(val_grad)
+        a_grad = _np.zeros(a_grad.shape) + out_grad
+        a_grad = a_grad.astype(a_grad.dtype)
+        val_grad = _np.zeros(val_grad.shape).astype(val_grad.dtype)
+
+        ind_arr = ind.transpose()
+        if ind_arr.ndim == 0:
+            ind_arr = _np.array([ind_arr])
+        for i in range(ind_arr.shape[0]):
+            t_ind = ind_arr[i]
+            t_ind = tuple(ind_arr[i].tolist()) if type(ind_arr[i]) is 
_np.ndarray else ind_arr[i].tolist()
+            a_grad[t_ind] = 0
+            if val_grad.ndim + ind_ndim > a_grad.ndim:
+                idx = 0 if val_grad.shape[0]==1 else i
+                t_grad = out_grad[t_ind]
+                t_grad_shape = _np.array(t_grad.shape)
+                val_grad_shape = _np.array(val_grad[idx].shape)
+                if type(val_grad[idx]) is not _np.ndarray:
+                    t_grad = _np.sum(t_grad)
+                else:
+                    is_not_equal = t_grad_shape - val_grad_shape
+                    if _np.any(is_not_equal):
+                        broadcast_dim = _np.nonzero(_np.where(is_not_equal, 1, 
0))
+                        t_grad = _np.sum(t_grad, 
axis=tuple(broadcast_dim[0].reshape(1, -1)[0]), keepdims=True)
+                val_grad[idx] += t_grad
+            else:
+                t_grad = out_grad[t_ind]
+                if type(val_grad) is not _np.ndarray or val_grad.shape == ():
+                    t_grad = _np.sum(t_grad)
+                else:
+                    if type(t_grad) is _np.ndarray:
+                        ext_dim = t_grad.ndim() - val_grad.ndim()
+                        if ext_dim:
+                            t_grad = _np.sum(t_grad, 
axis=tuple(_np.arange(ext_dim)))
+                        t_grad_shape = _np.array(t_grad.shape)
+                        val_grad_shape = _np.array(val_grad.shape)
+                        is_not_equal = t_grad_shape - val_grad_shape
+                        if _np.any(is_not_equal):
+                            broadcast_dim = 
_np.nonzero(_np.where(is_not_equal, 1, 0))
+                            t_grad = _np.sum(t_grad, 
axis=tuple(broadcast_dim.reshape(1, -1)[0]), keepdims=True)
+                val_grad += t_grad
+        if grad_req_a == 'add':
+            a_grad += init_a_grad
+        if grad_req_val == 'add':
+            val_grad += init_val_grad
+        return a_grad, val_grad
+
+    # a.shape, ind.shape, val.shape, ind_ndim, ind_num
+    configs = [((2, ), np.array(1, dtype=_np.int32), (1, ), 1, 1)]
+    shape = tuple(_np.random.randint(1, 6, size=(4))) # a.shape
+    for ind_ndim in range(1, 5): # ind.shape: (ind_ndim, ind_num)
+        ind_num = _np.random.randint(1, 7)
+        ind = []
+        for ind_dim in range(ind_ndim):
+            ind.append(_np.random.randint(0, shape[ind_dim], size=(ind_num)))
+        ind = _np.array(ind).astype(_np.int32)
+        # case: val is scalar
+        configs.append(tuple([shape, ind, (), ind_ndim, ind_num]))
+        for val_ndim in range(1, 5 - ind_ndim):
+            val_shape = [1 if _np.random.randint(0, 5)==0 else ind_num]
+            for val_dim in range(ind_ndim, 4):
+                val_shape.append(1 if _np.random.randint(0, 5)==0 else 
shape[val_dim])
+            # case: val is tensor
+            configs.append(tuple([shape, ind, tuple(val_shape), ind_ndim, 
ind_num]))
+
+    dtypes = ['float32', 'float64', 'int32', 'int64']
+    grad_req = ['write', 'null', 'add']
+    for hybridize, grad_req_a, grad_req_val, dtype, indtype in \
+        itertools.product([True, False], grad_req, grad_req, dtypes, ['int32', 
'int64']):
+        for a_shape, ind, val_shape ,ind_ndim, ind_num in configs:
+            eps = 1e-3
+            atype = dtype
+            valtype = dtype
+            test_index_update = TestIndexUpdate()
+            if hybridize:
+                test_index_update.hybridize()
+            a = mx.nd.random.uniform(-10.0, 10.0, 
shape=a_shape).as_np_ndarray().astype(atype)
+            a.attach_grad(grad_req=grad_req_a)
+            val = mx.nd.random.uniform(-10.0, 10.0, 
shape=val_shape).as_np_ndarray().astype(valtype)
+            val.attach_grad(grad_req=grad_req_val)
+            with mx.autograd.record():
+                mx_ret = test_index_update(a, np.array(ind).astype(indtype), 
val)
+            assert mx_ret.shape == a.shape
+            assert mx_ret.dtype == a.dtype
+            check_index_update_forward(mx_ret.asnumpy(), a.asnumpy(), 
ind.astype(indtype), val.asnumpy(), ind_ndim, ind_num, eps)
+
+            if atype not in ['float16', 'float32', 'float64'] or valtype not 
in ['float16', 'float32', 'float64']:
+                continue
+            if grad_req_a != 'null' or grad_req_val != 'null':
+                init_a_grad = mx.nd.random.uniform(-10.0, 10.0, 
shape=a_shape).as_np_ndarray().astype(atype)
+                init_val_grad = mx.nd.random.uniform(-10.0, 10.0, 
shape=val_shape).as_np_ndarray().astype(valtype)
+                out_grad = mx.nd.random.uniform(-10.0, 10.0, 
shape=a_shape).as_np_ndarray().astype(atype)
+                if grad_req_a == 'add':
+                    if init_a_grad.ndim == 0:
+                        a.grad[()] = init_a_grad.item()
+                    else:
+                        a.grad[:] = init_a_grad
+                if grad_req_val == 'add':
+                    if init_val_grad.ndim == 0:
+                        val.grad[()] = init_val_grad.item()
+                    else:
+                        val.grad[:] = init_val_grad
+                mx_ret.backward(out_grad)
+                expected_bwd_a, expected_bwd_val = 
index_update_bwd(out_grad.asnumpy(), init_a_grad.asnumpy(), ind,
+                                                                    
init_val_grad.asnumpy(), ind_ndim, ind_num,
+                                                                    
grad_req_a, grad_req_val)
+
+                if grad_req_a == 'null':
+                    assert a.grad is None
+                else:
+                    assert_almost_equal(a.grad.asnumpy(), expected_bwd_a, rtol 
= eps, atol=eps)
+                if grad_req_val == 'null':
+                    assert val.grad is None
+                else:
+                    assert_almost_equal(val.grad.asnumpy(), expected_bwd_val, 
rtol = eps, atol=eps)
+
+            mx_out = npx.index_update(a, np.array(ind).astype(indtype), val)
+            check_index_update_forward(mx_out.asnumpy(), a.asnumpy(), 
ind.astype(indtype), val.asnumpy(), ind_ndim, ind_num, eps)
+
+
+@with_seed()
+@use_np
 def test_npx_batch_dot():
     ctx = mx.context.current_context()
     dtypes = ['float32', 'float64']

Reply via email to