Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r173592811
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---
@@ -70,36 +71,37 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
+
// Verify we throw by default with unseen values
- intercept[SparkException] {
- indexer.transform(df2).collect()
- }
+ testTransformerByInterceptingException[(Int, String)](
+ df2,
+ indexer,
+ "Unseen label:",
+ "labelIndex")
indexer.setHandleInvalid("skip")
- // Verify that we skip the c record
- val transformedSkip = indexer.transform(df2)
- val attrSkip =
Attribute.fromStructField(transformedSkip.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attrSkip.values.get === Array("b", "a"))
- val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map {
r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
- // a -> 1, b -> 0
- val expectedSkip = Set((0, 1.0), (1, 0.0))
- assert(outputSkip === expectedSkip)
+
+ testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id",
"labelIndex") { rows =>
+ val attrSkip =
Attribute.fromStructField(rows.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrSkip.values.get === Array("b", "a"))
+ // Verify that we skip the c record
+ // a -> 1, b -> 0
+ val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF()
+ assert(rows.seq === expectedSkip.collect().toSeq)
+ }
indexer.setHandleInvalid("keep")
+
// Verify that we keep the unseen records
- val transformedKeep = indexer.transform(df2)
- val attrKeep =
Attribute.fromStructField(transformedKeep.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attrKeep.values.get === Array("b", "a", "__unknown"))
- val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map {
r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
- // a -> 1, b -> 0, c -> 2, d -> 3
- val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0))
- assert(outputKeep === expectedKeep)
+ testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id",
"labelIndex") { rows =>
+ val attrKeep =
Attribute.fromStructField(rows.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+ // a -> 1, b -> 0, c -> 2, d -> 3
+ val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF()
--- End diff --
ditto: move outside checkFunc
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]