Github user zhengruifeng commented on the issue:

    https://github.com/apache/spark/pull/19927
  
    test code:
    
    ```
    import org.apache.spark.ml.classification._
    
    val df = 
spark.read.format("libsvm").load("/Users/zrf/Dev/OpenSource/spark/data/mllib/sample_multiclass_classification_data.txt")
    
    val classifier = new 
LogisticRegression().setMaxIter(1).setTol(1E-6).setFitIntercept(true)
    val ovr = new OneVsRest().setClassifier(classifier)
    
    var df2 = df
    (0 until 10).foreach{ i => df2 = df2.union(df2) }
    df2 = df2.repartition(64)
    df2.persist()
    df2.count
    
    Seq(3, 5, 7, 10, 20).foreach { k =>
        val dfk = df2.withColumn("label", (rand()*k).cast("int"))
        val ovrModelK = ovr.fit(dfk)
        val start = System.nanoTime
        (0 until 10).foreach { i => ovrModelK.transform(dfk).count }
        val end = System.nanoTime
        val duration = (end - start) / 1e9
        println(s"numClasses $k, duration $duration")
    }
    
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to