This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9d44086 [Numpy] Port nd.random.multinomial to npx.sample_categorical
(#18272)
9d44086 is described below
commit 9d440868603ad26b702e12ddd2587e5c4b56e42b
Author: Xi Wang <[email protected]>
AuthorDate: Tue May 12 04:57:56 2020 +0800
[Numpy] Port nd.random.multinomial to npx.sample_categorical (#18272)
* port nd.multinomial to npx.sample_categorical
* move to npx.random
---
src/operator/random/sample_multinomial_op.cc | 1 +
tests/python/unittest/test_numpy_op.py | 27 +++++++++++++++++++++++++++
2 files changed, 28 insertions(+)
diff --git a/src/operator/random/sample_multinomial_op.cc
b/src/operator/random/sample_multinomial_op.cc
index bba76ce..f0aa246 100644
--- a/src/operator/random/sample_multinomial_op.cc
+++ b/src/operator/random/sample_multinomial_op.cc
@@ -32,6 +32,7 @@ DMLC_REGISTER_PARAMETER(SampleMultinomialParam);
NNVM_REGISTER_OP(_sample_multinomial)
.add_alias("sample_multinomial")
+.add_alias("_npx__random_categorical")
.describe(R"code(Concurrent sampling from multiple multinomial distributions.
*data* is an *n* dimensional array whose last dimension has length *k*, where
diff --git a/tests/python/unittest/test_numpy_op.py
b/tests/python/unittest/test_numpy_op.py
index bb07a57..3472481 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -4545,6 +4545,33 @@ def test_np_multivariate_normal():
@with_seed()
@use_np
+def test_npx_categorical():
+ class TestNumpyCategorical(HybridBlock):
+ def __init__(self, size=None):
+ super(TestNumpyCategorical, self).__init__()
+ self.size = size
+
+ def hybrid_forward(self, F, prob):
+ if self.size is None:
+ return F.npx.random.categorical(prob)
+ return F.npx.random.categorical(prob, shape=self.size)
+
+ batch_sizes = [(2,), (2, 3)]
+ event_shapes = [None, (10,), (10, 12)]
+ num_event = [2, 4, 10]
+ for batch_size, num_event, event_shape in itertools.product(batch_sizes,
num_event, event_shapes):
+ for hybridize in [True, False]:
+ prob = np.ones(batch_size + (num_event,)) / num_event
+ net = TestNumpyCategorical(event_shape)
+ if hybridize:
+ net.hybridize()
+ mx_out = net(prob)
+ desired_shape = batch_size + event_shape if event_shape is not
None else batch_size
+ assert mx_out.shape == desired_shape
+
+
+@with_seed()
+@use_np
def test_random_seed():
for seed in [234, 594, 7240, 20394]:
ret = []