Repository: spark
Updated Branches:
  refs/heads/master 288ce583b -> 561d31d2f


[SPARK-4614][MLLIB] Slight API changes in Matrix and Matrices

Before we have a full picture of the operators we want to add, it might be 
safer to hide `Matrix.transposeMultiply` in 1.2.0. Another update we want to 
change is `Matrix.randn` and `Matrix.rand`, both of which should take a 
`Random` implementation. Otherwise, it is very likely to produce inconsistent 
RDDs. I also added some unit tests for matrix factory methods. All APIs are new 
in 1.2, so there is no incompatible changes.

brkyvz

Author: Xiangrui Meng <[email protected]>

Closes #3468 from mengxr/SPARK-4614 and squashes the following commits:

3b0e4e2 [Xiangrui Meng] add mima excludes
6bfd8a4 [Xiangrui Meng] hide transposeMultiply; add rng to rand and randn; add 
unit tests


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/561d31d2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/561d31d2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/561d31d2

Branch: refs/heads/master
Commit: 561d31d2f13cc7b1112ba9f9aa8f08bcd032aebb
Parents: 288ce58
Author: Xiangrui Meng <[email protected]>
Authored: Wed Nov 26 08:22:50 2014 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Wed Nov 26 08:22:50 2014 -0800

----------------------------------------------------------------------
 .../apache/spark/mllib/linalg/Matrices.scala    | 20 ++++----
 .../spark/mllib/linalg/MatricesSuite.scala      | 50 ++++++++++++++++++++
 project/MimaExcludes.scala                      |  6 +++
 3 files changed, 65 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/561d31d2/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 2cc52e9..327366a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -17,12 +17,10 @@
 
 package org.apache.spark.mllib.linalg
 
-import java.util.Arrays
+import java.util.{Random, Arrays}
 
 import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM}
 
-import org.apache.spark.util.random.XORShiftRandom
-
 /**
  * Trait for a local matrix.
  */
@@ -67,14 +65,14 @@ sealed trait Matrix extends Serializable {
   }
 
   /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */
-  def transposeMultiply(y: DenseMatrix): DenseMatrix = {
+  private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = {
     val C: DenseMatrix = Matrices.zeros(numCols, 
y.numCols).asInstanceOf[DenseMatrix]
     BLAS.gemm(true, false, 1.0, this, y, 0.0, C)
     C
   }
 
   /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */
-  def transposeMultiply(y: DenseVector): DenseVector = {
+  private[mllib] def transposeMultiply(y: DenseVector): DenseVector = {
     val output = new DenseVector(new Array[Double](numCols))
     BLAS.gemv(true, 1.0, this, y, 0.0, output)
     output
@@ -291,22 +289,22 @@ object Matrices {
    * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers.
    * @param numRows number of rows of the matrix
    * @param numCols number of columns of the matrix
+   * @param rng a random number generator
    * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 
1)
    */
-  def rand(numRows: Int, numCols: Int): Matrix = {
-    val rand = new XORShiftRandom
-    new DenseMatrix(numRows, numCols, Array.fill(numRows * 
numCols)(rand.nextDouble()))
+  def rand(numRows: Int, numCols: Int, rng: Random): Matrix = {
+    new DenseMatrix(numRows, numCols, Array.fill(numRows * 
numCols)(rng.nextDouble()))
   }
 
   /**
    * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers.
    * @param numRows number of rows of the matrix
    * @param numCols number of columns of the matrix
+   * @param rng a random number generator
    * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 
1)
    */
-  def randn(numRows: Int, numCols: Int): Matrix = {
-    val rand = new XORShiftRandom
-    new DenseMatrix(numRows, numCols, Array.fill(numRows * 
numCols)(rand.nextGaussian()))
+  def randn(numRows: Int, numCols: Int, rng: Random): Matrix = {
+    new DenseMatrix(numRows, numCols, Array.fill(numRows * 
numCols)(rng.nextGaussian()))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/561d31d2/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 5f8b8c4..322a0e9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -17,7 +17,11 @@
 
 package org.apache.spark.mllib.linalg
 
+import java.util.Random
+
+import org.mockito.Mockito.when
 import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar._
 
 class MatricesSuite extends FunSuite {
   test("dense matrix construction") {
@@ -112,4 +116,50 @@ class MatricesSuite extends FunSuite {
     assert(sparseMat(0, 1) === 10.0)
     assert(sparseMat.values(2) === 10.0)
   }
+
+  test("zeros") {
+    val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix]
+    assert(mat.numRows === 2)
+    assert(mat.numCols === 3)
+    assert(mat.values.forall(_ == 0.0))
+  }
+
+  test("ones") {
+    val mat = Matrices.ones(2, 3).asInstanceOf[DenseMatrix]
+    assert(mat.numRows === 2)
+    assert(mat.numCols === 3)
+    assert(mat.values.forall(_ == 1.0))
+  }
+
+  test("eye") {
+    val mat = Matrices.eye(2).asInstanceOf[DenseMatrix]
+    assert(mat.numCols === 2)
+    assert(mat.numCols === 2)
+    assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 1.0))
+  }
+
+  test("rand") {
+    val rng = mock[Random]
+    when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0)
+    val mat = Matrices.rand(2, 2, rng).asInstanceOf[DenseMatrix]
+    assert(mat.numRows === 2)
+    assert(mat.numCols === 2)
+    assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
+  }
+
+  test("randn") {
+    val rng = mock[Random]
+    when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0)
+    val mat = Matrices.randn(2, 2, rng).asInstanceOf[DenseMatrix]
+    assert(mat.numRows === 2)
+    assert(mat.numCols === 2)
+    assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
+  }
+
+  test("diag") {
+    val mat = Matrices.diag(Vectors.dense(1.0, 2.0)).asInstanceOf[DenseMatrix]
+    assert(mat.numRows === 2)
+    assert(mat.numCols === 2)
+    assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/561d31d2/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 94de14d..230239a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -47,6 +47,12 @@ object MimaExcludes {
               "org.apache.spark.SparkStageInfoImpl.this"),
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.SparkStageInfo.submissionTime")
+          ) ++ Seq(
+            // SPARK-4614
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.Matrices.randn"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.Matrices.rand")
           )
 
         case v if v.startsWith("1.2") =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to