Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20319#discussion_r183566557
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
---
@@ -102,17 +105,14 @@ class BisectingKMeansSuite
val model = bkm.fit(dataset)
assert(model.clusterCenters.length === k)
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", predictionColName)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
+ testTransformerByGlobalCheckFunc[Vector](dataset.toDF(), model,
+ "features", predictionColName) { rows =>
+ val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
+ assert(clusters.size === k)
+ assert(clusters === Set(0, 1, 2, 3, 4))
+ assert(model.computeCost(dataset) < 0.1)
--- End diff --
These checks which do not use "rows" should go outside of
testTransformerByGlobalCheckFunc
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]