Repository: spark
Updated Branches:
  refs/heads/branch-2.1 a1112c615 -> f7a91a17e


[SPARK-20615][ML][TEST] SparseVector.argmax throws IndexOutOfBoundsException

## What changes were proposed in this pull request?

Added a check for for the number of defined values.  Previously the argmax 
function assumed that at least one value was defined if the vector size was 
greater than zero.

## How was this patch tested?

Tests were added to the existing VectorsSuite to cover this case.

Author: Jon McLean <jon.mcl...@atsid.com>

Closes #17877 from jonmclean/vectorArgmaxIndexBug.

(cherry picked from commit be53a78352ae7c70d8a07d0df24574b3e3129b4a)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f7a91a17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f7a91a17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f7a91a17

Branch: refs/heads/branch-2.1
Commit: f7a91a17e8e20965b3e634e611690a96f72cec6b
Parents: a1112c6
Author: Jon McLean <jon.mcl...@atsid.com>
Authored: Tue May 9 09:47:50 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Tue May 9 09:48:09 2017 +0100

----------------------------------------------------------------------
 .../src/main/scala/org/apache/spark/ml/linalg/Vectors.scala   | 2 ++
 .../test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala  | 7 +++++++
 .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala    | 2 ++
 .../scala/org/apache/spark/mllib/linalg/VectorsSuite.scala    | 7 +++++++
 4 files changed, 18 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f7a91a17/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git 
a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala 
b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 22e4ec6..7bc2cb1 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -657,6 +657,8 @@ class SparseVector @Since("2.0.0") (
   override def argmax: Int = {
     if (size == 0) {
       -1
+    } else if (numActives == 0) {
+      0
     } else {
       // Find the max active entry.
       var maxIdx = indices(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/f7a91a17/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala 
b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
index ea22c27..bd71656 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
@@ -125,6 +125,13 @@ class VectorsSuite extends SparkMLFunSuite {
 
     val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
     assert(vec8.argmax === 0)
+
+    // Check for case when sparse vector is non-empty but the values are empty
+    val vec9 = Vectors.sparse(100, Array.empty[Int], 
Array.empty[Double]).asInstanceOf[SparseVector]
+    assert(vec9.argmax === 0)
+
+    val vec10 = Vectors.sparse(1, Array.empty[Int], 
Array.empty[Double]).asInstanceOf[SparseVector]
+    assert(vec10.argmax === 0)
   }
 
   test("vector equals") {

http://git-wip-us.apache.org/repos/asf/spark/blob/f7a91a17/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 63ea9d3..5282849 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -846,6 +846,8 @@ class SparseVector @Since("1.0.0") (
   override def argmax: Int = {
     if (size == 0) {
       -1
+    } else if (numActives == 0) {
+      0
     } else {
       // Find the max active entry.
       var maxIdx = indices(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/f7a91a17/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 71a3cea..6172cff 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -122,6 +122,13 @@ class VectorsSuite extends SparkFunSuite with Logging {
 
     val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
     assert(vec8.argmax === 0)
+
+    // Check for case when sparse vector is non-empty but the values are empty
+    val vec9 = Vectors.sparse(100, Array.empty[Int], 
Array.empty[Double]).asInstanceOf[SparseVector]
+    assert(vec9.argmax === 0)
+
+    val vec10 = Vectors.sparse(1, Array.empty[Int], 
Array.empty[Double]).asInstanceOf[SparseVector]
+    assert(vec10.argmax === 0)
   }
 
   test("vector equals") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to