Github user WeichenXu123 commented on a diff in the pull request:
https://github.com/apache/spark/pull/20121#discussion_r172080757
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
---
@@ -2579,10 +2519,13 @@ class LogisticRegressionSuite
val lr4 = new LogisticRegression()
.setInitialModel(model3).setMaxIter(5).setFamily("multinomial")
val model4 = lr4.fit(smallMultinomialDataset)
- val predictions3 =
model3.transform(smallMultinomialDataset).select("prediction").collect()
- val predictions4 =
model4.transform(smallMultinomialDataset).select("prediction").collect()
- predictions3.zip(predictions4).foreach { case (Row(p1: Double),
Row(p2: Double)) =>
- assert(p1 === p2)
+ val multinomialExpected =
model3.transform(smallMultinomialDataset).select("prediction")
+ .collect().map(_.getDouble(0))
+ for (model <- Seq(model3, model4)) {
--- End diff --
The same reason above.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]