Github user wangmiao1981 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21090#discussion_r182243819
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
 ---
    @@ -0,0 +1,239 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.clustering
    +
    +import scala.collection.mutable
    +
    +import org.apache.spark.ml.util.DefaultReadWriteTest
    +import org.apache.spark.mllib.util.MLlibTestSparkContext
    +import org.apache.spark.sql.functions.col
    +import org.apache.spark.sql.types._
    +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
    +import org.apache.spark.{SparkException, SparkFunSuite}
    +
    +
    +class PowerIterationClusteringSuite extends SparkFunSuite
    +  with MLlibTestSparkContext with DefaultReadWriteTest {
    +
    +  @transient var data: Dataset[_] = _
    +  final val r1 = 1.0
    +  final val n1 = 10
    +  final val r2 = 4.0
    +  final val n2 = 40
    +
    +  override def beforeAll(): Unit = {
    +    super.beforeAll()
    +
    +    data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, 
n1, n2)
    +  }
    +
    +  test("default parameters") {
    +    val pic = new PowerIterationClustering()
    +
    +    assert(pic.getK === 2)
    +    assert(pic.getMaxIter === 20)
    +    assert(pic.getInitMode === "random")
    +    assert(pic.getPredictionCol === "prediction")
    +    assert(pic.getIdCol === "id")
    +    assert(pic.getNeighborsCol === "neighbors")
    +    assert(pic.getSimilaritiesCol === "similarities")
    +  }
    +
    +  test("parameter validation") {
    +    intercept[IllegalArgumentException] {
    +      new PowerIterationClustering().setK(1)
    +    }
    +    intercept[IllegalArgumentException] {
    +      new PowerIterationClustering().setInitMode("no_such_a_mode")
    +    }
    +    intercept[IllegalArgumentException] {
    +      new PowerIterationClustering().setIdCol("")
    +    }
    +    intercept[IllegalArgumentException] {
    +      new PowerIterationClustering().setNeighborsCol("")
    +    }
    +    intercept[IllegalArgumentException] {
    +      new PowerIterationClustering().setSimilaritiesCol("")
    +    }
    +  }
    +
    +  test("power iteration clustering") {
    +    val n = n1 + n2
    +
    +    val model = new PowerIterationClustering()
    +      .setK(2)
    +      .setMaxIter(40)
    +    val result = model.transform(data)
    +
    +    val predictions = Array.fill(2)(mutable.Set.empty[Long])
    +    result.select("id", "prediction").collect().foreach {
    +      case Row(id: Long, cluster: Integer) => predictions(cluster) += id
    +    }
    +    assert(predictions.toSet == Set((1 until n1).toSet, (n1 until 
n).toSet))
    +
    +    val result2 = new PowerIterationClustering()
    +      .setK(2)
    +      .setMaxIter(10)
    +      .setInitMode("degree")
    +      .transform(data)
    +    val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
    +    result2.select("id", "prediction").collect().foreach {
    +      case Row(id: Long, cluster: Integer) => predictions2(cluster) += id
    +    }
    +    assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until 
n).toSet))
    +  }
    +
    +  test("supported input types") {
    +    val model = new PowerIterationClustering()
    +      .setK(2)
    +      .setMaxIter(1)
    +
    +    def runTest(idType: DataType, neighborType: DataType, similarityType: 
DataType): Unit = {
    +      val typedData = data.select(
    +        col("id").cast(idType).alias("id"),
    +        col("neighbors").cast(ArrayType(neighborType, containsNull = 
false)).alias("neighbors"),
    +        col("similarities").cast(ArrayType(similarityType, containsNull = 
false))
    +          .alias("similarities")
    +      )
    +      model.transform(typedData).collect()
    +    }
    +
    +    for (idType <- Seq(IntegerType, LongType)) {
    +      runTest(idType, LongType, DoubleType)
    +    }
    +    for (neighborType <- Seq(IntegerType, LongType)) {
    +      runTest(LongType, neighborType, DoubleType)
    +    }
    +    for (similarityType <- Seq(FloatType, DoubleType)) {
    +      runTest(LongType, LongType, similarityType)
    +    }
    +  }
    +
    +  test("invalid input: wrong types") {
    +    val model = new PowerIterationClustering()
    +      .setK(2)
    +      .setMaxIter(1)
    +    intercept[IllegalArgumentException] {
    +      val typedData = data.select(
    +        col("id").cast(DoubleType).alias("id"),
    +        col("neighbors"),
    +        col("similarities")
    +      )
    +      model.transform(typedData)
    +    }
    +    intercept[IllegalArgumentException] {
    +      val typedData = data.select(
    +        col("id"),
    +        col("neighbors").cast(ArrayType(DoubleType, containsNull = 
false)).alias("neighbors"),
    +        col("similarities")
    +      )
    +      model.transform(typedData)
    +    }
    +    intercept[IllegalArgumentException] {
    +
    --- End diff --
    
    remove blank line or add blank line after line 139 for consistence? 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to