asitstands commented on a change in pull request #10048: [MXNET-68] Random 
shuffle implementation
URL: https://github.com/apache/incubator-mxnet/pull/10048#discussion_r173655956
 
 

 ##########
 File path: tests/python/unittest/test_random.py
 ##########
 @@ -552,6 +554,79 @@ def compute_expected_prob():
     mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), 
exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
     mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), 
exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)
 
+@with_seed()
+def test_shuffle():
+    def hash(arr):
+        ret = 0
+        for i, n in enumerate(arr):
+            ret += int(n.asscalar()) * (arr.size ** i)
+        return ret
+
+    def check_first_axis_shuffle(arr):
+        stride = int(arr.size / arr.shape[0])
+        column0 = arr.reshape((arr.size,))[::stride].sort()
+        seq = mx.nd.arange(0, arr.size - stride + 1, stride, ctx=arr.context)
+        assert (column0 == seq).prod() == 1
+        for i in range(arr.shape[0]):
+            subarr = arr[i].reshape((arr[i].size,))
+            start = subarr[0].asscalar()
+            seq = mx.nd.arange(start, start + stride, ctx=arr.context)
+            assert (subarr == seq).prod() == 1
+
+    # `data` must be a consecutive sequence of integers starting from 0 if it 
is flattened.
+    def testSmall(data, repeat1, repeat2):
+        # Check that the shuffling is along the first axis.
+        # The order of the elements in each subarray must not change.
+        # This takes long time so only a small number of samples (`repeat1`) 
are checked.
+        for i in range(repeat1):
+            ret = mx.nd.random.shuffle(data)
+            check_first_axis_shuffle(ret)
+        # Count the number of each outcome.
+        # The sequence composed of the first elements of the subarrays is 
enough to discriminate
+        # the outcomes as long as the order of the elements in each subarray 
does not change.
+        count = {}
+        stride = int(data.size / data.shape[0])
+        for i in range(repeat2):
+            ret = mx.nd.random.shuffle(data)
+            h = hash(ret.reshape((ret.size,))[::stride])
+            c = count.get(h, 0)
+            count[h] = c + 1
+        # Check the total number of possible outcomes.
+        assert len(count) == math.factorial(data.shape[0])
 
 Review comment:
   I revised the comments in the test.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to