This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8572506  fix_np_where (#18451)
8572506 is described below

commit 85725066767255090b57aec7f3b03628656afbf0
Author: Minghao Liu <[email protected]>
AuthorDate: Wed Jun 3 15:16:24 2020 +0800

    fix_np_where (#18451)
---
 src/api/operator/numpy/np_where_op.cc  | 2 +-
 tests/python/unittest/test_numpy_op.py | 7 +++++++
 2 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/src/api/operator/numpy/np_where_op.cc 
b/src/api/operator/numpy/np_where_op.cc
index a2ed14b..aca4e07 100644
--- a/src/api/operator/numpy/np_where_op.cc
+++ b/src/api/operator/numpy/np_where_op.cc
@@ -76,7 +76,7 @@ inline static void _npi_where_scalar2(runtime::MXNetArgs args,
   op::NumpyWhereScalar2Param param;
   nnvm::NodeAttrs attrs;
   param.x = args[1].operator double();
-  param.x = args[2].operator double();
+  param.y = args[2].operator double();
   attrs.op = op;
   attrs.parsed = param;
   SetAttrDict<op::NumpyWhereScalar2Param>(&attrs);
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 2247700..441c727 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -9156,6 +9156,13 @@ def test_np_where():
             same(ret.asnumpy(), _np.where(cond.asnumpy(), x.asnumpy(), 1))
             ret_rscalar.backward()
             same(x.grad.asnumpy(), 
collapse_sum_like(_np.broadcast_to(cond.asnumpy(), ret.shape), shape_pair[1]))
+        
+        # check both scalar case
+        x = _np.random.randint(0, 100)
+        y = _np.random.randint(0, 100)
+        mx_out = np.where(cond, x, y)
+        np_out = _np.where(cond, x, y)
+        same(mx_out, np_out)
 
 
 @with_seed()

Reply via email to