This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8039377 add op npx.index_update (#18545)
8039377 is described below
commit 8039377e6630bcb00c5a95abdaf0851803686bc6
Author: JiangZhaoh <[email protected]>
AuthorDate: Wed Jun 17 01:45:30 2020 +0800
add op npx.index_update (#18545)
* add op npx.index_update
* remove debug comment
* change eps
* fix stupid error
* add blank line in docs
* gpu temporary space request alignment
* fix test error
Co-authored-by: Ubuntu <[email protected]>
---
python/mxnet/_numpy_op_doc.py | 72 ++++++
src/operator/tensor/index_add-inl.h | 2 +-
src/operator/tensor/index_add_backward.cc | 18 +-
.../tensor/{index_add-inl.h => index_update-inl.h} | 175 ++++++++------
src/operator/tensor/index_update.cc | 261 +++++++++++++++++++++
src/operator/tensor/index_update.cu | 204 ++++++++++++++++
tests/python/unittest/test_numpy_op.py | 162 +++++++++++++
7 files changed, 813 insertions(+), 81 deletions(-)
diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index fecd0e6..b8f4a49 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -630,6 +630,7 @@ def _npx_index_add(a, ind, val):
"""
Add values to input according to given indexes.
If exists repeate positions to be updated, the update value will be
accumulated.
+
Parameters
----------
a : ndarray
@@ -643,10 +644,12 @@ def _npx_index_add(a, ind, val):
- ind.dtype should be 'int32' or 'int64'
val : ndarray
Input data. The array to update the input 'a'.
+
Returns
-------
out : ndarray
The output array.
+
Examples
--------
>>> a = np.zeros((2, 3, 4))
@@ -699,6 +702,75 @@ def _npx_index_add(a, ind, val):
pass
+def _npx_index_update(a, ind, val):
+ """
+ Update values to input according to given indexes.
+ If multiple indices refer to the same location it is undefined which
update is chosen; it may choose
+ the order of updates arbitrarily and nondeterministically (e.g., due to
concurrent updates on some
+ hardware platforms). Recommend not to use repeate positions.
+
+ Parameters
+ ----------
+ a : ndarray
+ Input data. The array to be updated.
+ Support dtype: 'float32', 'float64', 'int32', 'int64'.
+ ind : ndarray
+ Indexes for indicating update positions.
+ For example, array([[0, 1], [2, 3], [4, 5]] indicates here are two
positions to
+ be updated, which is (0, 2, 4) and (1, 3, 5).
+ Note: - 'ind' cannot be empty array '[]', for that case, please use
operator 'add' instead.
+ - 0 <= ind.ndim <= 2.
+ - ind.dtype should be 'int32' or 'int64'
+ val : ndarray
+ Input data. The array to update the input 'a'.
+ Support dtype: 'float32', 'float64', 'int32', 'int64'.
+
+ Returns
+ -------
+ out : ndarray
+ The output array.
+
+ Examples
+ --------
+ >>> a = np.zeros((2, 3, 4))
+ >>> ind = np.array([[0, 0], [0, 0], [0, 1]], dtype='int32')
+ >>> val = np.arange(2).reshape(2) + 1
+ >>> b = npx.index_update(a, ind, val)
+ >>> b
+ array([[[1., 2., 0., 0.],
+ [0., 0., 0., 0.],
+ [0., 0., 0., 0.]],
+
+ [[0., 0., 0., 0.],
+ [0., 0., 0., 0.],
+ [0., 0., 0., 0.]]])
+
+ >>> ind=np.array([[0, 0], [0, 1]], dtype='int32')
+ >>> val = np.arange(8).reshape(2, 4)
+ >>> b = npx.index_update(a, ind, val)
+ >>> b
+ array([[[0., 1., 2., 3.],
+ [4., 5., 6., 7.],
+ [0., 0., 0., 0.]],
+
+ [[0., 0., 0., 0.],
+ [0., 0., 0., 0.],
+ [0., 0., 0., 0.]]])
+
+ >>> val = np.arange(4).reshape(4) # brocast 'val'
+ >>> b = npx.index_update(a, ind, val)
+ >>> b
+ array([[[0., 1., 2., 3.],
+ [0., 1., 2., 3.],
+ [0., 0., 0., 0.]],
+
+ [[0., 0., 0., 0.],
+ [0., 0., 0., 0.],
+ [0., 0., 0., 0.]]])
+ """
+ pass
+
+
def _np_diag(array, k=0):
"""
Extracts a diagonal or constructs a diagonal array.
diff --git a/src/operator/tensor/index_add-inl.h
b/src/operator/tensor/index_add-inl.h
index 83463da..122aa01 100644
--- a/src/operator/tensor/index_add-inl.h
+++ b/src/operator/tensor/index_add-inl.h
@@ -52,7 +52,7 @@ inline bool IndexModifyOpType(const nnvm::NodeAttrs& attrs,
CHECK_NE((*in_attrs)[1], -1);
CHECK_NE((*in_attrs)[2], -1);
CHECK_EQ((*in_attrs)[0], (*in_attrs)[2])
- << "index_add(a, ind, val) only support a.dtype == val.dtype";
+ << "index_add/index_update(a, ind, val) only support a.dtype == val.dtype";
CHECK((*in_attrs)[1] == mshadow::kInt64 ||
(*in_attrs)[1] == mshadow::kInt32)
<< "'ind' only support int dtype.";
diff --git a/src/operator/tensor/index_add_backward.cc
b/src/operator/tensor/index_add_backward.cc
index 158695b..0fe1009 100644
--- a/src/operator/tensor/index_add_backward.cc
+++ b/src/operator/tensor/index_add_backward.cc
@@ -67,15 +67,15 @@ void IndexAddBackwardValCPUCompute(DType* grad_val,
template<>
void IndexAddOpBackwardValImpl<cpu>(const OpContext& ctx,
- const TBlob& grad_val,
- const TBlob& ograd,
- const TBlob& t_ind,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_tail_shape,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_pre_stride,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
- const int tail_size, const int ind_num, const
int ind_ndim,
- const int ndim) {
+ const TBlob& grad_val,
+ const TBlob& ograd,
+ const TBlob& t_ind,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape,
+ const int tail_size, const int ind_num,
const int ind_ndim,
+ const int ndim) {
using namespace mshadow;
using namespace mxnet_op;
int seg = MXNET_SPECIAL_MAX_NDIM - ndim;
diff --git a/src/operator/tensor/index_add-inl.h
b/src/operator/tensor/index_update-inl.h
similarity index 52%
copy from src/operator/tensor/index_add-inl.h
copy to src/operator/tensor/index_update-inl.h
index 83463da..8319647 100644
--- a/src/operator/tensor/index_add-inl.h
+++ b/src/operator/tensor/index_update-inl.h
@@ -18,15 +18,17 @@
*/
/*!
- * \file index_add-inl.h
- * \brief Function definition of index_add operator
+ * \file index_update-inl.h
+ * \brief Function definition of index_update operator
*/
-#ifndef MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_
-#define MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_
+#ifndef MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_
+#define MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_
#include <mxnet/operator_util.h>
#include <vector>
#include <algorithm>
+#include "./index_add-inl.h"
+#include "./sort_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
@@ -34,51 +36,25 @@
namespace mxnet {
namespace op {
-inline bool IndexModifyOpShape(const nnvm::NodeAttrs& attrs,
- mxnet::ShapeVector* in_attrs,
- mxnet::ShapeVector* out_attrs) {
- CHECK_EQ(in_attrs->size(), 3U);
- CHECK_EQ(out_attrs->size(), 1U);
- SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
- return true;
-}
-
-inline bool IndexModifyOpType(const nnvm::NodeAttrs& attrs,
- std::vector<int>* in_attrs,
- std::vector<int>* out_attrs) {
- CHECK_EQ(in_attrs->size(), 3U);
- CHECK_EQ(out_attrs->size(), 1U);
- CHECK_NE((*in_attrs)[0], -1);
- CHECK_NE((*in_attrs)[1], -1);
- CHECK_NE((*in_attrs)[2], -1);
- CHECK_EQ((*in_attrs)[0], (*in_attrs)[2])
- << "index_add(a, ind, val) only support a.dtype == val.dtype";
- CHECK((*in_attrs)[1] == mshadow::kInt64 ||
- (*in_attrs)[1] == mshadow::kInt32)
- << "'ind' only support int dtype.";
- TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
- return (*out_attrs)[0] != -1;
-}
-
template<typename xpu, typename DType>
-void IndexAddForwardCalc(mshadow::Stream<xpu> *s,
- const int ind_num, DType* out,
- const DType* val,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_tail_shape,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_pre_stride,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_shape,
- const int a_tail_size,
- const int ind_ndim, const int* ind,
- const int a_ndim);
+void IndexUpdateForwardCalc(mshadow::Stream<xpu> *s,
+ const int ind_num, DType* out,
+ const DType* val,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_shape,
+ const int a_tail_size,
+ const int ind_ndim, const int* ind,
+ const int a_ndim);
template<typename xpu>
-void IndexAddOpForward(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs) {
+void IndexUpdateOpForward(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 3U);
@@ -88,7 +64,7 @@ void IndexAddOpForward(const nnvm::NodeAttrs& attrs,
TBlob ind = inputs[1];
TBlob val = inputs[2];
TBlob out = outputs[0];
- CHECK_GT(a.shape_.ndim(), 0) << "The first input is saclar, please use '+'
instead.";
+ CHECK_GT(a.shape_.ndim(), 0) << "The first input is saclar, please use '='
instead.";
int a_ndim = a.shape_.ndim();
CHECK_LE(a_ndim, MXNET_SPECIAL_MAX_NDIM)
<< "ndim should less than "<< MXNET_SPECIAL_MAX_NDIM
@@ -152,33 +128,33 @@ void IndexAddOpForward(const nnvm::NodeAttrs& attrs,
(Shape1(ind.shape_.Size()), s));
mxnet_op::copy(s, t_ind, ind);
MSHADOW_TYPE_SWITCH(a.type_flag_, DType, {
- IndexAddForwardCalc<xpu, DType>(s, ind_num,
- out.dptr<DType>(), val.dptr<DType>(),
- a_tail_shape, a_pre_stride,
- val_stride, val_shape, a_shape,
- a_tail_size, ind_ndim,
- t_ind.dptr<int>(), a_ndim);
+ IndexUpdateForwardCalc<xpu, DType>(s, ind_num,
+ out.dptr<DType>(), val.dptr<DType>(),
+ a_tail_shape, a_pre_stride,
+ val_stride, val_shape, a_shape,
+ a_tail_size, ind_ndim,
+ t_ind.dptr<int>(), a_ndim);
});
}
template<typename xpu>
-void IndexAddOpBackwardValImpl(const OpContext& ctx,
- const TBlob& grad_val,
- const TBlob& ograd,
- const TBlob& t_ind,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_tail_shape,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_pre_stride,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
- const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
- const int tail_size, const int ind_num, const
int ind_ndim,
- const int ndim);
+void IndexUpdateOpBackwardValImpl(const OpContext& ctx,
+ const TBlob& grad_val,
+ const TBlob& ograd,
+ const TBlob& t_ind,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const int tail_size, const int ind_num,
const int ind_ndim,
+ const int ndim);
template<typename xpu>
-inline void IndexAddOpBackwardVal(const nnvm::NodeAttrs& attrs,
- const OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs) {
+inline void IndexUpdateOpBackwardVal(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
if (req[0] == kNullOp) {
@@ -221,11 +197,68 @@ inline void IndexAddOpBackwardVal(const nnvm::NodeAttrs&
attrs,
}
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride =
mxnet_op::calc_stride(ograd_pre_shape);
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride =
mxnet_op::calc_stride(val_shape);
- IndexAddOpBackwardValImpl<xpu>(ctx, grad_val, ograd, t_ind,
ograd_tail_shape, ograd_pre_stride,
- val_stride, val_shape, tail_size, ind_num,
ind_ndim, out_ndim);
+ IndexUpdateOpBackwardValImpl<xpu>(ctx, grad_val, ograd, t_ind,
ograd_tail_shape, ograd_pre_stride,
+ val_stride, val_shape, tail_size, ind_num,
ind_ndim, out_ndim);
+}
+
+template<typename DType>
+struct ReqCopy {
+ MSHADOW_XINLINE static void Map(size_t i, DType* dest, const DType* src,
const int req) {
+ KERNEL_ASSIGN(dest[i], req, src[i]);
+ }
+};
+
+template<typename xpu>
+void IndexUpdateOpBackwardAImpl(const OpContext& ctx,
+ const TBlob& grad_a,
+ const TBlob& ograd,
+ const TBlob& t_ind,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
grada_pre_stride,
+ const int tail_size, const int ind_num, const
int ind_ndim,
+ const int seg, const int req);
+
+template<typename xpu>
+inline void IndexUpdateOpBackwardA(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ if (req[0] == kNullOp) {
+ return;
+ }
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ TBlob ograd = inputs[0];
+ TBlob ind = inputs[1];
+ const TBlob& grad_a = outputs[0];
+ mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+ // get the number of 'ind' index
+ if (ind.shape_.ndim() == 0) {
+ ind.shape_ = Shape2(1, 1);
+ } else if (ind.shape_.ndim() == 1) {
+ ind.shape_ = Shape2(1, ind.shape_[0]);
+ }
+ int ind_ndim = ind.shape_[0];
+ int ind_num = ind.shape_[1];
+ int out_ndim = ograd.shape_.ndim();
+ int tail_size = static_cast<int>(ograd.shape_.ProdShape(ind_ndim, out_ndim));
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_shape;
+ for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = out_ndim - 1; i >= 0; --i, --j)
{
+ grada_shape[i] = (j >= 0) ? grad_a.shape_[j] : 1;
+ }
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_shape(grada_shape);
+ int seg = MXNET_SPECIAL_MAX_NDIM - out_ndim;
+ for (int i = seg + ind_ndim; i < seg + out_ndim; ++i) {
+ grada_pre_shape[i] = 1;
+ }
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride =
mxnet_op::calc_stride(grada_pre_shape);
+ IndexUpdateOpBackwardAImpl<xpu>(ctx, grad_a, ograd, ind, grada_pre_stride,
+ tail_size, ind_num, ind_ndim, seg, req[0]);
}
} // namespace op
} // namespace mxnet
-#endif // MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_
+#endif // MXNET_OPERATOR_TENSOR_INDEX_UPDATE_INL_H_
diff --git a/src/operator/tensor/index_update.cc
b/src/operator/tensor/index_update.cc
new file mode 100644
index 0000000..ffd5d29
--- /dev/null
+++ b/src/operator/tensor/index_update.cc
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file index_update-inl.cc
+ * \brief implementation of index_update operator
+*/
+#include <vector>
+#include "./index_update-inl.h"
+
+namespace mxnet {
+namespace op {
+
+template<typename DType>
+struct IndexUpdateForwardCPUKernel {
+ MSHADOW_XINLINE static void Map(size_t i, DType* out,
+ const DType* val,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_shape,
+ const int a_tail_size, const int ind_num,
+ const int ind_ndim, const int* ind,
+ const int a_ndim, const int seg) {
+ index_t id = 0;
+ for (int dim = 0; dim < ind_ndim; ++dim) {
+ CHECK_LT(ind[dim * ind_num + i], a_shape[seg + dim])
+ << "IndexError: index " << ind[dim * ind_num + i]
+ << " is out of bounds for axis " << dim
+ << " with size " << a_shape[seg + dim];
+ CHECK_GE(ind[dim * ind_num + i], 0)
+ << "IndexError: index " << ind[dim * ind_num + i]
+ << " should be greater or equal to 0.";
+ id += a_pre_stride[seg + dim] * ind[dim * ind_num + i];
+ }
+ id *= a_tail_size;
+ for (int _i = 0; _i < a_tail_size; ++_i) {
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_id = mxnet_op::unravel(_i,
a_tail_shape);
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+ for (int _j = 0; _j < seg; ++_j) {
+ val_id[_j] = 0;
+ }
+ for (int _j = seg; _j < seg + a_ndim; ++_j) {
+ val_id[_j] = (val_shape[_j] == 1) ? 0 : a_tail_id[_j];
+ }
+ val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 :
i;
+ index_t val_dest = mxnet_op::dot(val_id, val_stride);
+ out[id + _i] = val[val_dest];
+ }
+ }
+};
+
+template<typename xpu, typename DType>
+void IndexUpdateForwardCalc(mshadow::Stream<xpu> *s,
+ const int ind_num, DType* out,
+ const DType* val,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_shape,
+ const int a_tail_size,
+ const int ind_ndim, const int* ind,
+ const int a_ndim) {
+ using namespace mxnet_op;
+ using namespace mshadow;
+ int seg = MXNET_SPECIAL_MAX_NDIM - a_ndim;
+ Kernel<IndexUpdateForwardCPUKernel<DType>, xpu>::Launch(
+ s, ind_num, out, val, a_tail_shape, a_pre_stride,
+ val_stride, val_shape, a_shape, a_tail_size, ind_num,
+ ind_ndim, ind, a_ndim, seg);
+}
+
+
+template<typename DType>
+void IndexUpdateBackwardValCPUCompute(DType* grad_val,
+ const DType* ograd,
+ const int* ind_vec,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape,
+ const int ograd_tail_size, const int
ind_num,
+ const int ind_ndim, const int out_ndim,
+ const int seg) {
+ #pragma omp parallel for
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+ for (index_t i = 0; i < static_cast<index_t>(ind_num); ++i) {
+ index_t id = 0;
+ for (int dim = 0; dim < ind_ndim; ++dim) {
+ id += ograd_pre_stride[seg + dim] * ind_vec[dim * ind_num + i];
+ }
+ id *= ograd_tail_size;
+ for (int _i = 0; _i < ograd_tail_size; ++_i) {
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_id =
+ mxnet_op::unravel(_i, ograd_tail_shape);
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+ for (int _j = 0; _j < seg; ++_j) {
+ val_id[_j] = 0;
+ }
+ for (int _j = seg; _j < seg + out_ndim; ++_j) {
+ val_id[_j] = (val_shape[_j] == 1) ? 0 : ograd_tail_id[_j];
+ }
+ val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 :
i;
+ index_t val_dest = mxnet_op::dot(val_id, val_stride);
+ #pragma omp critical
+ {
+ grad_val[val_dest] += ograd[id + _i];
+ }
+ }
+ }
+}
+
+template<>
+void IndexUpdateOpBackwardValImpl<cpu>(const OpContext& ctx,
+ const TBlob& grad_val,
+ const TBlob& ograd,
+ const TBlob& t_ind,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const int tail_size, const int ind_num, const
int ind_ndim,
+ const int ndim) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ int seg = MXNET_SPECIAL_MAX_NDIM - ndim;
+ MSHADOW_TYPE_SWITCH(grad_val.type_flag_, DType, {
+ IndexUpdateBackwardValCPUCompute<DType>(
+ grad_val.dptr<DType>(), ograd.dptr<DType>(), t_ind.dptr<int>(),
+ ograd_tail_shape, ograd_pre_stride, val_stride, val_shape, tail_size,
+ ind_num, ind_ndim, ndim, seg);
+ });
+}
+
+template<typename DType>
+void IndexUpdateBackwardACPUCompute(DType* out_grad,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride,
+ const int tail_size, const int ind_num,
const int ind_ndim,
+ const int32_t* ind, const int seg,
+ const int req) {
+ #pragma omp parallel for
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+ for (index_t i = 0; i < static_cast<index_t>(ind_num); ++i) {
+ index_t id = 0;
+ for (int dim = 0; dim < ind_ndim; ++dim) {
+ id += grada_pre_stride[seg + dim] * ind[dim * ind_num + i];
+ }
+ id *= tail_size;
+ for (int _i = 0; _i < tail_size; ++_i) {
+ out_grad[id + _i] = 0;
+ }
+ }
+}
+
+template<>
+void IndexUpdateOpBackwardAImpl<cpu>(const OpContext& ctx,
+ const TBlob& grad_a,
+ const TBlob& ograd,
+ const TBlob& ind,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride,
+ const int tail_size, const int ind_num,
const int ind_ndim,
+ const int seg, const int req) {
+ using namespace mxnet_op;
+ using namespace mshadow;
+ mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+ MSHADOW_TYPE_SWITCH(grad_a.type_flag_, DType, {
+ size_t temp_mem_size = ind.shape_.Size() * sizeof(int) +
+ ograd.shape_.Size() * sizeof(DType);
+ Tensor<cpu, 1, char> temp_mem =
+ ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(temp_mem_size), s);
+ TBlob t_ograd = TBlob(temp_mem.dptr_, ograd.shape_, ograd.dev_mask(),
+ ograd.type_flag_, ograd.dev_id());
+ TBlob t_ind = TBlob(temp_mem.dptr_ + ograd.Size() * sizeof(DType),
ind.shape_, ind.dev_mask(),
+ mshadow::kInt32, ind.dev_id());
+ mxnet_op::copy(s, t_ograd, ograd);
+ mxnet_op::copy(s, t_ind, ind);
+ IndexUpdateBackwardACPUCompute<DType>(t_ograd.dptr<DType>(),
+ grada_pre_stride, tail_size,
+ ind_num, ind_ndim,
+ t_ind.dptr<int32_t>(), seg, req);
+ Kernel<ReqCopy<DType>, cpu>::Launch(s, grad_a.shape_.Size(),
grad_a.dptr<DType>(),
+ t_ograd.dptr<DType>(), req);
+ });
+}
+
+NNVM_REGISTER_OP(_npx_index_update)
+.describe(R"code(This operators implements the "=" mimic function.
+)code" ADD_FILELINE)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ return std::vector<std::string>{"a", "ind", "val"};
+ })
+.set_attr<mxnet::FInferShape>("FInferShape", IndexModifyOpShape)
+.set_attr<nnvm::FInferType>("FInferType", IndexModifyOpType)
+.set_attr<FCompute>("FCompute<cpu>", IndexUpdateOpForward<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<nnvm::FGradient>("FGradient",
+ [](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
+ auto a_grad = MakeNode("_backward_index_update_a", n->attrs.name +
"_backward_a",
+ {ograds[0], n->inputs[1]}, nullptr, &n);
+ auto idx_grad = MakeNode("zeros_like", n->attrs.name +
"_backward_indices",
+ {n->inputs[1]}, nullptr, &n);
+ auto val_grad = MakeNode("_backward_index_update_val", n->attrs.name +
"_backward_val",
+ {ograds[0], n->inputs[1]}, nullptr, &n);
+ // auto val_grad = MakeNode("zeros_like", n->attrs.name +
"_backward_val",
+ // {n->inputs[2]}, nullptr, &n);
+ std::vector<nnvm::NodeEntry> ret;
+ ret.emplace_back(a_grad);
+ ret.emplace_back(idx_grad);
+ ret.emplace_back(val_grad);
+ return ret;
+ })
+.add_argument("a", "NDArray-or-Symbol", "Input ndarray")
+.add_argument("ind", "NDArray-or-Symbol", "Index ndarray")
+.add_argument("val", "NDArray-or-Symbol", "Update ndarray");
+
+
+NNVM_REGISTER_OP(_backward_index_update_a)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<FCompute>("FCompute<cpu>", IndexUpdateOpBackwardA<cpu>);
+
+
+NNVM_REGISTER_OP(_backward_index_update_val)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest",
+ [](const NodeAttrs& attrs) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+ })
+.set_attr<FCompute>("FCompute<cpu>", IndexUpdateOpBackwardVal<cpu>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/src/operator/tensor/index_update.cu
b/src/operator/tensor/index_update.cu
new file mode 100644
index 0000000..3ed788f
--- /dev/null
+++ b/src/operator/tensor/index_update.cu
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file index_update.cu
+ * \brief GPU implementation of index_update operator
+ */
+
+#include <cub/cub.cuh>
+#include "./index_update-inl.h"
+#include "../tensor/util/tensor_util-inl.cuh"
+#include "../tensor/util/tensor_util-inl.h"
+
+
+namespace mxnet {
+namespace op {
+
+template<typename DType>
+struct IndexUpdateForwardGPUKernel {
+ MSHADOW_XINLINE static void Map(size_t i, DType* out,
+ const DType* val,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const int a_tail_size, const int ind_num,
+ const int ind_ndim, const int* ind,
+ const int a_ndim, const int seg) {
+ index_t id = 0;
+ for (int dim = 0; dim < ind_ndim; ++dim) {
+ id += a_pre_stride[seg + dim] * ind[dim * ind_num + i];
+ }
+ id *= a_tail_size;
+ for (int _i = 0; _i < a_tail_size; ++_i) {
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_id = mxnet_op::unravel(_i,
a_tail_shape);
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+ for (int _j = 0; _j < seg; ++_j) {
+ val_id[_j] = 0;
+ }
+ for (int _j = seg; _j < seg + a_ndim; ++_j) {
+ val_id[_j] = (val_shape[_j] == 1) ? 0 : a_tail_id[_j];
+ }
+ val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 :
i;
+ index_t val_dest = mxnet_op::dot(val_id, val_stride);
+ out[id + _i] = val[val_dest];
+ }
+ }
+};
+
+template<typename xpu, typename DType>
+void IndexUpdateForwardCalc(mshadow::Stream<xpu> *s,
+ const int ind_num, DType* out,
+ const DType* val,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
a_shape,
+ const int a_tail_size,
+ const int ind_ndim, const int* ind,
+ const int a_ndim) {
+ using namespace mxnet_op;
+ using namespace mshadow;
+ int seg = MXNET_SPECIAL_MAX_NDIM - a_ndim;
+ Kernel<IndexUpdateForwardGPUKernel<DType>, xpu>::Launch(
+ s, ind_num, out, val, a_tail_shape, a_pre_stride,
+ val_stride, val_shape, a_tail_size, ind_num,
+ ind_ndim, ind, a_ndim, seg);
+}
+
+
+struct IndexUpdateBackwardValGPUKernel {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(size_t i, DType* grad_val,
+ const DType* ograd, const int* ind_vec,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const int ograd_tail_size, const int ind_num,
+ const int ind_ndim, const int out_ndim,
const int seg) {
+ index_t id = 0;
+ for (int dim = 0; dim < ind_ndim; ++dim) {
+ id += ograd_pre_stride[seg + dim] * ind_vec[dim * ind_num + i];
+ }
+ id *= ograd_tail_size;
+ for (int _i = 0; _i < ograd_tail_size; ++_i) {
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_id =
+ mxnet_op::unravel(_i, ograd_tail_shape);
+ mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+ for (int _j = 0; _j < seg; ++_j) {
+ val_id[_j] = 0;
+ }
+ for (int _j = seg; _j < seg + out_ndim; ++_j) {
+ val_id[_j] = (val_shape[_j] == 1) ? 0 : ograd_tail_id[_j];
+ }
+ val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 :
i;
+ index_t val_dest = mxnet_op::dot(val_id, val_stride);
+ atomicAdd(&grad_val[val_dest], ograd[id + _i]);
+ }
+ }
+};
+
+template<>
+void IndexUpdateOpBackwardValImpl<gpu>(const OpContext& ctx,
+ const TBlob& grad_val,
+ const TBlob& ograd,
+ const TBlob& t_ind,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_tail_shape,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
ograd_pre_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_stride,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
val_shape,
+ const int tail_size, const int ind_num, const
int ind_ndim,
+ const int ndim) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+ int seg = MXNET_SPECIAL_MAX_NDIM - ndim;
+ MSHADOW_TYPE_SWITCH(grad_val.type_flag_, DType, {
+ Kernel<IndexUpdateBackwardValGPUKernel, gpu>::Launch(
+ s, ind_num, grad_val.dptr<DType>(), ograd.dptr<DType>(),
t_ind.dptr<int>(),
+ ograd_tail_shape, ograd_pre_stride, val_stride, val_shape, tail_size,
+ ind_num, ind_ndim, ndim, seg);
+ });
+}
+
+template<typename DType>
+struct IndexUpdateBackwardAGPUKernel {
+ MSHADOW_XINLINE static void Map(size_t i, DType* out_grad,
+ const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM>
grada_pre_stride,
+ const int tail_size, const int ind_num,
const int ind_ndim,
+ const int32_t* ind, const int seg,
+ const int req) {
+ index_t id = 0;
+ for (int dim = 0; dim < ind_ndim; ++dim) {
+ id += grada_pre_stride[seg + dim] * ind[dim * ind_num + i];
+ }
+ id *= tail_size;
+ for (int _i = 0; _i < tail_size; ++_i) {
+ out_grad[id + _i] = 0;
+ }
+ }
+};
+
+template<>
+void IndexUpdateOpBackwardAImpl<gpu>(const OpContext& ctx,
+ const TBlob& grad_a,
+ const TBlob& ograd,
+ const TBlob& ind,
+ const
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride,
+ const int tail_size, const int ind_num,
const int ind_ndim,
+ const int seg, const int req) {
+ using namespace mxnet_op;
+ using namespace mshadow;
+ mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+ MSHADOW_TYPE_SWITCH(grad_a.type_flag_, DType, {
+ size_t alignment = std::max(sizeof(DType), sizeof(int32_t));
+ size_t id_size = PadBytes(sizeof(int32_t) * ind.Size(), alignment);
+ size_t ograd_size = PadBytes(sizeof(DType) * ograd.Size(), alignment);
+ size_t temp_mem_size = id_size + ograd_size;
+ Tensor<gpu, 1, char> temp_mem =
+ ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_mem_size), s);
+ TBlob t_ograd = TBlob(temp_mem.dptr_, ograd.shape_, ograd.dev_mask(),
+ ograd.type_flag_, ograd.dev_id());
+ TBlob t_ind = TBlob(temp_mem.dptr_ + ograd_size, ind.shape_,
ind.dev_mask(),
+ mshadow::kInt32, ind.dev_id());
+ mxnet_op::copy(s, t_ograd, ograd);
+ mxnet_op::copy(s, t_ind, ind);
+ Kernel<IndexUpdateBackwardAGPUKernel<DType>, gpu>::Launch(s, ind_num,
t_ograd.dptr<DType>(),
+ grada_pre_stride,
tail_size,
+ ind_num, ind_ndim,
+
t_ind.dptr<int32_t>(), seg, req);
+ Kernel<ReqCopy<DType>, gpu>::Launch(s, grad_a.shape_.Size(),
grad_a.dptr<DType>(),
+ t_ograd.dptr<DType>(), req);
+ });
+}
+
+NNVM_REGISTER_OP(_npx_index_update)
+.set_attr<FCompute>("FCompute<gpu>", IndexUpdateOpForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_index_update_val)
+.set_attr<FCompute>("FCompute<gpu>", IndexUpdateOpBackwardVal<gpu>);
+
+NNVM_REGISTER_OP(_backward_index_update_a)
+.set_attr<FCompute>("FCompute<gpu>", IndexUpdateOpBackwardA<gpu>);
+
+} // namespace op
+} // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_op.py
b/tests/python/unittest/test_numpy_op.py
index 3ea6119..6440dec 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1479,6 +1479,168 @@ def test_npx_index_add():
@with_seed()
@use_np
+def test_npx_index_update():
+ class TestIndexUpdate(HybridBlock):
+ def __init__(self):
+ super(TestIndexUpdate, self).__init__()
+
+ def hybrid_forward(self, F, a, ind, val):
+ return F.npx.index_update(a, ind, val)
+
+ def check_index_update_forward(mx_ret, a, ind, val, ind_ndim, ind_num,
eps):
+ if val.dtype != a.dtype:
+ val = val.astype(a.dtype)
+ ind_arr = ind.transpose()
+ if ind_arr.ndim == 0:
+ ind_arr = _np.array([ind_arr])
+ for i in range(ind_arr.shape[0]):
+ t_ind = ind_arr[i]
+ t_ind = tuple(t_ind.tolist()) if type(t_ind) is _np.ndarray else
t_ind.tolist()
+ if val.ndim + ind_ndim > a.ndim:
+ t_val = val[tuple([0 if val.shape[0]==1 else i])]
+ if type(t_val) is _np.ndarray and t_val.shape[0] == 1:
+ expect_tmp = _np.squeeze(t_val, axis=0)
+ else:
+ expect_tmp = t_val
+ else:
+ expect_tmp = val
+ mx_tmp = mx_ret[t_ind]
+ close_pos = _np.where(_np.isclose(expect_tmp, mx_tmp, rtol=eps,
atol=eps))
+ if a[t_ind].ndim == 0:
+ if close_pos[0].size == 1:
+ mx_ret[t_ind] = 0
+ a[t_ind] = 0
+ else:
+ mx_ret[t_ind][close_pos] = 0
+ a[t_ind][close_pos] = 0
+ assert_almost_equal(mx_ret, a, rtol=eps, atol=eps)
+
+ def index_update_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num,
grad_req_a, grad_req_val):
+ if grad_req_a == 'add':
+ init_a_grad = _np.array(a_grad)
+ if grad_req_val == 'add':
+ init_val_grad = _np.array(val_grad)
+ a_grad = _np.zeros(a_grad.shape) + out_grad
+ a_grad = a_grad.astype(a_grad.dtype)
+ val_grad = _np.zeros(val_grad.shape).astype(val_grad.dtype)
+
+ ind_arr = ind.transpose()
+ if ind_arr.ndim == 0:
+ ind_arr = _np.array([ind_arr])
+ for i in range(ind_arr.shape[0]):
+ t_ind = ind_arr[i]
+ t_ind = tuple(ind_arr[i].tolist()) if type(ind_arr[i]) is
_np.ndarray else ind_arr[i].tolist()
+ a_grad[t_ind] = 0
+ if val_grad.ndim + ind_ndim > a_grad.ndim:
+ idx = 0 if val_grad.shape[0]==1 else i
+ t_grad = out_grad[t_ind]
+ t_grad_shape = _np.array(t_grad.shape)
+ val_grad_shape = _np.array(val_grad[idx].shape)
+ if type(val_grad[idx]) is not _np.ndarray:
+ t_grad = _np.sum(t_grad)
+ else:
+ is_not_equal = t_grad_shape - val_grad_shape
+ if _np.any(is_not_equal):
+ broadcast_dim = _np.nonzero(_np.where(is_not_equal, 1,
0))
+ t_grad = _np.sum(t_grad,
axis=tuple(broadcast_dim[0].reshape(1, -1)[0]), keepdims=True)
+ val_grad[idx] += t_grad
+ else:
+ t_grad = out_grad[t_ind]
+ if type(val_grad) is not _np.ndarray or val_grad.shape == ():
+ t_grad = _np.sum(t_grad)
+ else:
+ if type(t_grad) is _np.ndarray:
+ ext_dim = t_grad.ndim() - val_grad.ndim()
+ if ext_dim:
+ t_grad = _np.sum(t_grad,
axis=tuple(_np.arange(ext_dim)))
+ t_grad_shape = _np.array(t_grad.shape)
+ val_grad_shape = _np.array(val_grad.shape)
+ is_not_equal = t_grad_shape - val_grad_shape
+ if _np.any(is_not_equal):
+ broadcast_dim =
_np.nonzero(_np.where(is_not_equal, 1, 0))
+ t_grad = _np.sum(t_grad,
axis=tuple(broadcast_dim.reshape(1, -1)[0]), keepdims=True)
+ val_grad += t_grad
+ if grad_req_a == 'add':
+ a_grad += init_a_grad
+ if grad_req_val == 'add':
+ val_grad += init_val_grad
+ return a_grad, val_grad
+
+ # a.shape, ind.shape, val.shape, ind_ndim, ind_num
+ configs = [((2, ), np.array(1, dtype=_np.int32), (1, ), 1, 1)]
+ shape = tuple(_np.random.randint(1, 6, size=(4))) # a.shape
+ for ind_ndim in range(1, 5): # ind.shape: (ind_ndim, ind_num)
+ ind_num = _np.random.randint(1, 7)
+ ind = []
+ for ind_dim in range(ind_ndim):
+ ind.append(_np.random.randint(0, shape[ind_dim], size=(ind_num)))
+ ind = _np.array(ind).astype(_np.int32)
+ # case: val is scalar
+ configs.append(tuple([shape, ind, (), ind_ndim, ind_num]))
+ for val_ndim in range(1, 5 - ind_ndim):
+ val_shape = [1 if _np.random.randint(0, 5)==0 else ind_num]
+ for val_dim in range(ind_ndim, 4):
+ val_shape.append(1 if _np.random.randint(0, 5)==0 else
shape[val_dim])
+ # case: val is tensor
+ configs.append(tuple([shape, ind, tuple(val_shape), ind_ndim,
ind_num]))
+
+ dtypes = ['float32', 'float64', 'int32', 'int64']
+ grad_req = ['write', 'null', 'add']
+ for hybridize, grad_req_a, grad_req_val, dtype, indtype in \
+ itertools.product([True, False], grad_req, grad_req, dtypes, ['int32',
'int64']):
+ for a_shape, ind, val_shape ,ind_ndim, ind_num in configs:
+ eps = 1e-3
+ atype = dtype
+ valtype = dtype
+ test_index_update = TestIndexUpdate()
+ if hybridize:
+ test_index_update.hybridize()
+ a = mx.nd.random.uniform(-10.0, 10.0,
shape=a_shape).as_np_ndarray().astype(atype)
+ a.attach_grad(grad_req=grad_req_a)
+ val = mx.nd.random.uniform(-10.0, 10.0,
shape=val_shape).as_np_ndarray().astype(valtype)
+ val.attach_grad(grad_req=grad_req_val)
+ with mx.autograd.record():
+ mx_ret = test_index_update(a, np.array(ind).astype(indtype),
val)
+ assert mx_ret.shape == a.shape
+ assert mx_ret.dtype == a.dtype
+ check_index_update_forward(mx_ret.asnumpy(), a.asnumpy(),
ind.astype(indtype), val.asnumpy(), ind_ndim, ind_num, eps)
+
+ if atype not in ['float16', 'float32', 'float64'] or valtype not
in ['float16', 'float32', 'float64']:
+ continue
+ if grad_req_a != 'null' or grad_req_val != 'null':
+ init_a_grad = mx.nd.random.uniform(-10.0, 10.0,
shape=a_shape).as_np_ndarray().astype(atype)
+ init_val_grad = mx.nd.random.uniform(-10.0, 10.0,
shape=val_shape).as_np_ndarray().astype(valtype)
+ out_grad = mx.nd.random.uniform(-10.0, 10.0,
shape=a_shape).as_np_ndarray().astype(atype)
+ if grad_req_a == 'add':
+ if init_a_grad.ndim == 0:
+ a.grad[()] = init_a_grad.item()
+ else:
+ a.grad[:] = init_a_grad
+ if grad_req_val == 'add':
+ if init_val_grad.ndim == 0:
+ val.grad[()] = init_val_grad.item()
+ else:
+ val.grad[:] = init_val_grad
+ mx_ret.backward(out_grad)
+ expected_bwd_a, expected_bwd_val =
index_update_bwd(out_grad.asnumpy(), init_a_grad.asnumpy(), ind,
+
init_val_grad.asnumpy(), ind_ndim, ind_num,
+
grad_req_a, grad_req_val)
+
+ if grad_req_a == 'null':
+ assert a.grad is None
+ else:
+ assert_almost_equal(a.grad.asnumpy(), expected_bwd_a, rtol
= eps, atol=eps)
+ if grad_req_val == 'null':
+ assert val.grad is None
+ else:
+ assert_almost_equal(val.grad.asnumpy(), expected_bwd_val,
rtol = eps, atol=eps)
+
+ mx_out = npx.index_update(a, np.array(ind).astype(indtype), val)
+ check_index_update_forward(mx_out.asnumpy(), a.asnumpy(),
ind.astype(indtype), val.asnumpy(), ind_ndim, ind_num, eps)
+
+
+@with_seed()
+@use_np
def test_npx_batch_dot():
ctx = mx.context.current_context()
dtypes = ['float32', 'float64']