Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r173593885
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---
@@ -247,14 +253,18 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
- val transformed = indexer.transform(df)
+ val expected1 = Seq(0.0, 2.0, 1.0, 0.0, 0.0,
1.0).map(Tuple1(_)).toDF("labelIndex")
+ testTransformerByGlobalCheckFunc[(Int, String)](df, indexer,
"labelIndex") { rows =>
+ assert(rows == expected1.collect().seq)
+ }
+
val idx2str = new IndexToString()
.setInputCol("labelIndex")
.setOutputCol("sameLabel")
.setLabels(indexer.labels)
- idx2str.transform(transformed).select("label",
"sameLabel").collect().foreach {
- case Row(a: String, b: String) =>
- assert(a === b)
+
+ testTransformerByGlobalCheckFunc[(Double)](expected1, idx2str,
"sameLabel") { rows =>
--- End diff --
You should be able to test per-row, rather than using a global check
function.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]