This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 5ba285b Fix sample_multinomial number of outputs bug (#14873)
5ba285b is described below
commit 5ba285bec12a6a9aed1e0f27e5c81f6e7f3b3540
Author: reminisce <[email protected]>
AuthorDate: Fri May 3 23:28:04 2019 -0700
Fix sample_multinomial number of outputs bug (#14873)
* Fix sample_multinomial number of outputs bug
* Fix lint
---
src/operator/random/sample_multinomial_op.h | 7 +++++--
tests/python/unittest/test_random.py | 12 ++++++++++++
2 files changed, 17 insertions(+), 2 deletions(-)
diff --git a/src/operator/random/sample_multinomial_op.h
b/src/operator/random/sample_multinomial_op.h
index b38aefb..377df4f 100644
--- a/src/operator/random/sample_multinomial_op.h
+++ b/src/operator/random/sample_multinomial_op.h
@@ -68,7 +68,7 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs&
attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
- if (!shape_is_known(ishape)) return false;
+ if (!ndim_is_known(ishape)) return false;
MSHADOW_TYPE_SWITCH(param.dtype, DType, {
CHECK_LE(ishape[ishape.ndim() - 1],
mxnet::common::MaxIntegerValue<DType>())
@@ -95,7 +95,10 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs&
attrs,
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
if (param.get_prob) SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape);
- return shape_is_known(out_attrs->at(0)) && shape_is_known(out_attrs->at(1));
+ for (const auto& out_shape : *out_attrs) {
+ if (!shape_is_known(out_shape)) return false;
+ }
+ return true;
}
diff --git a/tests/python/unittest/test_random.py
b/tests/python/unittest/test_random.py
index 8fbd97d..5e809d3 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -916,6 +916,18 @@ def test_randint_without_dtype():
a = mx.nd.random.randint(low=50000000, high=50000010,
ctx=mx.context.current_context())
assert a.dtype == np.int32
+
+@with_seed()
+def test_sample_multinomial_num_outputs():
+ ctx = mx.context.current_context()
+ probs = [[0.125, 0.25, 0.25], [0.0625, 0.125, 0.1875]]
+ out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx),
shape=10000, get_prob=False)
+ assert isinstance(out, mx.nd.NDArray)
+ out = mx.nd.random.multinomial(data=mx.nd.array(probs, ctx=ctx),
shape=10000, get_prob=True)
+ assert isinstance(out, list)
+ assert len(out) == 2
+
+
if __name__ == '__main__':
import nose
nose.runmodule()