Github user attilapiros commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20686#discussion_r171765261
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
 ---
    @@ -103,11 +96,12 @@ class OneHotEncoderEstimatorSuite
           .setInputCols(Array("size"))
           .setOutputCols(Array("encoded"))
         val model = encoder.fit(df)
    -    val output = model.transform(df)
    -    val group = AttributeGroup.fromStructField(output.schema("encoded"))
    -    assert(group.size === 2)
    -    assert(group.getAttr(0) === 
BinaryAttribute.defaultAttr.withName("small").withIndex(0))
    -    assert(group.getAttr(1) === 
BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
    +    testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { 
rows =>
    +        val group = 
AttributeGroup.fromStructField(rows.head.schema("encoded"))
    +        assert(group.size === 2)
    +        assert(group.getAttr(0) === 
BinaryAttribute.defaultAttr.withName("small").withIndex(0))
    +        assert(group.getAttr(1) === 
BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
    +    }
    --- End diff --
    
    Thanks.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to