This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3e79aef Use argmax instead of argmax_channel in Accuracy to keep
dimention (#8245)
3e79aef is described below
commit 3e79aefba36889d800d56c2048e6dd9ff0adbe54
Author: BenoƮt Quartier <[email protected]>
AuthorDate: Mon Feb 5 22:43:53 2018 +0100
Use argmax instead of argmax_channel in Accuracy to keep dimention (#8245)
Fix github issue 8129
---
scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
index 98a09d2..ed99a1f 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
@@ -107,7 +107,11 @@ class Accuracy extends EvalMetric("accuracy") {
"labels and predictions should have the same length.")
for ((pred, label) <- preds zip labels) {
- val predLabel = NDArray.argmax_channel(pred)
+ val predLabel = if (pred.shape == label.shape) {
+ NDArray.argmax(Map("axis" -> 1, "keepdims" -> true))(pred)
+ } else {
+ NDArray.argmax_channel(pred)
+ }
require(label.shape == predLabel.shape,
s"label ${label.shape} and prediction ${predLabel.shape}" +
s"should have the same length.")
--
To stop receiving notification emails like this one, please contact
[email protected].