Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r172408009
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
---
@@ -324,19 +352,46 @@ class QuantileDiscretizerSuite
.setStages(Array(discretizerForCol1, discretizerForCol2,
discretizerForCol3))
.fit(df)
- val resultForMultiCols = plForMultiCols.transform(df)
- .select("result1", "result2", "result3")
- .collect()
-
- val resultForSingleCol = plForSingleCol.transform(df)
- .select("result1", "result2", "result3")
- .collect()
+ val expected = Seq(
+ (0.0, 0.0, 0.0),
+ (0.0, 0.0, 1.0),
+ (0.0, 0.0, 1.0),
+ (0.0, 1.0, 2.0),
+ (0.0, 1.0, 2.0),
+ (0.0, 1.0, 2.0),
+ (0.0, 1.0, 3.0),
+ (0.0, 2.0, 4.0),
+ (0.0, 2.0, 4.0),
+ (1.0, 2.0, 5.0),
+ (1.0, 2.0, 5.0),
+ (1.0, 2.0, 5.0),
+ (1.0, 3.0, 6.0),
+ (1.0, 3.0, 6.0),
+ (1.0, 3.0, 7.0),
+ (1.0, 4.0, 8.0),
+ (1.0, 4.0, 8.0),
+ (1.0, 4.0, 9.0),
+ (1.0, 4.0, 9.0),
+ (1.0, 4.0, 9.0)
+ ).toDF("result1", "result2", "result3")
+ .collect().toSeq
--- End diff --
But I prefer to avoid hardcoding big literal array so that the code is
easier for maintenance. and following code is enough I think:
```
val expected = plForSingleCol.transform(df).select("result1", "result2",
"result3").collect()
testTransformerByGlobalCheckFunc[(Double, Double, Double)](
df,plForSingleCol,
"result1", "result2","result3") {
rows =>assert(rows == expected)
}
```
There is a similar case here
https://github.com/apache/spark/pull/20121#discussion_r172288890
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]