Github user zhengruifeng commented on the issue:
https://github.com/apache/spark/pull/19927
test code:
```
import org.apache.spark.ml.classification._
val df =
spark.read.format("libsvm").load("/Users/zrf/Dev/OpenSource/spark/data/mllib/sample_multiclass_classification_data.txt")
val classifier = new
LogisticRegression().setMaxIter(1).setTol(1E-6).setFitIntercept(true)
val ovr = new OneVsRest().setClassifier(classifier)
var df2 = df
(0 until 10).foreach{ i => df2 = df2.union(df2) }
df2 = df2.repartition(64)
df2.persist()
df2.count
Seq(3, 5, 7, 10, 20).foreach { k =>
val dfk = df2.withColumn("label", (rand()*k).cast("int"))
val ovrModelK = ovr.fit(dfk)
val start = System.nanoTime
(0 until 10).foreach { i => ovrModelK.transform(dfk).count }
val end = System.nanoTime
val duration = (end - start) / 1e9
println(s"numClasses $k, duration $duration")
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]