Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20686#discussion_r171759634
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
 ---
    @@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite
     
         val df = spark.createDataFrame(sc.parallelize(data), schema)
     
    -    val dfWithTypes = df
    -      .withColumn("shortInput", df("input").cast(ShortType))
    -      .withColumn("longInput", df("input").cast(LongType))
    -      .withColumn("intInput", df("input").cast(IntegerType))
    -      .withColumn("floatInput", df("input").cast(FloatType))
    -      .withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
    -
    -    val cols = Array("input", "shortInput", "longInput", "intInput",
    -      "floatInput", "decimalInput")
    -    for (col <- cols) {
    -      val encoder = new OneHotEncoderEstimator()
    -        .setInputCols(Array(col))
    +    class NumericTypeWithEncoder[A](val numericType: NumericType)
    +      (implicit val encoder: Encoder[(A, Vector)])
    +
    +    val types = Seq(
    +      new NumericTypeWithEncoder[Short](ShortType),
    +      new NumericTypeWithEncoder[Long](LongType),
    +      new NumericTypeWithEncoder[Int](IntegerType),
    +      new NumericTypeWithEncoder[Float](FloatType),
    +      new NumericTypeWithEncoder[Byte](ByteType),
    +      new NumericTypeWithEncoder[Double](DoubleType),
    +      new NumericTypeWithEncoder[Decimal](DecimalType(10, 
0))(ExpressionEncoder()))
    --- End diff --
    
    Oh I see. This is an syntax issue that `testTransformer` need generic 
parameter. When I design the `testTransformer` helper function, I cannot 
eliminate the generic parameter which make things difficult.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to