Github user holdenk commented on a diff in the pull request:
https://github.com/apache/spark/pull/20907#discussion_r178382724
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala ---
@@ -202,6 +207,26 @@ class KMeansSuite extends SparkFunSuite with
MLlibTestSparkContext with DefaultR
testEstimatorAndModelReadWrite(kmeans, dataset,
KMeansSuite.allParamSettings,
KMeansSuite.allParamSettings, checkModelData)
}
+
+ test("pmml export") {
+ val clusterCenters = Array(
+ MLlibVectors.dense(1.0, 2.0, 6.0),
+ MLlibVectors.dense(1.0, 3.0, 0.0),
+ MLlibVectors.dense(1.0, 4.0, 6.0))
+ val oldKmeansModel = new MLlibKMeansModel(clusterCenters)
+ val kmeansModel = new KMeansModel("", oldKmeansModel)
+ def checkModel(pmml: PMML): Unit = {
+ // Check the header descripiton is what we expect
+ assert(pmml.getHeader.getDescription === "k-means clustering")
+ // check that the number of fields match the single vector size
+ assert(pmml.getDataDictionary.getNumberOfFields ===
clusterCenters(0).size)
+ // This verify that there is a model attached to the pmml object and
the model is a clustering
+ // one. It also verifies that the pmml model has the same number of
clusters of the spark
+ // model.
+ val pmmlClusteringModel =
pmml.getModels.get(0).asInstanceOf[ClusteringModel]
+ assert(pmmlClusteringModel.getNumberOfClusters ===
clusterCenters.length)
+ }
--- End diff --
Oh yeah :( Thanks for catching that.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]