This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8572506 fix_np_where (#18451)
8572506 is described below
commit 85725066767255090b57aec7f3b03628656afbf0
Author: Minghao Liu <[email protected]>
AuthorDate: Wed Jun 3 15:16:24 2020 +0800
fix_np_where (#18451)
---
src/api/operator/numpy/np_where_op.cc | 2 +-
tests/python/unittest/test_numpy_op.py | 7 +++++++
2 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/src/api/operator/numpy/np_where_op.cc
b/src/api/operator/numpy/np_where_op.cc
index a2ed14b..aca4e07 100644
--- a/src/api/operator/numpy/np_where_op.cc
+++ b/src/api/operator/numpy/np_where_op.cc
@@ -76,7 +76,7 @@ inline static void _npi_where_scalar2(runtime::MXNetArgs args,
op::NumpyWhereScalar2Param param;
nnvm::NodeAttrs attrs;
param.x = args[1].operator double();
- param.x = args[2].operator double();
+ param.y = args[2].operator double();
attrs.op = op;
attrs.parsed = param;
SetAttrDict<op::NumpyWhereScalar2Param>(&attrs);
diff --git a/tests/python/unittest/test_numpy_op.py
b/tests/python/unittest/test_numpy_op.py
index 2247700..441c727 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -9156,6 +9156,13 @@ def test_np_where():
same(ret.asnumpy(), _np.where(cond.asnumpy(), x.asnumpy(), 1))
ret_rscalar.backward()
same(x.grad.asnumpy(),
collapse_sum_like(_np.broadcast_to(cond.asnumpy(), ret.shape), shape_pair[1]))
+
+ # check both scalar case
+ x = _np.random.randint(0, 100)
+ y = _np.random.randint(0, 100)
+ mx_out = np.where(cond, x, y)
+ np_out = _np.where(cond, x, y)
+ same(mx_out, np_out)
@with_seed()