Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r171545178
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala ---
@@ -119,29 +131,41 @@ class OneHotEncoderSuite
test("OneHotEncoder with varying types") {
val df = stringIndexed()
- val dfWithTypes = df
- .withColumn("shortLabel", df("labelIndex").cast(ShortType))
- .withColumn("longLabel", df("labelIndex").cast(LongType))
- .withColumn("intLabel", df("labelIndex").cast(IntegerType))
- .withColumn("floatLabel", df("labelIndex").cast(FloatType))
- .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10,
0)))
- val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel",
- "floatLabel", "decimalLabel")
- for (col <- cols) {
+ val attr = NominalAttribute.defaultAttr.withValues("small", "medium",
"large")
+ val expected = Seq(
+ (0, Vectors.sparse(3, Seq((0, 1.0)))),
+ (1, Vectors.sparse(3, Seq((2, 1.0)))),
+ (2, Vectors.sparse(3, Seq((1, 1.0)))),
+ (3, Vectors.sparse(3, Seq((0, 1.0)))),
+ (4, Vectors.sparse(3, Seq((0, 1.0)))),
+ (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected")
+
+ val withExpected = df.join(expected, "id")
+
+ class NumericTypeWithEncoder[A](val numericType: NumericType)
+ (implicit val encoder: Encoder[(A, Vector)])
+
+ val types = Seq(
+ new NumericTypeWithEncoder[Short](ShortType),
+ new NumericTypeWithEncoder[Long](LongType),
+ new NumericTypeWithEncoder[Int](IntegerType),
+ new NumericTypeWithEncoder[Float](FloatType),
+ new NumericTypeWithEncoder[Byte](ByteType),
+ new NumericTypeWithEncoder[Double](DoubleType),
+ new NumericTypeWithEncoder[Decimal](DecimalType(10,
0))(ExpressionEncoder()))
--- End diff --
ditto.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]