Github user yanboliang commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19156#discussion_r156564200
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala ---
    @@ -19,149 +19,165 @@ package org.apache.spark.ml.stat
     
     import org.scalatest.exceptions.TestFailedException
     
    -import org.apache.spark.{SparkException, SparkFunSuite}
    +import org.apache.spark.SparkFunSuite
     import org.apache.spark.ml.linalg.{Vector, Vectors}
     import org.apache.spark.ml.util.TestingUtils._
     import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => 
OldVectors}
     import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, 
Statistics}
     import org.apache.spark.mllib.util.MLlibTestSparkContext
     import org.apache.spark.sql.{DataFrame, Row}
    -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
     
     class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
     
       import testImplicits._
       import Summarizer._
       import SummaryBuilderImpl._
     
    -  private case class ExpectedMetrics(
    -      mean: Seq[Double],
    -      variance: Seq[Double],
    -      count: Long,
    -      numNonZeros: Seq[Long],
    -      max: Seq[Double],
    -      min: Seq[Double],
    -      normL2: Seq[Double],
    -      normL1: Seq[Double])
    -
       /**
    -   * The input is expected to be either a sparse vector, a dense vector or 
an array of doubles
    -   * (which will be converted to a dense vector)
    -   * The expected is the list of all the known metrics.
    +   * The input is expected to be either a sparse vector, a dense vector.
        *
    -   * The tests take an list of input vectors and a list of all the summary 
values that
    -   * are expected for this input. They currently test against some fixed 
subset of the
    -   * metrics, but should be made fuzzy in the future.
    +   * The tests take an list of input vectors, and compare results with
    +   * `mllib.stat.MultivariateOnlineSummarizer`. They currently test 
against some fixed subset
    +   * of the metrics, but should be made fuzzy in the future.
        */
    -  private def testExample(name: String, input: Seq[Any], exp: 
ExpectedMetrics): Unit = {
    +  private def testExample(name: String, inputVec: Seq[(Vector, Double)]): 
Unit = {
     
    -    def inputVec: Seq[Vector] = input.map {
    -      case x: Array[Double @unchecked] => Vectors.dense(x)
    -      case x: Seq[Double @unchecked] => Vectors.dense(x.toArray)
    -      case x: Vector => x
    -      case x => throw new Exception(x.toString)
    +    val summarizer = {
    +      val _summarizer = new MultivariateOnlineSummarizer
    +      inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v._1), v._2))
    +      _summarizer
         }
     
    -    val summarizer = {
    +    val summarizerWithoutWeight = {
           val _summarizer = new MultivariateOnlineSummarizer
    -      inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v)))
    +      inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v._1)))
           _summarizer
         }
     
         // Because the Spark context is reset between tests, we cannot hold a 
reference onto it.
         def wrappedInit() = {
    -      val df = inputVec.map(Tuple1.apply).toDF("features")
    -      val col = df.col("features")
    -      (df, col)
    +      val df = inputVec.toDF("features", "weight")
    +      val featuresCol = df.col("features")
    +      val weightCol = df.col("weight")
    +      (df, featuresCol, weightCol)
         }
     
         registerTest(s"$name - mean only") {
    -      val (df, c) = wrappedInit()
    -      compare(df.select(metrics("mean").summary(c), mean(c)), 
Seq(Row(exp.mean), summarizer.mean))
    +      val (df, c, weight) = wrappedInit()
    +      compare(df.select(metrics("mean").summary(c, weight), mean(c, 
weight)),
    +        Seq(Row(summarizer.mean), summarizer.mean))
         }
     
    -    registerTest(s"$name - mean only (direct)") {
    -      val (df, c) = wrappedInit()
    -      compare(df.select(mean(c)), Seq(exp.mean))
    +    registerTest(s"$name - mean only w/o weight") {
    +      val (df, c, _) = wrappedInit()
    +      compare(df.select(metrics("mean").summary(c), mean(c)),
    +        Seq(Row(summarizerWithoutWeight.mean), 
summarizerWithoutWeight.mean))
         }
     
         registerTest(s"$name - variance only") {
    -      val (df, c) = wrappedInit()
    -      compare(df.select(metrics("variance").summary(c), variance(c)),
    -        Seq(Row(exp.variance), summarizer.variance))
    +      val (df, c, weight) = wrappedInit()
    --- End diff --
    
    nit: ```weight``` can be abbreviated to ```w```.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to