Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r171519201
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
---
@@ -116,11 +110,12 @@ class OneHotEncoderEstimatorSuite
.setInputCols(Array("index"))
.setOutputCols(Array("encoded"))
val model = encoder.fit(df)
- val output = model.transform(df)
- val group = AttributeGroup.fromStructField(output.schema("encoded"))
- assert(group.size === 2)
- assert(group.getAttr(0) ===
BinaryAttribute.defaultAttr.withName("0").withIndex(0))
- assert(group.getAttr(1) ===
BinaryAttribute.defaultAttr.withName("1").withIndex(1))
+ testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") {
rows =>
+ val group =
AttributeGroup.fromStructField(rows.head.schema("encoded"))
+ assert(group.size === 2)
+ assert(group.getAttr(0) ===
BinaryAttribute.defaultAttr.withName("0").withIndex(0))
+ assert(group.getAttr(1) ===
BinaryAttribute.defaultAttr.withName("1").withIndex(1))
+ }
--- End diff --
ditto.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]