Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r173592635
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---
@@ -70,36 +71,37 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
+
// Verify we throw by default with unseen values
- intercept[SparkException] {
- indexer.transform(df2).collect()
- }
+ testTransformerByInterceptingException[(Int, String)](
+ df2,
+ indexer,
+ "Unseen label:",
+ "labelIndex")
indexer.setHandleInvalid("skip")
- // Verify that we skip the c record
- val transformedSkip = indexer.transform(df2)
- val attrSkip =
Attribute.fromStructField(transformedSkip.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attrSkip.values.get === Array("b", "a"))
- val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map {
r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
- // a -> 1, b -> 0
- val expectedSkip = Set((0, 1.0), (1, 0.0))
- assert(outputSkip === expectedSkip)
+
+ testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id",
"labelIndex") { rows =>
+ val attrSkip =
Attribute.fromStructField(rows.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrSkip.values.get === Array("b", "a"))
+ // Verify that we skip the c record
+ // a -> 1, b -> 0
+ val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF()
--- End diff --
This can be moved outside of the testTransformerByGlobalCheckFunc method.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]