This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c2c5d4  [numpy] Fix mean, prod with input of empty array  (#18286)
9c2c5d4 is described below

commit 9c2c5d45a6c838bd8348ce43a6168cf77d5c7125
Author: Yiyan66 <[email protected]>
AuthorDate: Mon May 25 00:14:58 2020 +0800

    [numpy] Fix mean, prod with input of empty array  (#18286)
    
    * prod
    
    * mean
    
    * nan
    
    * sanity
    
    * change kernel
    
    * include
    
    Co-authored-by: Ubuntu <[email protected]>
---
 src/operator/numpy/np_broadcast_reduce_op.h        | 39 ++++++++++++++++++++--
 .../python/unittest/test_numpy_interoperability.py |  2 ++
 2 files changed, 38 insertions(+), 3 deletions(-)

diff --git a/src/operator/numpy/np_broadcast_reduce_op.h 
b/src/operator/numpy/np_broadcast_reduce_op.h
index d10e32a..6b59ac0 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -275,6 +275,16 @@ inline bool NeedSafeAcc(int itype, int otype) {
   return safe_acc_hint && rule;
 }
 
+namespace mxnet_op {
+struct set_to_nan {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(index_t i, DType *out) {
+    out[i] = DType(nanf(""));
+  }
+};
+
+}  // namespace mxnet_op
+
 void TVMOpReduce(const OpContext& ctx, const TBlob& input,
                  const dmlc::optional<mxnet::Tuple<int>>& axis,
                  const TBlob& output, const OpReqType req, const std::string& 
reducer_name);
@@ -296,9 +306,32 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
   if (outputs[0].shape_.Size() == 0) return;
   if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) {
     using namespace mxnet_op;
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      Kernel<set_zero, xpu>::Launch(s, outputs[0].shape_.Size(), 
outputs[0].dptr<DType>());
-    });
+    if (normalize) {
+      LOG(WARNING) << "WARNING: Mean of empty slice.";
+      if (mxnet::common::is_float(outputs[0].type_flag_)) {
+        MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+          Kernel<set_to_nan, xpu>::Launch(s, outputs[0].shape_.Size(),
+                                          outputs[0].dptr<DType>());
+        });
+      } else {
+        LOG(WARNING) << "WARNING: nan is outside the range of"<<
+                        "representable values of type 'int'";
+        MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+          Kernel<set_zero, xpu>::Launch(s, outputs[0].shape_.Size(),
+                                        outputs[0].dptr<DType>());
+        });
+      }
+    } else if (std::is_same<reducer, mshadow_op::sum>::value) {
+      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+        Kernel<set_zero, xpu>::Launch(s, outputs[0].shape_.Size(),
+                                      outputs[0].dptr<DType>());
+      });
+    } else {
+      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+        Kernel<set_one, xpu>::Launch(s, outputs[0].shape_.Size(),
+                                     outputs[0].dptr<DType>());
+      });
+    }
     return;
   }
   CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place";
diff --git a/tests/python/unittest/test_numpy_interoperability.py 
b/tests/python/unittest/test_numpy_interoperability.py
index 342372c..0060b73 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -1105,6 +1105,7 @@ def _add_workload_mean(array_pool):
     OpArgMngr.add_workload('mean', array_pool['4x1'])
     OpArgMngr.add_workload('mean', array_pool['4x1'], axis=0, keepdims=True)
     OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]))
+    OpArgMngr.add_workload('mean', np.array([]).reshape(2,0,0))
     OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=0)
     OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=1)
 
@@ -1139,6 +1140,7 @@ def _add_workload_atleast_nd():
 
 def _add_workload_prod(array_pool):
     OpArgMngr.add_workload('prod', array_pool['4x1'])
+    OpArgMngr.add_workload('prod', np.array([]).reshape(2,0,0))
 
 
 def _add_workload_product(array_pool):

Reply via email to