asitstands commented on a change in pull request #10048: [MXNET-68] Random 
shuffle implementation
URL: https://github.com/apache/incubator-mxnet/pull/10048#discussion_r173621923
 
 

 ##########
 File path: tests/python/unittest/test_random.py
 ##########
 @@ -552,6 +554,56 @@ def compute_expected_prob():
     mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), 
exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
     mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), 
exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)
 
+@with_seed()
+def test_shuffle():
+    def hash(arr):
+        ret = 0
+        for i, n in enumerate(arr):
+            ret += int(n.asscalar()) * (arr.size ** i)
+        return ret
+
+    def check_first_axis_shuffle(arr):
+        stride = int(arr.size / arr.shape[0])
+        column0 = arr.reshape((arr.size,))[::stride].sort()
+        seq = mx.nd.arange(0, arr.size - stride + 1, stride, ctx=arr.context)
+        assert (column0 == seq).prod() == 1
+        for i in range(arr.shape[0]):
+            subarr = arr[i].reshape((arr[i].size,))
+            start = subarr[0].asscalar()
+            seq = mx.nd.arange(start, start + stride, ctx=arr.context)
+            assert (subarr == seq).prod() == 1
+
+    # `data` must be a consecutive sequence of integers starting from 0 if it 
is flattened.
+    def test(data, repeat1, repeat2):
+        stride = int(data.size / data.shape[0])
+        # Check that the shuffling is along the first axis
+        for i in range(repeat1):
+            ret = mx.nd.random.shuffle(data)
+            check_first_axis_shuffle(ret)
+        count = {}
+        # Count the number of each outcome
+        for i in range(repeat2):
+            ret = mx.nd.random.shuffle(data)
+            h = hash(ret.reshape((ret.size,))[::stride])
+            c = count.get(h, 0)
+            count[h] = c + 1
+        # Check the total number of possible outcomes
+        assert len(count) == math.factorial(data.shape[0])
+        # The outcomes must be uniformly distributed
+        for p in itertools.permutations(range(0, data.size - stride + 1, 
stride)):
+            assert abs(count[hash(mx.nd.array(p))] / repeat2 - 1 / 
math.factorial(data.shape[0])) < 0.01
+        # Check symbol interface
+        a = mx.sym.Variable('a')
+        b = mx.sym.random.shuffle(a)
+        c = mx.sym.random.shuffle(data=b, name='c')
+        d = mx.sym.sort(c, axis=0)
+        assert (d.eval(a=data, ctx=mx.current_context())[0] == data).prod() == 
1
+
+    test(mx.nd.arange(0, 3), 10, 20000)
 
 Review comment:
   Verifying the uniformity of the distribution of the outcomes needs very 
large number of samplings for larger arrays. The needed number of samples grows 
factorially with the length of the first axis. I added a weaker test for larger 
arrays.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to