This is an automated email from the ASF dual-hosted git repository. zha0q1 pushed a commit to branch large_tensor_tests_batch5_clean in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 500fa3e16f934011ef812d6d7b5653a69f61b525 Author: Zhaoqi Zhu <[email protected]> AuthorDate: Mon Nov 16 14:31:52 2020 -0500 Update test_np_large_array.py --- tests/nightly/test_np_large_array.py | 271 +++++++++++++++++++++++++++++++++-- 1 file changed, 257 insertions(+), 14 deletions(-) diff --git a/tests/nightly/test_np_large_array.py b/tests/nightly/test_np_large_array.py index ed53fba..a6eefa9 100644 --- a/tests/nightly/test_np_large_array.py +++ b/tests/nightly/test_np_large_array.py @@ -1312,6 +1312,219 @@ def test_polyval(): assert poly.grad.shape == poly.shape assert poly.grad[0] == 4 + +@use_np +def test_rot90(): + inp = np.zeros((1, 2, INT_OVERFLOW)) + inp[-1, -1, -1] = 1 + inp.attach_grad() + with mx.autograd.record(): + out = np.rot90(inp, axes=(1,2)) + out.backward() + assert out.shape == (1, INT_OVERFLOW, 2) + assert out[0, 0, 1] == 1 + assert inp.grad.shape == inp.shape + assert inp.grad[-1, -1, -1] == 1 + + +@use_np +def test_squeeze(): + inp = np.zeros((2, 1, INT_OVERFLOW)) + inp[-1, -1, -1] = 1 + inp.attach_grad() + with mx.autograd.record(): + out = np.squeeze(inp, axis=1) + out.backward() + assert out.shape == (2, INT_OVERFLOW) + assert out[-1, -1] == 1 + assert inp.grad.shape == inp.shape + assert inp.grad[-1, -1, -1] == 1 + + +@use_np +def test_tile(): + inp = np.array([[0, 1],[2, 3]]) + inp.attach_grad() + with mx.autograd.record(): + out = np.tile(inp, (1, HALF_INT_OVERFLOW)) + out.backward() + assert out.shape == (2, INT_OVERFLOW) + assert out[-1, -1] == 3 + assert inp.grad.shape == inp.shape + assert inp.grad[-1, -1] == HALF_INT_OVERFLOW + + +@use_np +def test_trace(): + N = 2**16 + inp1 = np.eye(N) + inp1.attach_grad() + with mx.autograd.record(): + out1 = np.trace(inp1) + out1.backward() + assert out1 == N + assert inp1.grad.shape == inp1.shape + assert inp1.grad[0, 0] == 1 and inp1.grad[-1, -1] == 1 + inp2 = np.zeros((2, INT_OVERFLOW)) + inp2[-1, -1] = 1 + inp2.attach_grad() + with mx.autograd.record(): + out2 = np.trace(inp2, offset=INT_OVERFLOW-2) + out2.backward() + assert out2 == 1 + assert inp2.grad.shape == inp2.shape + assert inp2.grad[0, -2] == 1 and inp2.grad[-1, -1] == 1 + + +@use_np +def test_tri(): + N = 2**16 + data1 = np.tri(N) + assert data1.shape == (N, N) + assert data1[0, 0] == 1 and data1[-1, -1] == 1 + assert data1[0, -1] == 0 and data1[-1, 0] == 1 + data2 = np.tri(2, INT_OVERFLOW, INT_OVERFLOW-2) + assert data2.shape == (2, INT_OVERFLOW) + assert data2[0, -1] == 0 and data2[-1, -1] == 1 + + +@use_np +def test_tril(): + N = 2**16 + inp1 = np.ones((N, N)) + inp1.attach_grad() + with mx.autograd.record(): + out1 = np.tril(inp1) + out1.backward() + assert out1.shape == (N, N) + assert out1[-1, -1] == 1 and out1[0, -1] == 0 and out1[-1, 0] == 1 + assert inp1.grad.shape == inp1.shape + assert inp1.grad[-1, -1] == 1 and inp1.grad[0, -1] == 0 and \ + inp1.grad[-1, 0] == 1 + inp2 = np.ones((2, INT_OVERFLOW)) + inp2[-1, -1] = 1 + inp2.attach_grad() + with mx.autograd.record(): + out2 = np.tril(inp2, k=INT_OVERFLOW-2) + out2.backward() + assert out2.shape == inp2.shape + assert out2[0, -1] == 0 and out2[-1, -1] == 1 + assert inp2.grad.shape == inp2.shape + assert inp2.grad[0, -1] == 0 and inp2.grad[-1, -1] == 1 + + +@use_np +def test_triu(): + N = 2**16 + inp1 = np.ones((N, N)) + inp1.attach_grad() + with mx.autograd.record(): + out1 = np.triu(inp1) + out1.backward() + assert out1.shape == (N, N) + assert out1[-1, -1] == 1 and out1[0, -1] == 1 and out1[-1, 0] == 0 + assert inp1.grad.shape == inp1.shape + assert inp1.grad[-1, -1] == 1 and inp1.grad[0, -1] == 1 and \ + inp1.grad[-1, 0] == 0 + inp2 = np.ones((2, INT_OVERFLOW)) + inp2[-1, -1] = 1 + inp2.attach_grad() + with mx.autograd.record(): + out2 = np.triu(inp2, k=INT_OVERFLOW-1) + out2.backward() + assert out2.shape == inp2.shape + assert out2[0, -1] == 1 and out2[-1, -1] == 0 + assert inp2.grad.shape == inp2.shape + assert inp2.grad[0, -1] == 1 and inp2.grad[-1, -1] == 0 + + +@use_np +def test_transpose(): + inp = np.zeros((1, 2, INT_OVERFLOW)) + inp[0, 0, -1] = 1 + inp.attach_grad() + with mx.autograd.record(): + out = np.transpose(inp, (2, 0, 1)) + out.backward() + assert out.shape == (INT_OVERFLOW, 1, 2) + assert out[-1, 0, 0] == 1 + assert inp.grad.shape == inp.shape + assert inp.grad[-1, -1, -1] == 1 + + +@use_np +def test_trunc(): + inp = np.zeros((2, INT_OVERFLOW)) + inp[0, -1], inp[1, -1] = 1.9, -1.9 + inp.attach_grad() + with mx.autograd.record(): + out = np.trunc(inp) + out.backward() + assert out.shape == inp.shape + assert out[0, -1] == 1 and out[1, -1] == -1 + assert inp.grad.shape == inp.shape + assert inp.grad[-1, -1] == 0 + + +@use_np +def test_stack(): + inp1 = np.zeros((INT_OVERFLOW)) + inp2 = np.ones((INT_OVERFLOW)) + inp1.attach_grad() + inp2.attach_grad() + with mx.autograd.record(): + out1 = np.stack([inp1, inp2]) + out1.backward() + assert out1.shape == (2, INT_OVERFLOW) + assert out1[0, -1] == 0 and out1[1, -1] == 1 + assert inp1.grad.shape == inp1.shape + assert inp1.grad[-1] == 1 + with mx.autograd.record(): + out2 = np.stack([inp1, inp2], axis=1) + out2.backward() + assert out2.shape == (INT_OVERFLOW, 2) + assert out2[-1, 0] == 0 and out2[-1, 1] == 1 + assert inp2.grad.shape == inp2.shape + assert inp2.grad[-1] == 1 + + +@use_np +def test_dstack(): + inp1 = np.zeros((INT_OVERFLOW, 1)) + inp2 = np.ones((INT_OVERFLOW, 1)) + inp1.attach_grad() + inp2.attach_grad() + with mx.autograd.record(): + out = np.dstack((inp1, inp2)) + out.backward() + assert out.shape == (INT_OVERFLOW, 1, 2) + assert out[0, -1, 0] == 0 and out[0, -1, 1] == 1 + assert inp1.grad.shape == inp1.shape + assert inp1.grad[-1, -1] == 1 + + +@use_np +def test_hstack(): + inp1 = np.zeros((INT_OVERFLOW, 1)) + inp2 = np.ones((INT_OVERFLOW, 1)) + inp1.attach_grad() + inp2.attach_grad() + with mx.autograd.record(): + out1 = np.hstack((inp1, inp2)) + out1.backward() + assert out1.shape == (INT_OVERFLOW, 2) + assert out1[-1, 0] == 0 and out1[-1, 1] == 1 + assert inp1.grad.shape == inp1.shape + assert inp1.grad[-1, -1] == 1 + with mx.autograd.record(): + out2 = np.hstack((inp1.flatten(), inp2.flatten())) + out2.backward() + assert out2.shape == (DOUBLE_INT_OVERFLOW, ) + assert out2[INT_OVERFLOW-1] == 0 and out2[-1] == 1 + assert inp2.grad.shape == inp2.shape + assert inp2.grad[-1, -1] == 1 + + ''' _ _ _ _ _ _ _ __ _ __ _ _ _____ _| |_ ___ _ _ __(_)___ _ _ @@ -1408,23 +1621,53 @@ def test_batch_flatten(): assert A.grad.shape == (2, 1, INT_OVERFLOW) assert A.grad[0][0][0] == 1 -# broken + @use_np [email protected](reason='Does not support large tensor; to be fixed') def test_batch_norm(): - A = np.ones((2, INT_OVERFLOW)) - gamma = np.ones((2)) - beta = np.zeros((2)) - mov_mean = np.ones((2)) - mov_var = np.ones((2)) - A.attach_grad() + inp = np.zeros((2, INT_OVERFLOW)) + gamma = np.array([1.5, 2.5]) + beta = np.array([0.3, 0.6]) + mov_mean = np.array([0.4, 0.8]) + mov_var = np.array([0.6, 1.2]) + eps = 1e-5 + inp[0, -1], inp[1, -1] = 3, 6 + inp.attach_grad() with mx.autograd.record(): - B = npx.batch_norm(A, gamma, beta, mov_mean, mov_var) - assert B.shape == (2, INT_OVERFLOW) - assert B[0][0] == 0 - B.backward() - assert A.grad.shape == (2, INT_OVERFLOW) - assert A.grad[0][0] == 0 + out = npx.batch_norm(inp, gamma=gamma, beta=beta, moving_mean=mov_mean,\ + moving_var=mov_var, axis=0, eps=eps, use_global_stats=True) + out.backward() + assert out.shape == inp.shape + ref0 = (inp[0, -1] - mov_mean[0]) / (mov_var[0] + eps)**0.5 * gamma[0] + beta[0] + ref1 = (inp[1, -1] - mov_mean[1]) / (mov_var[1] + eps)**0.5 * gamma[1] + beta[1] + assert_almost_equal(out[0, -1], ref0, rtol=1e-3, atol=1e-5) + assert_almost_equal(out[1, -1], ref1, rtol=1e-3, atol=1e-5) + assert inp.grad.shape == inp.shape + grad_ref0 = gamma[0] / (mov_var[0] + eps)**0.5 + grad_ref1 = gamma[1] / (mov_var[1] + eps)**0.5 + assert_almost_equal(inp.grad[0, -1], grad_ref0, rtol=1e-3, atol=1e-5) + assert_almost_equal(inp.grad[1, -1], grad_ref1, rtol=1e-3, atol=1e-5) + + +@use_np +def test_batch_norm_mean_var(): + N = 2**20 + inp = np.zeros((2, INT_OVERFLOW), dtype='float64') + gamma = np.array([1, 1], dtype='float64') + beta = np.array([0, 0], dtype='float64') + mov_mean = np.array([0, 0], dtype='float64') + mov_var = np.array([1, 1], dtype='float64') + eps = 0 + inp[1, -1] = N + with mx.autograd.record(): + out, mean, var = npx.batch_norm(inp, gamma=gamma, beta=beta, moving_mean=mov_mean,\ + moving_var=mov_var, axis=0, eps=eps, output_mean_var=True) + assert out.shape == inp.shape + mean_ref = float(N) / INT_OVERFLOW + std_ref = ((INT_OVERFLOW-1) * (mean_ref-0)**2 + (mean_ref-N)**2) / INT_OVERFLOW + out_ref = (N - mean_ref) / (std_ref**0.5) + assert_almost_equal(mean[1], mean_ref, rtol=1e-3, atol=1e-5) + assert_almost_equal(var[1], 1 / std_ref**0.5, rtol=1e-3, atol=1e-5) + assert_almost_equal(out[1, -1], out_ref, rtol=1e-3, atol=1e-5) @use_np
