asitstands commented on a change in pull request #10048: [MXNET-68] Random 
shuffle implementation
URL: https://github.com/apache/incubator-mxnet/pull/10048#discussion_r173656027
 
 

 ##########
 File path: tests/python/unittest/test_random.py
 ##########
 @@ -552,6 +554,79 @@ def compute_expected_prob():
     mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), 
exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
     mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), 
exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)
 
+@with_seed()
+def test_shuffle():
+    def hash(arr):
+        ret = 0
+        for i, n in enumerate(arr):
+            ret += int(n.asscalar()) * (arr.size ** i)
+        return ret
+
+    def check_first_axis_shuffle(arr):
+        stride = int(arr.size / arr.shape[0])
+        column0 = arr.reshape((arr.size,))[::stride].sort()
+        seq = mx.nd.arange(0, arr.size - stride + 1, stride, ctx=arr.context)
+        assert (column0 == seq).prod() == 1
+        for i in range(arr.shape[0]):
+            subarr = arr[i].reshape((arr[i].size,))
+            start = subarr[0].asscalar()
+            seq = mx.nd.arange(start, start + stride, ctx=arr.context)
+            assert (subarr == seq).prod() == 1
+
+    # `data` must be a consecutive sequence of integers starting from 0 if it 
is flattened.
+    def testSmall(data, repeat1, repeat2):
+        # Check that the shuffling is along the first axis.
+        # The order of the elements in each subarray must not change.
+        # This takes long time so only a small number of samples (`repeat1`) 
are checked.
+        for i in range(repeat1):
+            ret = mx.nd.random.shuffle(data)
+            check_first_axis_shuffle(ret)
+        # Count the number of each outcome.
+        # The sequence composed of the first elements of the subarrays is 
enough to discriminate
+        # the outcomes as long as the order of the elements in each subarray 
does not change.
+        count = {}
+        stride = int(data.size / data.shape[0])
+        for i in range(repeat2):
+            ret = mx.nd.random.shuffle(data)
+            h = hash(ret.reshape((ret.size,))[::stride])
+            c = count.get(h, 0)
+            count[h] = c + 1
+        # Check the total number of possible outcomes.
+        assert len(count) == math.factorial(data.shape[0])
+        # The outcomes must be uniformly distributed.
+        for p in itertools.permutations(range(0, data.size - stride + 1, 
stride)):
+            assert abs(count[hash(mx.nd.array(p))] / repeat2 - 1 / 
math.factorial(data.shape[0])) < 0.01
+        # Check symbol interface
+        a = mx.sym.Variable('a')
+        b = mx.sym.random.shuffle(a)
+        c = mx.sym.random.shuffle(data=b, name='c')
+        d = mx.sym.sort(c, axis=0)
+        assert (d.eval(a=data, ctx=mx.current_context())[0] == data).prod() == 
1
+
+    # `data` must be a consecutive sequence of integers starting from 0 if it 
is flattened.
+    # This does not verify the uniformity of the distribution of the outcomes.
+    def testLarge(data, repeat):
+        # Check that the shuffling is along the first axis
+        # and count the number of different outcomes.
+        stride = int(data.size / data.shape[0])
+        count = {}
+        for i in range(repeat):
+            ret = mx.nd.random.shuffle(data)
+            check_first_axis_shuffle(ret)
+            h = hash(ret.reshape((ret.size,))[::stride])
+            c = count.get(h, 0)
+            count[h] = c + 1
+        # The probability of duplicated outcomes is very low for large arrays.
+        assert len(count) == repeat
+
+    # Test small arrays with different shapes
+    testSmall(mx.nd.arange(0, 3), 100, 20000)
+    testSmall(mx.nd.arange(0, 9).reshape((3, 3)), 100, 20000)
+    testSmall(mx.nd.arange(0, 12).reshape((2, 2, 3)), 100, 20000)
 
 Review comment:
   Actually any number >= 2 is equivallent. Anyway I changed it to 3 :)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to