Repository: spark Updated Branches: refs/heads/master 0d4a69527 -> 21cb59f1c
[SPARK-17835][ML][MLLIB] Optimize NaiveBayes mllib wrapper to eliminate extra pass on data ## What changes were proposed in this pull request? [SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077) copied the ```NaiveBayes``` implementation from mllib to ml and left mllib as a wrapper. However, there are some difference between mllib and ml to handle labels: * mllib allow input labels as {-1, +1}, however, ml assumes the input labels in range [0, numClasses). * mllib ```NaiveBayesModel``` expose ```labels``` but ml did not due to the assumption mention above. During the copy in [SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077), we use ```val labels = data.map(_.label).distinct().collect().sorted``` to get the distinct labels firstly, and then encode the labels for training. It involves extra Spark job compared with the original implementation. Since ```NaiveBayes``` only do one pass aggregation during training, adding another one seems less efficient. We can get the labels in a single pass along with ```NaiveBayes``` training and send them to MLlib side. ## How was this patch tested? Existing tests. Author: Yanbo Liang <yblia...@gmail.com> Closes #15402 from yanboliang/spark-17835. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21cb59f1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21cb59f1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21cb59f1 Branch: refs/heads/master Commit: 21cb59f1cd137d96b2596f1abe691b544581cf59 Parents: 0d4a695 Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Oct 12 19:56:40 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Wed Oct 12 19:56:40 2016 -0700 ---------------------------------------------------------------------- .../spark/ml/classification/NaiveBayes.scala | 46 ++++++++++++++++---- .../spark/mllib/classification/NaiveBayes.scala | 15 +++---- 2 files changed, 43 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/21cb59f1/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index e565a6f..994ed99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -110,16 +110,28 @@ class NaiveBayes @Since("1.5.0") ( @Since("2.1.0") def setWeightCol(value: String): this.type = set(weightCol, value) - override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - val numClasses = getNumClasses(dataset) + /** + * ml assumes input labels in range [0, numClasses). But this implementation + * is also called by mllib NaiveBayes which allows other kinds of input labels + * such as {-1, +1}. Here we use this parameter to switch between different processing logic. + * It should be removed when we remove mllib NaiveBayes. + */ + private[spark] var isML: Boolean = true - if (isDefined(thresholds)) { - require($(thresholds).length == numClasses, this.getClass.getSimpleName + - ".train() called with non-matching numClasses and thresholds.length." + - s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") - } + private[spark] def setIsML(isML: Boolean): this.type = { + this.isML = isML + this + } - val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + if (isML) { + val numClasses = getNumClasses(dataset) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + } val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { @@ -153,6 +165,7 @@ class NaiveBayes @Since("1.5.0") ( } } + val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) // Aggregates term frequencies per label. @@ -176,6 +189,7 @@ class NaiveBayes @Since("1.5.0") ( val numLabels = aggregated.length val numDocuments = aggregated.map(_._2._1).sum + val labelArray = new Array[Double](numLabels) val piArray = new Array[Double](numLabels) val thetaArray = new Array[Double](numLabels * numFeatures) @@ -183,6 +197,7 @@ class NaiveBayes @Since("1.5.0") ( val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => + labelArray(i) = label piArray(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = $(modelType) match { case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) @@ -201,7 +216,7 @@ class NaiveBayes @Since("1.5.0") ( val pi = Vectors.dense(piArray) val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) - new NaiveBayesModel(uid, pi, theta) + new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray) } @Since("1.5.0") @@ -240,6 +255,19 @@ class NaiveBayesModel private[ml] ( import NaiveBayes.{Bernoulli, Multinomial} /** + * mllib NaiveBayes is a wrapper of ml implementation currently. + * Input labels of mllib could be {-1, +1} and mllib NaiveBayesModel exposes labels, + * both of which are different from ml, so we should store the labels sequentially + * to be called by mllib. This should be removed when we remove mllib NaiveBayes. + */ + private[spark] var oldLabels: Array[Double] = null + + private[spark] def setOldLabels(labels: Array[Double]): this.type = { + this.oldLabels = labels + this + } + + /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra * application of this condition (in predict function). http://git-wip-us.apache.org/repos/asf/spark/blob/21cb59f1/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 32d6968..33561be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -364,15 +364,10 @@ class NaiveBayes private ( val nb = new NewNaiveBayes() .setModelType(modelType) .setSmoothing(lambda) + .setIsML(false) - val labels = data.map(_.label).distinct().collect().sorted - - // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must be - // in range [0, numClasses). - val dataset = data.map { - case LabeledPoint(label, features) => - (labels.indexOf(label).toDouble, features.asML) - }.toDF("label", "features") + val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) } + .toDF("label", "features") val newModel = nb.fit(dataset) @@ -383,7 +378,9 @@ class NaiveBayes private ( theta(i)(j) = v } - new NaiveBayesModel(labels, pi, theta, modelType) + require(newModel.oldLabels != null, + "The underlying ML NaiveBayes training does not produce labels.") + new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org