Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21195#discussion_r186555425
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala 
---
    @@ -256,6 +257,22 @@ class GaussianMixtureSuite extends SparkFunSuite with 
MLlibTestSparkContext
         val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, 
triangularValues)
         assert(symmetricMatrix === expectedMatrix)
       }
    +
    +  test("GaussianMixture with Array input") {
    +    def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = {
    +      val model = new 
GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset)
    +      model.summary.logLikelihood
    +    }
    +
    +    val (newDataset, newDatasetD, newDatasetF) = 
MLTestingUtils.generateArrayFeatureDataset(dataset)
    +    val trueLikelihood = trainAndComputlogLikelihood(newDataset)
    +    val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD)
    +    val floatLikelihood = trainAndComputlogLikelihood(newDatasetF)
    +
    +    // checking the cost is fine enough as a sanity check
    +    assert(trueLikelihood == doubleLikelihood)
    --- End diff --
    
    minor: should use `===` instead of `==` for assertions, the former gives a 
better error message. (not necessary to update this PR)


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to