Github user BryanCutler commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20907#discussion_r178204492
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala ---
    @@ -202,6 +207,26 @@ class KMeansSuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultR
         testEstimatorAndModelReadWrite(kmeans, dataset, 
KMeansSuite.allParamSettings,
           KMeansSuite.allParamSettings, checkModelData)
       }
    +
    +  test("pmml export") {
    +    val clusterCenters = Array(
    +      MLlibVectors.dense(1.0, 2.0, 6.0),
    +      MLlibVectors.dense(1.0, 3.0, 0.0),
    +      MLlibVectors.dense(1.0, 4.0, 6.0))
    +    val oldKmeansModel = new MLlibKMeansModel(clusterCenters)
    +    val kmeansModel = new KMeansModel("", oldKmeansModel)
    +    def checkModel(pmml: PMML): Unit = {
    +      // Check the header descripiton is what we expect
    +      assert(pmml.getHeader.getDescription === "k-means clustering")
    +      // check that the number of fields match the single vector size
    +      assert(pmml.getDataDictionary.getNumberOfFields === 
clusterCenters(0).size)
    +      // This verify that there is a model attached to the pmml object and 
the model is a clustering
    +      // one. It also verifies that the pmml model has the same number of 
clusters of the spark
    +      // model.
    +      val pmmlClusteringModel = 
pmml.getModels.get(0).asInstanceOf[ClusteringModel]
    +      assert(pmmlClusteringModel.getNumberOfClusters === 
clusterCenters.length)
    +    }
    --- End diff --
    
    Isn't this missing a call to `testPMMLWrite`?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to