Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/21195#discussion_r186555425
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
---
@@ -256,6 +257,22 @@ class GaussianMixtureSuite extends SparkFunSuite with
MLlibTestSparkContext
val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4,
triangularValues)
assert(symmetricMatrix === expectedMatrix)
}
+
+ test("GaussianMixture with Array input") {
+ def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = {
+ val model = new
GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset)
+ model.summary.logLikelihood
+ }
+
+ val (newDataset, newDatasetD, newDatasetF) =
MLTestingUtils.generateArrayFeatureDataset(dataset)
+ val trueLikelihood = trainAndComputlogLikelihood(newDataset)
+ val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD)
+ val floatLikelihood = trainAndComputlogLikelihood(newDatasetF)
+
+ // checking the cost is fine enough as a sanity check
+ assert(trueLikelihood == doubleLikelihood)
--- End diff --
minor: should use `===` instead of `==` for assertions, the former gives a
better error message. (not necessary to update this PR)
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]