stu1130 commented on a change in pull request #15219: Numpy compatible 
multinomial
URL: https://github.com/apache/incubator-mxnet/pull/15219#discussion_r294097945
 
 

 ##########
 File path: tests/python/unittest/test_numpy_ndarray.py
 ##########
 @@ -625,6 +625,52 @@ def convert(num):
             test_setitem_autograd(np_array, index)
 
 
+@retry(5)
+@with_seed()
[email protected]_np_shape
+def test_np_multinomial():
+    pvals_list = [[0.0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.0]]
+    sizes = [None, (), (3,), (2, 5, 7), (4, 9)]
+    experiements = 10000
+    for pvals_type in [list, _np.ndarray]:
+        for have_size in [False, True]:
+            for pvals in pvals_list:
+                if have_size:
+                    for size in sizes:
+                        if pvals_type == mx.nd.NDArray:
+                            pvals = mx.nd.array(pvals).as_np_ndarray()
+                        elif pvals_type == _np.ndarray:
+                            pvals = _np.array(pvals)
+                        freq = mx.np.random.multinomial(experiements, pvals, 
size=size).asnumpy() / _np.float32(experiements)
+                        # for those cases that didn't need reshape
+                        if size in [None, ()]:
+                            mx.test_utils.assert_almost_equal(freq, pvals, 
rtol=0.20, atol=1e-1)
+                        else:
+                            # check the shape
+                            assert freq.shape == size + (len(pvals),), 
'freq.shape={}, size + (len(pvals))={}'.format(freq.shape, size + (len(pvals)))
+                            freq = freq.reshape((-1, len(pvals)))
+                            # check the value for each row
+                            for i in range(freq.shape[0]):
+                                mx.test_utils.assert_almost_equal(freq[i, :], 
pvals, rtol=0.20, atol=1e-1)
+                else:
+                    freq = mx.np.random.multinomial(experiements, 
pvals).asnumpy() / _np.float32(experiements)
+                    mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, 
atol=1e-1)
+    # check the zero dimension
+    sizes = [(0), (0, 2), (4, 0, 2), (3, 0, 1, 2, 0)]
+    for pvals in pvals_list:
+        for size in sizes:
+            freq = mx.np.random.multinomial(experiements, pvals, 
size=size).asnumpy()
+            assert freq.size == 0
+    # check [] as pvals
+    for pvals in [[], ()]:
+        freq = mx.np.random.multinomial(experiements, pvals).asnumpy()
+        assert freq.size == 0
+        for size in sizes:
+            freq = mx.np.random.multinomial(experiements, pvals, 
size=size).asnumpy()
+            assert freq.size == 0
+
 
 Review comment:
   done!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to