Repository: spark Updated Branches: refs/heads/master 3ce46e94f -> c76da36c2
[SPARK-5858][MLLIB] Remove unnecessary first() call in GLM `numFeatures` is only used by multinomial logistic regression. Calling `.first()` for every GLM causes performance regression, especially in Python. Author: Xiangrui Meng <m...@databricks.com> Closes #4647 from mengxr/SPARK-5858 and squashes the following commits: 036dc7f [Xiangrui Meng] remove unnecessary first() call 12c5548 [Xiangrui Meng] check numFeatures only once Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c76da36c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c76da36c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c76da36c Branch: refs/heads/master Commit: c76da36c2163276b5c34e59fbb139eeb34ed0faa Parents: 3ce46e9 Author: Xiangrui Meng <m...@databricks.com> Authored: Tue Feb 17 10:17:45 2015 -0800 Committer: Xiangrui Meng <m...@databricks.com> Committed: Tue Feb 17 10:17:45 2015 -0800 ---------------------------------------------------------------------- .../spark/mllib/classification/LogisticRegression.scala | 6 +++++- .../spark/mllib/regression/GeneralizedLinearAlgorithm.scala | 7 ++++--- 2 files changed, 9 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c76da36c/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 420d6e2..b787667 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -355,6 +355,10 @@ class LogisticRegressionWithLBFGS } override protected def createModel(weights: Vector, intercept: Double) = { - new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) + if (numOfLinearPredictor == 1) { + new LogisticRegressionModel(weights, intercept) + } else { + new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/c76da36c/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 2b71453..7c66e8c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -126,7 +126,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * The dimension of training features. */ - protected var numFeatures: Int = 0 + protected var numFeatures: Int = -1 /** * Set if the algorithm should use feature scaling to improve the convergence during optimization. @@ -163,7 +163,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * RDD of LabeledPoint entries. */ def run(input: RDD[LabeledPoint]): M = { - numFeatures = input.first().features.size + if (numFeatures < 0) { + numFeatures = input.map(_.features.size).first() + } /** * When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights, @@ -193,7 +195,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * of LabeledPoint entries starting from the initial weights provided. */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { - numFeatures = input.first().features.size if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org