asitstands commented on a change in pull request #10048: [MXNET-68] Random 
shuffle implementation
URL: https://github.com/apache/incubator-mxnet/pull/10048#discussion_r173656043
 
 

 ##########
 File path: tests/python/unittest/test_random.py
 ##########
 @@ -552,6 +554,79 @@ def compute_expected_prob():
     mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), 
exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
     mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), 
exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)
 
+@with_seed()
+def test_shuffle():
+    def hash(arr):
+        ret = 0
+        for i, n in enumerate(arr):
+            ret += int(n.asscalar()) * (arr.size ** i)
+        return ret
+
+    def check_first_axis_shuffle(arr):
+        stride = int(arr.size / arr.shape[0])
+        column0 = arr.reshape((arr.size,))[::stride].sort()
+        seq = mx.nd.arange(0, arr.size - stride + 1, stride, ctx=arr.context)
+        assert (column0 == seq).prod() == 1
+        for i in range(arr.shape[0]):
+            subarr = arr[i].reshape((arr[i].size,))
+            start = subarr[0].asscalar()
+            seq = mx.nd.arange(start, start + stride, ctx=arr.context)
+            assert (subarr == seq).prod() == 1
+
+    # `data` must be a consecutive sequence of integers starting from 0 if it 
is flattened.
+    def testSmall(data, repeat1, repeat2):
+        # Check that the shuffling is along the first axis.
+        # The order of the elements in each subarray must not change.
+        # This takes long time so only a small number of samples (`repeat1`) 
are checked.
+        for i in range(repeat1):
+            ret = mx.nd.random.shuffle(data)
+            check_first_axis_shuffle(ret)
+        # Count the number of each outcome.
+        # The sequence composed of the first elements of the subarrays is 
enough to discriminate
+        # the outcomes as long as the order of the elements in each subarray 
does not change.
+        count = {}
+        stride = int(data.size / data.shape[0])
+        for i in range(repeat2):
+            ret = mx.nd.random.shuffle(data)
+            h = hash(ret.reshape((ret.size,))[::stride])
+            c = count.get(h, 0)
+            count[h] = c + 1
+        # Check the total number of possible outcomes.
+        assert len(count) == math.factorial(data.shape[0])
+        # The outcomes must be uniformly distributed.
+        for p in itertools.permutations(range(0, data.size - stride + 1, 
stride)):
+            assert abs(count[hash(mx.nd.array(p))] / repeat2 - 1 / 
math.factorial(data.shape[0])) < 0.01
 
 Review comment:
   I forgot the difference between python 2 and 3. Fixed.
   The hash is also changed. Actually the test is faster! Thanks.
   (There seems no way to put an answer to the comment about the hash ?)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to