Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20686#discussion_r173593885
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---
    @@ -247,14 +253,18 @@ class StringIndexerSuite
           .setInputCol("label")
           .setOutputCol("labelIndex")
           .fit(df)
    -    val transformed = indexer.transform(df)
    +    val expected1 = Seq(0.0, 2.0, 1.0, 0.0, 0.0, 
1.0).map(Tuple1(_)).toDF("labelIndex")
    +    testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, 
"labelIndex") { rows =>
    +      assert(rows == expected1.collect().seq)
    +    }
    +
         val idx2str = new IndexToString()
           .setInputCol("labelIndex")
           .setOutputCol("sameLabel")
           .setLabels(indexer.labels)
    -    idx2str.transform(transformed).select("label", 
"sameLabel").collect().foreach {
    -      case Row(a: String, b: String) =>
    -        assert(a === b)
    +
    +    testTransformerByGlobalCheckFunc[(Double)](expected1, idx2str, 
"sameLabel") { rows =>
    --- End diff --
    
    You should be able to test per-row, rather than using a global check 
function.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to