haojin2 commented on a change in pull request #15861: Numpy det and slogdet 
operator
URL: https://github.com/apache/incubator-mxnet/pull/15861#discussion_r313719958
 
 

 ##########
 File path: tests/python/unittest/test_numpy_op.py
 ##########
 @@ -213,6 +213,94 @@ def test_np_dot():
         assert False
 
 
+def test_np_linalg_det():
+    class TestDet(HybridBlock):
+        def __init__(self):
+            super(TestDet, self).__init__()
+
+        def hybrid_forward(self, F, a):
+            return F.np.linalg.det(a)
+
+    # test non zero size input
+    tensor_shapes = [
+        (5, 5),
+        (3, 3, 3),
+        (2, 2, 2, 2, 2),
+        (1, 1)
+    ]
+
+    for hybridize in [True, False]:
+        for shape in tensor_shapes:
+            for dtype in [_np.float32, _np.float64]:
+                a_shape = (1,) + shape
+                test_det = TestDet()
+                if hybridize:
+                    test_det.hybridize()
+                a = rand_ndarray(shape = a_shape, dtype = 
dtype).as_np_ndarray()
+                a.attach_grad()
+
+                np_out = _np.linalg.det(a.asnumpy())
+                with mx.autograd.record():
+                    mx_out = test_det(a)
+                assert mx_out.shape == np_out.shape
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol = 1e-1, 
atol = 1e-1)
+                mx_out.backward()
+
+                # Test imperative once again
+                mx_out = np.linalg.det(a)
+                np_out = _np.linalg.det(a.asnumpy())
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, 
atol=1e-1)
+
+                # test numeric gradient
+                a_sym = mx.sym.Variable("a").as_np_ndarray()
+                mx_sym = mx.sym.np.linalg.det(a_sym).as_nd_ndarray()
+                check_numeric_gradient(mx_sym, [a.as_nd_ndarray()],
+                    rtol=1e-1, atol=1e-1, dtype = dtype)
 
 Review comment:
   fix all spaces around `=` operators.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to