Repository: spark Updated Branches: refs/heads/master 454fe129e -> 7e8e62aec
[SPARK-5015] [mllib] Random seed for GMM + make test suite deterministic Issues: * From JIRA: GaussianMixtureEM uses randomness but does not take a random seed. It should take one as a parameter. * This also makes the test suite flaky since initialization can fail due to stochasticity. Fix: * Add random seed * Use it in test suite CC: mengxr tgaloppo Author: Joseph K. Bradley <[email protected]> Closes #3981 from jkbradley/gmm-seed and squashes the following commits: f0df4fd [Joseph K. Bradley] Added seed parameter to GMM. Updated test suite to use seed to prevent flakiness Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e8e62ae Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e8e62ae Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e8e62ae Branch: refs/heads/master Commit: 7e8e62aec11c43c983055adc475b96006412199a Parents: 454fe12 Author: Joseph K. Bradley <[email protected]> Authored: Fri Jan 9 13:00:15 2015 -0800 Committer: Xiangrui Meng <[email protected]> Committed: Fri Jan 9 13:00:15 2015 -0800 ---------------------------------------------------------------------- .../mllib/clustering/GaussianMixtureEM.scala | 26 ++++++++++++++------ .../GMMExpectationMaximizationSuite.scala | 14 ++++++----- 2 files changed, 27 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7e8e62ae/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala index 3a6c0e6..b3c5631 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS} import org.apache.spark.mllib.stat.impl.MultivariateGaussian import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.util.Utils /** * This class performs expectation maximization for multivariate Gaussian @@ -45,10 +46,11 @@ import org.apache.spark.mllib.util.MLUtils class GaussianMixtureEM private ( private var k: Int, private var convergenceTol: Double, - private var maxIterations: Int) extends Serializable { + private var maxIterations: Int, + private var seed: Long) extends Serializable { /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */ - def this() = this(2, 0.01, 100) + def this() = this(2, 0.01, 100, Utils.random.nextLong()) // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 @@ -100,11 +102,21 @@ class GaussianMixtureEM private ( this } - /** Return the largest change in log-likelihood at which convergence is - * considered to have occurred. + /** + * Return the largest change in log-likelihood at which convergence is + * considered to have occurred. */ def getConvergenceTol: Double = convergenceTol - + + /** Set the random seed */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** Return the random seed */ + def getSeed: Long = seed + /** Perform expectation maximization */ def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext @@ -113,7 +125,7 @@ class GaussianMixtureEM private ( val breezeData = data.map(u => u.toBreeze.toDenseVector).cache() // Get length of the input vectors - val d = breezeData.first.length + val d = breezeData.first().length // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise @@ -126,7 +138,7 @@ class GaussianMixtureEM private ( }) case None => { - val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt) + val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) http://git-wip-us.apache.org/repos/asf/spark/blob/7e8e62ae/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala index 23feb82..9da5495 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala @@ -35,12 +35,14 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex val Ew = 1.0 val Emu = Vectors.dense(5.0, 10.0) val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0)) - - val gmm = new GaussianMixtureEM().setK(1).run(data) - - assert(gmm.weight(0) ~== Ew absTol 1E-5) - assert(gmm.mu(0) ~== Emu absTol 1E-5) - assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + + val seeds = Array(314589, 29032897, 50181, 494821, 4660) + seeds.foreach { seed => + val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data) + assert(gmm.weight(0) ~== Ew absTol 1E-5) + assert(gmm.mu(0) ~== Emu absTol 1E-5) + assert(gmm.sigma(0) ~== Esigma absTol 1E-5) + } } test("two clusters") { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
