This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e79aef  Use argmax instead of argmax_channel in Accuracy to keep 
dimention (#8245)
3e79aef is described below

commit 3e79aefba36889d800d56c2048e6dd9ff0adbe54
Author: BenoƮt Quartier <benoit.quart...@a3.epfl.ch>
AuthorDate: Mon Feb 5 22:43:53 2018 +0100

    Use argmax instead of argmax_channel in Accuracy to keep dimention (#8245)
    
    Fix github issue 8129
---
 scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala 
b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
index 98a09d2..ed99a1f 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
@@ -107,7 +107,11 @@ class Accuracy extends EvalMetric("accuracy") {
       "labels and predictions should have the same length.")
 
     for ((pred, label) <- preds zip labels) {
-      val predLabel = NDArray.argmax_channel(pred)
+      val predLabel = if (pred.shape == label.shape) {
+        NDArray.argmax(Map("axis" -> 1, "keepdims" -> true))(pred)
+      } else {
+        NDArray.argmax_channel(pred)
+      }
       require(label.shape == predLabel.shape,
         s"label ${label.shape} and prediction ${predLabel.shape}" +
         s"should have the same length.")

-- 
To stop receiving notification emails like this one, please contact
liuyi...@apache.org.

Reply via email to