Github user attilapiros commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r171615512
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
---
@@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite
val df = spark.createDataFrame(sc.parallelize(data), schema)
- val dfWithTypes = df
- .withColumn("shortInput", df("input").cast(ShortType))
- .withColumn("longInput", df("input").cast(LongType))
- .withColumn("intInput", df("input").cast(IntegerType))
- .withColumn("floatInput", df("input").cast(FloatType))
- .withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
-
- val cols = Array("input", "shortInput", "longInput", "intInput",
- "floatInput", "decimalInput")
- for (col <- cols) {
- val encoder = new OneHotEncoderEstimator()
- .setInputCols(Array(col))
+ class NumericTypeWithEncoder[A](val numericType: NumericType)
+ (implicit val encoder: Encoder[(A, Vector)])
+
+ val types = Seq(
+ new NumericTypeWithEncoder[Short](ShortType),
+ new NumericTypeWithEncoder[Long](LongType),
+ new NumericTypeWithEncoder[Int](IntegerType),
+ new NumericTypeWithEncoder[Float](FloatType),
+ new NumericTypeWithEncoder[Byte](ByteType),
+ new NumericTypeWithEncoder[Double](DoubleType),
+ new NumericTypeWithEncoder[Decimal](DecimalType(10,
0))(ExpressionEncoder()))
--- End diff --
The reason behind is we cannot pass runtime values (ShortType, LongType,
...) to the generic function testTransformer. But luckily [context bounds are
resolved to an implicit
parameter](https://docs.scala-lang.org/tutorials/FAQ/context-bounds.html#how-are-context-bounds-implemented)
this is the t.encoder which passed as a last parameter.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]