Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15770#discussion_r178992899
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
 ---
    @@ -0,0 +1,171 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering
    +
    +import scala.collection.mutable
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.ml.linalg.{Vector, Vectors}
    +import org.apache.spark.ml.util.DefaultReadWriteTest
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
    +
    +class PowerIterationClusteringSuite extends SparkFunSuite
    +  with MLlibTestSparkContext with DefaultReadWriteTest {
    +
    +  @transient var data: Dataset[_] = _
    +  @transient var malData: Dataset[_] = _
    +  final val r1 = 1.0
    +  final val n1 = 10
    +  final val r2 = 4.0
    +  final val n2 = 40
    +
    +  override def beforeAll(): Unit = {
    +    super.beforeAll()
    +
    +    data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, 
n1, n2)
    +  }
    +
    +  test("default parameters") {
    +    val pic = new PowerIterationClustering()
    +
    +    assert(pic.getK === 2)
    +    assert(pic.getMaxIter === 20)
    +    assert(pic.getInitMode === "random")
    +    assert(pic.getFeaturesCol === "features")
    +    assert(pic.getPredictionCol === "prediction")
    +    assert(pic.getIdCol === "id")
    +    assert(pic.getWeightCol === "weight")
    +    assert(pic.getNeighborCol === "neighbor")
    +  }
    +
    +  test("set parameters") {
    +    val pic = new PowerIterationClustering()
    +      .setK(9)
    +      .setMaxIter(33)
    +      .setInitMode("degree")
    +      .setFeaturesCol("test_feature")
    +      .setPredictionCol("test_prediction")
    +      .setIdCol("test_id")
    +      .setWeightCol("test_weight")
    +      .setNeighborCol("test_neighbor")
    +
    +    assert(pic.getK === 9)
    +    assert(pic.getMaxIter === 33)
    +    assert(pic.getInitMode === "degree")
    +    assert(pic.getFeaturesCol === "test_feature")
    +    assert(pic.getPredictionCol === "test_prediction")
    +    assert(pic.getIdCol === "test_id")
    +    assert(pic.getWeightCol === "test_weight")
    +    assert(pic.getNeighborCol === "test_neighbor")
    +  }
    +
    +  test("parameters validation") {
    +    intercept[IllegalArgumentException] {
    +      new PowerIterationClustering().setK(1)
    +    }
    +    intercept[IllegalArgumentException] {
    +      new PowerIterationClustering().setInitMode("no_such_a_mode")
    +    }
    +  }
    +
    +  test("power iteration clustering") {
    +    val n = n1 + n2
    +
    +    val model = new PowerIterationClustering()
    +      .setK(2)
    +      .setMaxIter(40)
    +    val result = model.transform(data)
    +
    +    val predictions = Array.fill(2)(mutable.Set.empty[Long])
    +    result.select("id", "prediction").collect().foreach {
    +      case Row(id: Long, cluster: Integer) => predictions(cluster) += id
    +    }
    +    assert(predictions.toSet == Set((1 until n1).toSet, (n1 until 
n).toSet))
    +
    +    val result2 = new PowerIterationClustering()
    +      .setK(2)
    +      .setMaxIter(10)
    +      .setInitMode("degree")
    +      .transform(data)
    +    val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
    +    result2.select("id", "prediction").collect().foreach {
    +      case Row(id: Long, cluster: Integer) => predictions2(cluster) += id
    +    }
    +    assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until 
n).toSet))
    +
    +    val expectedColumns = Array("id", "prediction")
    --- End diff --
    
    No need to check this since it's already checks above by result2.select(...)


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to