leezu commented on a change in pull request #18545:
URL: https://github.com/apache/incubator-mxnet/pull/18545#discussion_r440470722
##########
File path: tests/python/unittest/test_numpy_op.py
##########
@@ -1477,6 +1477,165 @@ def index_add_bwd(out_grad, a_grad, ind, val_grad,
ind_ndim, ind_num, grad_req_a
assert_almost_equal(mx_out.asnumpy(), expected_ret, rtol=eps,
atol=eps)
+@with_seed()
+@use_np
+def test_npx_index_update():
+ class TestIndexUpdate(HybridBlock):
+ def __init__(self):
+ super(TestIndexUpdate, self).__init__()
+
+ def hybrid_forward(self, F, a, ind, val):
+ return F.npx.index_update(a, ind, val)
+
+ def check_index_update_forward(mx_ret, a, ind, val, ind_ndim, ind_num,
eps):
+ if val.dtype != a.dtype:
+ val = val.astype(a.dtype)
+ ind_arr = ind.transpose()
+ if ind_arr.ndim == 0:
+ ind_arr = _np.array([ind_arr])
+ for i in range(ind_arr.shape[0]):
+ t_ind = ind_arr[i]
+ t_ind = tuple(t_ind.tolist()) if type(t_ind) is _np.ndarray else
t_ind.tolist()
+ if val.ndim + ind_ndim > a.ndim:
+ t_val = val[tuple([0 if val.shape[0]==1 else i])]
+ if type(t_val) is _np.ndarray and t_val.shape[0] == 1:
+ expect_tmp = _np.squeeze(t_val, axis=0)
+ else:
+ expect_tmp = t_val
+ else:
+ expect_tmp = val
+ mx_tmp = mx_ret[t_ind]
+ if _np.allclose(expect_tmp, mx_tmp, rtol=eps, atol=eps):
+ mx_ret[t_ind] = 0
+ a[t_ind] = 0
+ assert_almost_equal(mx_ret, a, rtol=eps, atol=eps)
+
+ def index_update_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num,
grad_req_a, grad_req_val):
+ if grad_req_a == 'add':
+ init_a_grad = _np.array(a_grad)
+ if grad_req_val == 'add':
+ init_val_grad = _np.array(val_grad)
+ a_grad = _np.zeros(a_grad.shape) + out_grad
+ a_grad = a_grad.astype(a_grad.dtype)
+ val_grad = _np.zeros(val_grad.shape).astype(val_grad.dtype)
+
+ ind_arr = ind.transpose()
+ if ind_arr.ndim == 0:
+ ind_arr = _np.array([ind_arr])
+ for i in range(ind_arr.shape[0]):
+ t_ind = ind_arr[i]
+ t_ind = tuple(ind_arr[i].tolist()) if type(ind_arr[i]) is
_np.ndarray else ind_arr[i].tolist()
+ a_grad[t_ind] = 0
+ if val_grad.ndim + ind_ndim > a_grad.ndim:
+ idx = 0 if val_grad.shape[0]==1 else i
+ t_grad = out_grad[t_ind]
+ t_grad_shape = _np.array(t_grad.shape)
+ val_grad_shape = _np.array(val_grad[idx].shape)
+ if type(val_grad[idx]) is not _np.ndarray:
+ t_grad = _np.sum(t_grad)
+ else:
+ is_not_equal = t_grad_shape - val_grad_shape
+ if _np.any(is_not_equal):
+ broadcast_dim = _np.nonzero(_np.where(is_not_equal, 1,
0))
+ t_grad = _np.sum(t_grad,
axis=tuple(broadcast_dim[0].reshape(1, -1)[0]), keepdims=True)
+ val_grad[idx] += t_grad
+ else:
+ t_grad = out_grad[t_ind]
+ if type(val_grad) is not _np.ndarray or val_grad.shape == ():
+ t_grad = _np.sum(t_grad)
+ else:
+ if type(t_grad) is _np.ndarray:
+ ext_dim = t_grad.ndim() - val_grad.ndim()
+ if ext_dim:
+ t_grad = _np.sum(t_grad,
axis=tuple(_np.arange(ext_dim)))
+ t_grad_shape = _np.array(t_grad.shape)
+ val_grad_shape = _np.array(val_grad.shape)
+ is_not_equal = t_grad_shape - val_grad_shape
+ if _np.any(is_not_equal):
+ broadcast_dim =
_np.nonzero(_np.where(is_not_equal, 1, 0))
+ t_grad = _np.sum(t_grad,
axis=tuple(broadcast_dim.reshape(1, -1)[0]), keepdims=True)
+ val_grad += t_grad
+ if grad_req_a == 'add':
+ a_grad += init_a_grad
+ if grad_req_val == 'add':
+ val_grad += init_val_grad
+ return a_grad, val_grad
+
+ # a.shape, ind.shape, val.shape, ind_ndim, ind_num
+ configs = [((2, ), np.array(1, dtype=_np.int32), (1, ), 1, 1)]
+ shape = tuple(_np.random.randint(1, 6, size=(4))) # a.shape
+ for ind_ndim in range(1, 5): # ind.shape: (ind_ndim, ind_num)
+ ind_num = _np.random.randint(1, 7)
+ ind = []
+ for ind_dim in range(ind_ndim):
+ ind.append(_np.random.randint(0, shape[ind_dim], size=(ind_num)))
+ ind = _np.array(ind).astype(_np.int32)
+ # case: val is scalar
+ configs.append(tuple([shape, ind, (), ind_ndim, ind_num]))
+ for val_ndim in range(1, 5 - ind_ndim):
+ val_shape = [1 if _np.random.randint(0, 5)==0 else ind_num]
+ for val_dim in range(ind_ndim, 4):
+ val_shape.append(1 if _np.random.randint(0, 5)==0 else
shape[val_dim])
+ # case: val is tensor
+ configs.append(tuple([shape, ind, tuple(val_shape), ind_ndim,
ind_num]))
+
+ dtypes = ['float32', 'float64', 'int32', 'int64']
+ grad_req = ['write', 'null', 'add']
+ for hybridize, grad_req_a, grad_req_val, dtype, indtype in \
+ itertools.product([True, False], grad_req, grad_req, dtypes, ['int32',
'int64']):
+ for a_shape, ind, val_shape ,ind_ndim, ind_num in configs:
+ eps = 1e-3
+ if sys.platform.startswith('linux'):
+ eps = 1e-2
Review comment:
eps should not depend on the platform. What's the root-cause here?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]