haojin2 commented on a change in pull request #15893: Numpy kron operator
URL: https://github.com/apache/incubator-mxnet/pull/15893#discussion_r314178597
 
 

 ##########
 File path: tests/python/unittest/test_numpy_op.py
 ##########
 @@ -829,6 +829,57 @@ def get_indices(axis_size):
                     assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, 
atol=1e-5)
 
 
+@with_seed()
+@use_np
+def test_np_kron():
+    class TestKron(HybridBlock):
+        def __init__(self):
+            super(TestKron, self).__init__()
+
+        def hybrid_forward(self, F, a, b):
+            return F.np.kron(a, b)
+
+    # test input
+    tensor_shapes = [
+        ((3,), (3,)),
+        ((2, 3), (3,)),
+        ((2, 3, 4), (2,)),
+        ((3, 2), ())
+    ]
+
+    for hybridize in [True, False]:
+        for a_shape, b_shape in tensor_shapes:
+            for dtype in [_np.float32, _np.float64]:
+                test_kron = TestKron()
+                if hybridize:
+                    test_kron.hybridize()
+                a = rand_ndarray(shape=a_shape, dtype=dtype).as_np_ndarray()
+                b = rand_ndarray(shape=b_shape, dtype=dtype).as_np_ndarray()
+                a.attach_grad()
+                b.attach_grad()
+
+                np_out = _np.kron(a.asnumpy(), b.asnumpy())
+                with mx.autograd.record():
+                    mx_out = test_kron(a, b)
+                assert mx_out.shape == np_out.shape
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, 
atol=1e-5)
+                mx_out.backward()
+
+                # Test imperative once again
+                mx_out = np.kron(a, b)
+                np_out = _np.kron(a.asnumpy(), b.asnumpy())
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, 
atol=1e-5)
+
+                # test numeric gradient
+                if (len(a_shape) > 0 and len(b_shape) > 0 and 
_np.prod(a_shape) > 0 
+                  and _np.prod(b_shape) > 0):
+                    a_sym = mx.sym.Variable("a").as_np_ndarray()
+                    b_sym = mx.sym.Variable("b").as_np_ndarray()
+                    mx_sym = mx.sym.np.kron(a_sym, b_sym).as_nd_ndarray()
+                    check_numeric_gradient(mx_sym, [a.as_nd_ndarray(), 
b.as_nd_ndarray()],
+                      rtol=1e-1, atol=1e-1, dtype=dtype)
 
 Review comment:
   Have you tried smaller numbers for `atol` and `rtol` ? Something like 
`rtol=1e-2` and `atol=1e-4`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to