Github user attilapiros commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20686#discussion_r171615849
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
 ---
    @@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite
     
         val df = spark.createDataFrame(sc.parallelize(data), schema)
     
    -    val dfWithTypes = df
    -      .withColumn("shortInput", df("input").cast(ShortType))
    -      .withColumn("longInput", df("input").cast(LongType))
    -      .withColumn("intInput", df("input").cast(IntegerType))
    -      .withColumn("floatInput", df("input").cast(FloatType))
    -      .withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
    -
    -    val cols = Array("input", "shortInput", "longInput", "intInput",
    -      "floatInput", "decimalInput")
    -    for (col <- cols) {
    -      val encoder = new OneHotEncoderEstimator()
    -        .setInputCols(Array(col))
    +    class NumericTypeWithEncoder[A](val numericType: NumericType)
    +      (implicit val encoder: Encoder[(A, Vector)])
    +
    +    val types = Seq(
    +      new NumericTypeWithEncoder[Short](ShortType),
    +      new NumericTypeWithEncoder[Long](LongType),
    +      new NumericTypeWithEncoder[Int](IntegerType),
    +      new NumericTypeWithEncoder[Float](FloatType),
    +      new NumericTypeWithEncoder[Byte](ByteType),
    +      new NumericTypeWithEncoder[Double](DoubleType),
    +      new NumericTypeWithEncoder[Decimal](DecimalType(10, 
0))(ExpressionEncoder()))
    +
    +    for (t <- types) {
    +      val dfWithTypes = df.select(col("input").cast(t.numericType), 
col("expected"))
    +      val estimator = new OneHotEncoderEstimator()
    +        .setInputCols(Array("input"))
             .setOutputCols(Array("output"))
             .setDropLast(false)
     
    -      val model = encoder.fit(dfWithTypes)
    -      val encoded = model.transform(dfWithTypes)
    -
    -      encoded.select("output", "expected").rdd.map { r =>
    -        (r.getAs[Vector](0), r.getAs[Vector](1))
    -      }.collect().foreach { case (vec1, vec2) =>
    -        assert(vec1 === vec2)
    -      }
    +      val model = estimator.fit(dfWithTypes)
    +      testTransformer(dfWithTypes, model, "output", "expected") {
    +        case Row(output: Vector, expected: Vector) =>
    +          assert(output === expected)
    +      }(t.encoder)
    --- End diff --
    
    See previous comment.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to