chongruo opened a new issue #16101: [Bug,  Feature Request]  mx.nd.where()
URL: https://github.com/apache/incubator-mxnet/issues/16101
 
 
   
   ## Description
   ### Bug 
   
[mx.nd.where()](https://beta.mxnet.io/api/ndarray/_autogen/mxnet.ndarray.where.html?highlight=where#mxnet.ndarray.where)
 shows an incorrect behavior when one of the inputs is an NDArray with zero 
size. 
   
   Here is a reproducible example
   ```python
   cond = mx.nd.array([0])       # cond.shape: (1,)
   x = mx.nd.array([[10,10]])    #    x.shape: (1, 2)
   y = mx.nd.array(4)            #    y.shape: ()
   
   print( mx.nd.where(cond, x, y) )
   # output: [[4.0000e+00 3.0773e-41]]
   
   ```
   The output is weird and it seems that the NDArray with zero size has not 
been checked. We expect that it would raise an error showing the shape of x and 
y must be the same, according to [docs of 
mx.nd.where()](https://beta.mxnet.io/api/ndarray/_autogen/mxnet.ndarray.where.html?highlight=where#mxnet.ndarray.where).
 Broadcast is not supported in the latest version but where() still has an 
output.
   
    It is also a little dangerous as it outputs incorrect answers rather than 
error messages, when users forget to type [] for ``mx.nd.array([4])``.
   
   
   <br>
   
   ### Feature Request
   #### 1. Broadcast
   
   Currently, there are two limitations for mx.nd.where()
    - x and y must have the same shape
    - If condition does not have the same shape as x, it must be a 1D array 
whose size is the same as x’s first dimension size
   
   Similar to 
[np.where()](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.where.html),
 it would be great if mx.nd.where() supports broadcast to make sure (cond, x, 
y) have the same shape, even if they are in different shapes as input. 
   
   <br>
   
   #### 2.  Scalar inputs (cond, x and y)
   In some situations, we want to give a constant value for True/False.
   
   It would be user-friendly if programmers only need to type 
   ```mx.nd.where(cond, x, 0)``` 
   instead of 
   ```mx.nd.where(cond, x,  mx.nd.array([0]))```
   
   
   
   
   <br>
   <br> 
   <br> 
   <br> 
   <br>
   
   
   ---
   
   ## Environment info (Required)
   
   ```
   ----------Python Info----------
   Version      : 3.6.9
   Compiler     : GCC 7.3.0
   Build        : ('default', 'Jul 30 2019 19:07:31')
   Arch         : ('64bit', '')
   ------------Pip Info-----------
   Version      : 19.2.2
   Directory    : 
/home/ubuntu/anaconda3/envs/new/lib/python3.6/site-packages/pip
   ----------MXNet Info-----------
   Version      : 1.6.0
   Directory    : /home/ubuntu/new/my-mxnet/python/mxnet
   Commit hash file "/home/ubuntu/new/my-mxnet/python/mxnet/COMMIT_HASH" not 
found. Not installed from pre-built package or built from source.
   Library      : 
['/home/ubuntu/new/my-mxnet/python/mxnet/../../build/libmxnet.so']
   Build features:
   No runtime build feature info available
   ----------System Info----------
   Platform     : Linux-4.4.0-1092-aws-x86_64-with-debian-stretch-sid
   system       : Linux
   node         : ip-172-31-14-150
   release      : 4.4.0-1092-aws
   version      : #103-Ubuntu SMP Tue Aug 27 10:21:48 UTC 2019
   ----------Hardware Info----------
   machine      : x86_64
   processor    : x86_64
   Architecture:          x86_64
   CPU op-mode(s):        32-bit, 64-bit
   Byte Order:            Little Endian
   CPU(s):                96
   On-line CPU(s) list:   0-95
   Thread(s) per core:    2
   Core(s) per socket:    24
   Socket(s):             2
   NUMA node(s):          2
   Vendor ID:             GenuineIntel
   CPU family:            6
   Model:                 85
   Model name:            Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz
   Stepping:              4
   CPU MHz:               2499.998
   BogoMIPS:              4999.99
   Hypervisor vendor:     KVM
   Virtualization type:   full
   L1d cache:             32K
   L1i cache:             32K
   L2 cache:              1024K
   L3 cache:              33792K
   NUMA node0 CPU(s):     0-23,48-71
   NUMA node1 CPU(s):     24-47,72-95
   Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge 
mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm 
constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc aperfmperf 
tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic 
movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm 
abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep 
bmi2 erms invpcid rtm mpx avx512f rdseed adx smap clflushopt clwb avx512cd 
xsaveopt xsavec xgetbv1 ida arat pku
   ----------Network Test----------
   Setting timeout: 10
   Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0014 
sec, LOAD: 0.4787 sec.
   Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.1707 sec, LOAD: 
0.2402 sec.
   Timing for Gluon Tutorial(cn): https://zh.gluon.ai, DNS: 0.0228 sec, LOAD: 
0.3108 sec.
   Timing for FashionMNIST: 
https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz,
 DNS: 0.0107 sec, LOAD: 0.1101 sec.
   Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0013 sec, LOAD: 
0.3356 sec.
   Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0135 sec, 
LOAD: 0.0633 sec.
   ----------Environment----------
   ```
   
   
   
   ## Build info (Required if built from source)
   
   Compiler (gcc/clang/mingw/visual studio): gcc
   
   MXNet commit hash: 03f12f0fe706d35c93a2cf721b6101bcbffeb07d
   
   Build config:  plain CMakeList.txt with USE_NCCL=1
   
   
   
   
   
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to