Repository: spark Updated Branches: refs/heads/master 8782eb992 -> 4554529dc
[SPARK-4406] [MLib] FIX: Validate k in SVD Raise exception when k is non-positive in SVD Author: MechCoder <[email protected]> Closes #3945 from MechCoder/spark-4406 and squashes the following commits: 64e6d2d [MechCoder] TST: Add better test errors and messages 12dae73 [MechCoder] [SPARK-4406] FIX: Validate k in SVD Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4554529d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4554529d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4554529d Branch: refs/heads/master Commit: 4554529dce8fe8ca937d887109ef072eef52bf51 Parents: 8782eb9 Author: MechCoder <[email protected]> Authored: Fri Jan 9 17:45:18 2015 -0800 Committer: Xiangrui Meng <[email protected]> Committed: Fri Jan 9 17:45:18 2015 -0800 ---------------------------------------------------------------------- .../spark/mllib/linalg/distributed/IndexedRowMatrix.scala | 3 +++ .../apache/spark/mllib/linalg/distributed/RowMatrix.scala | 2 +- .../mllib/linalg/distributed/IndexedRowMatrixSuite.scala | 7 +++++++ .../spark/mllib/linalg/distributed/RowMatrixSuite.scala | 8 ++++++++ 4 files changed, 19 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4554529d/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 36d8cad..181f507 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -102,6 +102,9 @@ class IndexedRowMatrix( k: Int, computeU: Boolean = false, rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + + val n = numCols().toInt + require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") val indices = rows.map(_.index) val svd = toRowMatrix().computeSVD(k, computeU, rCond) val U = if (computeU) { http://git-wip-us.apache.org/repos/asf/spark/blob/4554529d/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index fbd35e3..d5abba6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -212,7 +212,7 @@ class RowMatrix( tol: Double, mode: String): SingularValueDecomposition[RowMatrix, Matrix] = { val n = numCols().toInt - require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.") + require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") object SVDMode extends Enumeration { val LocalARPACK, LocalLAPACK, DistARPACK = Value http://git-wip-us.apache.org/repos/asf/spark/blob/4554529d/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index e25bc02..741cd49 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -113,6 +113,13 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate k in svd") { + val A = new IndexedRowMatrix(indexedRows) + intercept[IllegalArgumentException] { + A.computeSVD(-1) + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } http://git-wip-us.apache.org/repos/asf/spark/blob/4554529d/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index dbf55ff..3309713 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -171,6 +171,14 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } + test("validate k in svd") { + for (mat <- Seq(denseMat, sparseMat)) { + intercept[IllegalArgumentException] { + mat.computeSVD(-1) + } + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
