This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5ba285b  Fix sample_multinomial number of outputs bug (#14873)
5ba285b is described below

commit 5ba285bec12a6a9aed1e0f27e5c81f6e7f3b3540
Author: reminisce <[email protected]>
AuthorDate: Fri May 3 23:28:04 2019 -0700

    Fix sample_multinomial number of outputs bug (#14873)
    
    * Fix sample_multinomial number of outputs bug
    
    * Fix lint
---
 src/operator/random/sample_multinomial_op.h |  7 +++++--
 tests/python/unittest/test_random.py        | 12 ++++++++++++
 2 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/src/operator/random/sample_multinomial_op.h 
b/src/operator/random/sample_multinomial_op.h
index b38aefb..377df4f 100644
--- a/src/operator/random/sample_multinomial_op.h
+++ b/src/operator/random/sample_multinomial_op.h
@@ -68,7 +68,7 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& 
attrs,
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
   const mxnet::TShape& ishape = (*in_attrs)[0];
-  if (!shape_is_known(ishape)) return false;
+  if (!ndim_is_known(ishape)) return false;
 
   MSHADOW_TYPE_SWITCH(param.dtype, DType, {
     CHECK_LE(ishape[ishape.ndim() - 1], 
mxnet::common::MaxIntegerValue<DType>())
@@ -95,7 +95,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& 
attrs,
   }
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
   if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape);
-  return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1));
+  for (const auto& out_shape : *out_attrs) {
+    if (!shape_is_known(out_shape)) return false;
+  }
+  return true;
 }
 
 
diff --git a/tests/python/unittest/test_random.py 
b/tests/python/unittest/test_random.py
index 8fbd97d..5e809d3 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -916,6 +916,18 @@ def test_randint_without_dtype():
     a = mx.nd.random.randint(low=50000000, high=50000010, 
ctx=mx.context.current_context())
     assert a.dtype == np.int32
 
+
+@with_seed()
+def test_sample_multinomial_num_outputs():
+    ctx = mx.context.current_context()
+    probs = [[0.125, 0.25, 0.25], [0.0625, 0.125, 0.1875]]
+    out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), 
shape=10000, get_prob=False)
+    assert isinstance(out, mx.nd.NDArray)
+    out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx), 
shape=10000, get_prob=True)
+    assert isinstance(out, list)
+    assert len(out) == 2
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to