Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r171762031
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala ---
@@ -538,21 +540,28 @@ class RFormulaSuite extends MLTest with
DefaultReadWriteTest {
// Handle unseen labels.
val formula2 = new RFormula().setFormula("b ~ a + id")
- intercept[SparkException] {
- formula2.fit(df1).transform(df2).collect()
- }
+ testTransformerByInterceptingException[(Int, String, String)](
+ df2,
+ formula2.fit(df1),
+ "Unseen label:",
+ "label")
+
val model3 = formula2.setHandleInvalid("skip").fit(df1)
val model4 = formula2.setHandleInvalid("keep").fit(df1)
+ val attr = NominalAttribute.defaultAttr
val expected3 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
(2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0)
).toDF("id", "a", "b", "features", "label")
+ .select($"id", $"a", $"b", $"features", $"label".as("label",
attr.toMetadata()))
--- End diff --
nit: indent
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]