Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20686#discussion_r173592811
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---
    @@ -70,36 +71,37 @@ class StringIndexerSuite
           .setInputCol("label")
           .setOutputCol("labelIndex")
           .fit(df)
    +
         // Verify we throw by default with unseen values
    -    intercept[SparkException] {
    -      indexer.transform(df2).collect()
    -    }
    +    testTransformerByInterceptingException[(Int, String)](
    +      df2,
    +      indexer,
    +      "Unseen label:",
    +      "labelIndex")
     
         indexer.setHandleInvalid("skip")
    -    // Verify that we skip the c record
    -    val transformedSkip = indexer.transform(df2)
    -    val attrSkip = 
Attribute.fromStructField(transformedSkip.schema("labelIndex"))
    -      .asInstanceOf[NominalAttribute]
    -    assert(attrSkip.values.get === Array("b", "a"))
    -    val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { 
r =>
    -      (r.getInt(0), r.getDouble(1))
    -    }.collect().toSet
    -    // a -> 1, b -> 0
    -    val expectedSkip = Set((0, 1.0), (1, 0.0))
    -    assert(outputSkip === expectedSkip)
    +
    +    testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", 
"labelIndex") { rows =>
    +      val attrSkip = 
Attribute.fromStructField(rows.head.schema("labelIndex"))
    +        .asInstanceOf[NominalAttribute]
    +      assert(attrSkip.values.get === Array("b", "a"))
    +      // Verify that we skip the c record
    +      // a -> 1, b -> 0
    +      val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF()
    +      assert(rows.seq === expectedSkip.collect().toSeq)
    +    }
     
         indexer.setHandleInvalid("keep")
    +
         // Verify that we keep the unseen records
    -    val transformedKeep = indexer.transform(df2)
    -    val attrKeep = 
Attribute.fromStructField(transformedKeep.schema("labelIndex"))
    -      .asInstanceOf[NominalAttribute]
    -    assert(attrKeep.values.get === Array("b", "a", "__unknown"))
    -    val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { 
r =>
    -      (r.getInt(0), r.getDouble(1))
    -    }.collect().toSet
    -    // a -> 1, b -> 0, c -> 2, d -> 3
    -    val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0))
    -    assert(outputKeep === expectedKeep)
    +    testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", 
"labelIndex") { rows =>
    +      val attrKeep = 
Attribute.fromStructField(rows.head.schema("labelIndex"))
    +        .asInstanceOf[NominalAttribute]
    +      assert(attrKeep.values.get === Array("b", "a", "__unknown"))
    +      // a -> 1, b -> 0, c -> 2, d -> 3
    +      val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF()
    --- End diff --
    
    ditto: move outside checkFunc


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to