zhengruifeng edited a comment on pull request #29255:
URL: https://github.com/apache/spark/pull/29255#issuecomment-664216405
test code:
```
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.classification._
import org.apache.spark.storage.StorageLevel
val df =
spark.read.format("libsvm").load("/data1/Datasets/a9a/a9a").withColumn("label",
(col("label")+1)/2)
df.persist(StorageLevel.MEMORY_AND_DISK)
df.count
val lr = new LogisticRegression().setMaxIter(10)
val model = lr.fit(df)
val vecs = df.select("features").rdd.map(row => row.getAs[Vector](0)).collect
model.setThreshold(0.2)
val start = System.currentTimeMillis; Seq.range(0, 1000).foreach{i =>
vecs.foreach{vec => model.predict(vec)}}; val end = System.currentTimeMillis;
end - start
val start = System.currentTimeMillis; Seq.range(0, 1000).foreach{i =>
vecs.foreach{vec => model.predictRaw(vec)}}; val end =
System.currentTimeMillis; end - start
val start = System.currentTimeMillis; Seq.range(0, 1000).foreach{i =>
vecs.foreach{vec => model.predictProbability(vec)}}; val end =
System.currentTimeMillis; end - start
val start = System.currentTimeMillis; Seq.range(0, 100).foreach{i =>
model.transform(df).count}; val end = System.currentTimeMillis; end - start
model.setThresholds(Array(1, 10))
val start = System.currentTimeMillis; Seq.range(0, 1000).foreach{i =>
vecs.foreach{vec => model.predict(vec)}}; val end = System.currentTimeMillis;
end - start
val start = System.currentTimeMillis; Seq.range(0, 1000).foreach{i =>
vecs.foreach{vec => model.predictRaw(vec)}}; val end =
System.currentTimeMillis; end - start
val start = System.currentTimeMillis; Seq.range(0, 1000).foreach{i =>
vecs.foreach{vec => model.predictProbability(vec)}}; val end =
System.currentTimeMillis; end - start
val start = System.currentTimeMillis; Seq.range(0, 100).foreach{i =>
model.transform(df).count}; val end = System.currentTimeMillis; end - start
```
Results:
`model.setThreshold(0.2)`
|Durations| model.predict | model.predictRaw | model.predictProbability |
model.transform(df).count |
|------|----------|------------|----------|------------|
|This PR|3895|5740|7502|3791|
|Master|28139|5878|9121|4049|
`model.setThresholds(Array(1, 10))`
|Durations| model.predict | model.predictRaw | model.predictProbability |
model.transform(df).count |
|------|----------|------------|----------|------------|
|This PR|3693|5886|7523|3687|
|Master|21052|5828|9086|3857|
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]