Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/20686#discussion_r171759634 --- Diff: mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala --- @@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite val df = spark.createDataFrame(sc.parallelize(data), schema) - val dfWithTypes = df - .withColumn("shortInput", df("input").cast(ShortType)) - .withColumn("longInput", df("input").cast(LongType)) - .withColumn("intInput", df("input").cast(IntegerType)) - .withColumn("floatInput", df("input").cast(FloatType)) - .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) - - val cols = Array("input", "shortInput", "longInput", "intInput", - "floatInput", "decimalInput") - for (col <- cols) { - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array(col)) + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) --- End diff -- Oh I see. This is an syntax issue that `testTransformer` need generic parameter. When I design the `testTransformer` helper function, I cannot eliminate the generic parameter which make things difficult.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org