This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 20957ff Numpy Ops Large Tensor Tests Batch 2 (#18968)
20957ff is described below
commit 20957ffcf1d76495ffa0132930f1bbb4917afd31
Author: Zhaoqi Zhu <[email protected]>
AuthorDate: Thu Aug 20 16:04:36 2020 -0700
Numpy Ops Large Tensor Tests Batch 2 (#18968)
* remove asnumpy()
* add more tests
* add more tests
* add more tests
Co-authored-by: Ubuntu <[email protected]>
---
tests/nightly/test_np_large_array.py | 240 +++++++++++++++++++++++++++++++++--
1 file changed, 232 insertions(+), 8 deletions(-)
diff --git a/tests/nightly/test_np_large_array.py
b/tests/nightly/test_np_large_array.py
index 28d8aeb..e6ef9cf 100644
--- a/tests/nightly/test_np_large_array.py
+++ b/tests/nightly/test_np_large_array.py
@@ -40,6 +40,7 @@ INT_OVERFLOW = 2**31
HALF_INT_OVERFLOW = 2**30
DOUBLE_INT_OVERFLOW = 2**32
+
@use_np
def test_gluon_embedding():
m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X)
@@ -157,7 +158,7 @@ def test_all():
A.attach_grad()
with mx.autograd.record():
B = np.all(A)
- assert B.asnumpy() == True
+ assert B == True
B.backward()
assert A.grad.shape == (INT_OVERFLOW, 2)
assert A.grad[0][0] == 0
@@ -169,7 +170,7 @@ def test_amin():
A.attach_grad()
with mx.autograd.record():
B = np.amin(A)
- assert B.asnumpy() == -1.0
+ assert B == -1.0
B.backward()
assert A.grad.shape == (INT_OVERFLOW, 2)
assert A.grad[0][0] == 0
@@ -182,7 +183,7 @@ def test_amax():
with mx.autograd.record():
B = np.amax(A)
print(B)
- assert B.asnumpy() == 1.0
+ assert B == 1.0
B.backward()
assert A.grad.shape == (INT_OVERFLOW, 2)
assert A.grad[0][0] == 0
@@ -195,7 +196,7 @@ def test_argmin():
with mx.autograd.record():
B = np.argmin(A)
print(B)
- assert B.asnumpy() == 21
+ assert B == 21
B.backward()
assert A.grad.shape == (INT_OVERFLOW, 2)
assert A.grad[0][0] == 0
@@ -208,7 +209,7 @@ def test_argmax():
with mx.autograd.record():
B = np.argmax(A)
print(B)
- assert B.asnumpy() == 21
+ assert B == 21
B.backward()
assert A.grad.shape == (INT_OVERFLOW, 2)
assert A.grad[0][0] == 0
@@ -240,7 +241,7 @@ def test_any():
A.attach_grad()
with mx.autograd.record():
B = np.any(A)
- assert B.asnumpy() == False
+ assert B == False
B.backward()
assert A.grad.shape == (INT_OVERFLOW, 2)
assert A.grad[0][0] == 0
@@ -319,7 +320,7 @@ def test_average():
A.attach_grad()
with mx.autograd.record():
B = np.average(A)
- assert B.asnumpy() == 1
+ assert B == 1
B.backward()
assert A.grad.shape == (INT_OVERFLOW, 2)
assert_almost_equal(A.grad[0][0], np.array([1.0 / DOUBLE_INT_OVERFLOW]), \
@@ -551,7 +552,7 @@ def test_constraint_check():
A = np.ones((2, INT_OVERFLOW))
constraint = (A > 0)
B = npx.constraint_check(constraint)
- assert B.asnumpy() == True
+ assert B == True
# broken
@use_np
@@ -735,6 +736,7 @@ def test_slice():
assert B.shape == (100, 2)
assert B[0][0] == 2
+@use_np
def test_smooth_l1():
A = np.arange((INT_OVERFLOW))
A.attach_grad()
@@ -744,4 +746,226 @@ def test_smooth_l1():
assert B[1] == 0.5
B.backward()
assert A.grad.shape == (INT_OVERFLOW, )
+ assert A.grad[0] == 0
+
+@use_np
[email protected](reason='np.random broken on large tensor; npx.random \
+ to be re-examined after np.random is fixed')
+def test_random():
+ prob = np.random.uniform(size=(INT_OVERFLOW, 2))
+ A = npx.random.bernoulli(prob=prob, size=(INT_OVERFLOW, 2))
+ assert A.shape == (INT_OVERFLOW, 2)
+ assert int((A == 0).sum() + (A == 1).sum()) == A.size
+
+@use_np
+def test_gamma():
+ A = np.ones((2, INT_OVERFLOW))
+ A[0][0] = 5
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.gamma(A)
+ assert B.shape == (2, INT_OVERFLOW)
+ assert B[0][0] == 24
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW)
+ assert_almost_equal(A.grad[0][0], np.array([36.1428]), \
+ rtol=1e-3, atol=1e-5)
+
+@use_np
+def test_gammaln():
+ A = np.ones((2, INT_OVERFLOW))
+ A[0][0] = 5
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.gammaln(A)
+ assert B.shape == (2, INT_OVERFLOW)
+ assert_almost_equal(B[0][0], np.array([np.log(24)]), \
+ rtol=1e-3, atol=1e-5)
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW)
+ assert_almost_equal(A.grad[0][0], np.array([1.5061178]), \
+ rtol=1e-3, atol=1e-5)
+@use_np
+def test_digamma():
+ A = np.ones((2, INT_OVERFLOW))
+ A[0][0] = 5
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.digamma(A)
+ assert B.shape == (2, INT_OVERFLOW)
+ assert_almost_equal(B[0][0], np.array([1.5061178]), \
+ rtol=1e-3, atol=1e-5)
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW)
+ assert_almost_equal(A.grad[0][0], np.array([0.22132295]), \
+ rtol=1e-3, atol=1e-5)
+
+@use_np
[email protected](reason='broken on large tensors; also backward errors out')
+def test_rnn():
+ def batch_check(x, modes, params):
+ for m, p in zip(modes, params):
+ #x.attach_grad()
+ #with mx.autograd.record():
+ y = npx.rnn(data=x, parameters=p, mode=m, \
+ state=np.random.normal(0, 1, (1, 4, 1)), \
+ state_size=1, num_layers=1)
+ assert y.shape == (INT_OVERFLOW, 4, 1)
+ assert type(y).__name__ == 'ndarray'
+ #y.backward()
+ #assert x.grad.shape == x.shape
+ #assert type(x.grad[0]).__name__ == 'ndarray'
+ data = np.random.normal(0, 1, (INT_OVERFLOW, 4, 4))
+ modes = ['rnn_relu', 'rnn_tanh', 'gru']
+ params = [np.random.normal(0, 1, (7,)), \
+ np.random.normal(0, 1, (7,)), \
+ np.random.normal(0, 1, (21,))]
+ batch_check(data, modes, params)
+ # check lstm seperately because it has an extra param
+ out = npx.rnn(data=data, parameters=np.random.normal(0, 1, (28,)), \
+ mode='lstm', \
+ state=np.random.normal(0, 1, (1, 4, 1)), \
+ state_cell=np.random.normal(0, 1, (1, 4, 1)), \
+ state_size=1, num_layers=1)
+ assert out.shape == (INT_OVERFLOW, 4, 1)
+ assert type(out[0]).__name__ == 'ndarray'
+
+@use_np
[email protected](reason='backward errors out')
+def test_ctc_loss():
+ A = np.ones((2, INT_OVERFLOW, 4))
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.ctc_loss(A, np.ones((INT_OVERFLOW, 2)))
+ assert B.shape == (INT_OVERFLOW, )
+ assert type(B).__name__ == 'ndarray'
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW, 4)
+ assert A.grad[0][0][0] == 0
+
+@use_np
+def test_erf():
+ A = np.ones((2, INT_OVERFLOW))
+ A[0][0] = 10
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.erf(A)
+ assert B.shape == (2, INT_OVERFLOW)
+ assert B[0][0] == 1
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW)
+ assert_almost_equal(A.grad[0][0], np.array([4.2e-44]), \
+ rtol=1e-3, atol=1e-5)
+
+@use_np
+def test_erfinv():
+ A = np.ones((2, INT_OVERFLOW))
+ A[0][0] = 0.5
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.erfinv(A)
+ assert B.shape == (2, INT_OVERFLOW)
+ assert_almost_equal(B[0][0], np.array([0.47693628]), \
+ rtol=1e-3, atol=1e-5)
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW)
+ assert_almost_equal(A.grad[0][0], np.array([1.112585]), \
+ rtol=1e-3, atol=1e-5)
+
+@use_np
+def test_index_add():
+ A = np.zeros((2, INT_OVERFLOW))
+ ind = np.array([[0, 0], [0, 1]], dtype='int32')
+ val = np.array([100, 200])
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.index_add(A, ind, val)
+ assert B.shape == (2, INT_OVERFLOW)
+ assert B[0][0] == 100 and B[0][1] == 200
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW)
+ assert A.grad[0][0] == 1
+
+@use_np
+def test_index_update():
+ A = np.zeros((2, INT_OVERFLOW))
+ ind = np.array([[0, 0], [0, 1]], dtype='int32')
+ val = np.array([100, 200])
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.index_update(A, ind, val)
+ assert B.shape == (2, INT_OVERFLOW)
+ assert B[0][0] == 100 and B[0][1] == 200
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW)
assert A.grad[0][0] == 0
+
+@use_np
+def test_layer_norm():
+ A = np.ones((2, INT_OVERFLOW))
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.layer_norm(A, gamma=np.ones((2)), beta=np.zeros((2)), axis=0)
+ assert B.shape == (2, INT_OVERFLOW)
+ assert B[0][0] == 0
+ B.backward()
+ assert A.grad.shape == (2, INT_OVERFLOW)
+ assert A.grad[0][0] == 0
+
+@use_np
+def test_dlpack():
+ A = np.ones((2, INT_OVERFLOW))
+ A[0][100] = 100
+ B = npx.to_dlpack_for_read(A)
+ assert type(B).__name__ == 'PyCapsule'
+ C = npx.from_dlpack(B)
+ assert type(C).__name__ == 'ndarray'
+ assert C.shape == (2, INT_OVERFLOW)
+ assert C[0][100] == 100
+ B = npx.to_dlpack_for_write(A)
+ assert type(B).__name__ == 'PyCapsule'
+ C = npx.from_dlpack(B)
+ C += 1
+ assert type(C).__name__ == 'ndarray'
+ assert C.shape == (2, INT_OVERFLOW)
+ assert C[0][100] == 101
+
+@use_np
[email protected](reason='broken on large tensors')
+#TODO add 3d pooling test after large tensor is fixed
+def test_pooling():
+ A = np.ones((1, 2, INT_OVERFLOW))
+ A[0][0][2] = 100
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.pooling(data=A, kernel=(2), stride=2, pool_type='max')
+ assert B.shape == (1, 2, HALF_INT_OVERFLOW)
+ assert B[0][0][1] == 100
+ B.backward()
+ assert A.grad.shape == (1, 2, INT_OVERFLOW)
+ assert A.grad[0][0][0] == 1
+
+@use_np
[email protected](reason='forward gives wrong value on large tensor')
+def test_roi_pooling():
+ A = np.ones((1, 1, 5, INT_OVERFLOW))
+ A[0][0][0][2] = 100
+ roi = np.array([[0, 0, 0, 3, 3]])
+ A.attach_grad()
+ with mx.autograd.record():
+ B = npx.roi_pooling(A, roi, pooled_size=(2, 2), spatial_scale=1)
+ assert B.shape == (1, 1, 2, 2)
+ assert B[0][0][0][1] == 100
+ B.backward()
+ assert A.grad.shape == (1, 1, 5, INT_OVERFLOW)
+ assert A.grad[0][0][0][0] == 1
+
+@use_np
[email protected](reason='times out on (generally speaking) large tensors')
+def test_save_load():
+ A = np.ones((2, INT_OVERFLOW), dtype='int8')
+ A[0][100] = 100
+ npx.save('my_tensor', A)
+ B = np.array(npx.load('my_tensor'))
+ assert B[0].shape == (2, INT_OVERFLOW)
+ assert B[0][0][100] == 100