This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 3e79aef Use argmax instead of argmax_channel in Accuracy to keep dimention (#8245) 3e79aef is described below commit 3e79aefba36889d800d56c2048e6dd9ff0adbe54 Author: BenoƮt Quartier <benoit.quart...@a3.epfl.ch> AuthorDate: Mon Feb 5 22:43:53 2018 +0100 Use argmax instead of argmax_channel in Accuracy to keep dimention (#8245) Fix github issue 8129 --- scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala index 98a09d2..ed99a1f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala @@ -107,7 +107,11 @@ class Accuracy extends EvalMetric("accuracy") { "labels and predictions should have the same length.") for ((pred, label) <- preds zip labels) { - val predLabel = NDArray.argmax_channel(pred) + val predLabel = if (pred.shape == label.shape) { + NDArray.argmax(Map("axis" -> 1, "keepdims" -> true))(pred) + } else { + NDArray.argmax_channel(pred) + } require(label.shape == predLabel.shape, s"label ${label.shape} and prediction ${predLabel.shape}" + s"should have the same length.") -- To stop receiving notification emails like this one, please contact liuyi...@apache.org.