Github user attilapiros commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r171615849
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
---
@@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite
val df = spark.createDataFrame(sc.parallelize(data), schema)
- val dfWithTypes = df
- .withColumn("shortInput", df("input").cast(ShortType))
- .withColumn("longInput", df("input").cast(LongType))
- .withColumn("intInput", df("input").cast(IntegerType))
- .withColumn("floatInput", df("input").cast(FloatType))
- .withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
-
- val cols = Array("input", "shortInput", "longInput", "intInput",
- "floatInput", "decimalInput")
- for (col <- cols) {
- val encoder = new OneHotEncoderEstimator()
- .setInputCols(Array(col))
+ class NumericTypeWithEncoder[A](val numericType: NumericType)
+ (implicit val encoder: Encoder[(A, Vector)])
+
+ val types = Seq(
+ new NumericTypeWithEncoder[Short](ShortType),
+ new NumericTypeWithEncoder[Long](LongType),
+ new NumericTypeWithEncoder[Int](IntegerType),
+ new NumericTypeWithEncoder[Float](FloatType),
+ new NumericTypeWithEncoder[Byte](ByteType),
+ new NumericTypeWithEncoder[Double](DoubleType),
+ new NumericTypeWithEncoder[Decimal](DecimalType(10,
0))(ExpressionEncoder()))
+
+ for (t <- types) {
+ val dfWithTypes = df.select(col("input").cast(t.numericType),
col("expected"))
+ val estimator = new OneHotEncoderEstimator()
+ .setInputCols(Array("input"))
.setOutputCols(Array("output"))
.setDropLast(false)
- val model = encoder.fit(dfWithTypes)
- val encoded = model.transform(dfWithTypes)
-
- encoded.select("output", "expected").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1))
- }.collect().foreach { case (vec1, vec2) =>
- assert(vec1 === vec2)
- }
+ val model = estimator.fit(dfWithTypes)
+ testTransformer(dfWithTypes, model, "output", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
+ }(t.encoder)
--- End diff --
See previous comment.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]