zhengruifeng created SPARK-29754:
------------------------------------

             Summary: LoR/AFT/LiR/SVC use Summarizer instead of 
MultivariateOnlineSummarizer
                 Key: SPARK-29754
                 URL: https://issues.apache.org/jira/browse/SPARK-29754
             Project: Spark
          Issue Type: Improvement
          Components: ML
    Affects Versions: 3.0.0
            Reporter: zhengruifeng


Before iteration, LoR/AFT/LiR/SVC use MultivariateOnlineSummarizer to summarize 
the input dataset, however, MultivariateOnlineSummarizer compute much more than 
needed.

example:

bin/spark-shell --driver-memory=4G
{code:java}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.regression._
import org.apache.spark.ml.classification._

scala> val df = spark.read.format("libsvm").load("/data1/Datasets/kdda/kdda.t")
19/11/05 13:47:02 WARN LibSVMFileFormat: 'numFeatures' option not specified, 
determining the number of features by going though the input. If you know the 
number in advance, please specify it via 'numFeatures' option to avoid the 
extra scan.
df: org.apache.spark.sql.DataFrame = [label: double, features: vector]          
scala> df.persist()
res0: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, 
features: vector]
scala> df.count
res1: Long = 510302
scala> df.show(3)
+-----+--------------------+
|label|            features|
+-----+--------------------+
|  1.0|(2014669,[0,1,2,3...|
|  1.0|(2014669,[1,2,3,4...|
|  0.0|(2014669,[1,2,3,4...|
+-----+--------------------+

val lr = new LogisticRegression().setMaxIter(1)
val tic = System.currentTimeMillis; val model = lr.fit(df); val toc = 
System.currentTimeMillis; toc - tic {code}
The input dataset is here 
([https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#kdd2010%20(algebra))]

#instance=510302, #features=2014669

 

Above example will fail because of OOM:
{code:java}
Caused by: java.lang.OutOfMemoryError: Java heap space
        at java.lang.Object.clone(Native Method)
        at 
org.apache.spark.mllib.stat.MultivariateOnlineSummarizer.merge(MultivariateOnlineSummarizer.scala:174)
        at 
org.apache.spark.ml.classification.LogisticRegression.$anonfun$train$3(LogisticRegression.scala:511)
        at 
org.apache.spark.ml.classification.LogisticRegression$$Lambda$4111/1818679131.apply(Unknown
 Source)
        at 
org.apache.spark.rdd.PairRDDFunctions.$anonfun$foldByKey$3(PairRDDFunctions.scala:218)
        at 
org.apache.spark.rdd.PairRDDFunctions$$Lambda$4139/1537760275.apply(Unknown 
Source)
        at 
org.apache.spark.util.collection.ExternalSorter.$anonfun$insertAll$1(ExternalSorter.scala:190)
        at 
org.apache.spark.util.collection.ExternalSorter.$anonfun$insertAll$1$adapted(ExternalSorter.scala:189)
        at 
org.apache.spark.util.collection.ExternalSorter$$Lambda$4180/1672153085.apply(Unknown
 Source)
        at 
org.apache.spark.util.collection.AppendOnlyMap.changeValue(AppendOnlyMap.scala:144)
        at 
org.apache.spark.util.collection.SizeTrackingAppendOnlyMap.changeValue(SizeTrackingAppendOnlyMap.scala:32)
        at 
org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:195)
        at 
org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:62)
        at 
org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
        at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
        at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
        at org.apache.spark.scheduler.Task.run(Task.scala:127)
        at 
org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:462)
        at 
org.apache.spark.executor.Executor$TaskRunner$$Lambda$2799/542333665.apply(Unknown
 Source)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:465)
        at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        at java.lang.Thread.run(Thread.java:748)
 {code}
 

Here, if we use {{ml.Summarizer}} instead, only 3G memory is enough to fit this 
LR model.

 



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to