Repository: spark
Updated Branches:
  refs/heads/master 0d4a69527 -> 21cb59f1c


[SPARK-17835][ML][MLLIB] Optimize NaiveBayes mllib wrapper to eliminate extra 
pass on data

## What changes were proposed in this pull request?
[SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077) copied the 
```NaiveBayes``` implementation from mllib to ml and left mllib as a wrapper. 
However, there are some difference between mllib and ml to handle labels:
* mllib allow input labels as {-1, +1}, however, ml assumes the input labels in 
range [0, numClasses).
* mllib ```NaiveBayesModel``` expose ```labels``` but ml did not due to the 
assumption mention above.

During the copy in 
[SPARK-14077](https://issues.apache.org/jira/browse/SPARK-14077), we use
```val labels = data.map(_.label).distinct().collect().sorted```
to get the distinct labels firstly, and then encode the labels for training. It 
involves extra Spark job compared with the original implementation. Since 
```NaiveBayes``` only do one pass aggregation during training, adding another 
one seems less efficient. We can get the labels in a single pass along with 
```NaiveBayes``` training and send them to MLlib side.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #15402 from yanboliang/spark-17835.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21cb59f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21cb59f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21cb59f1

Branch: refs/heads/master
Commit: 21cb59f1cd137d96b2596f1abe691b544581cf59
Parents: 0d4a695
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Oct 12 19:56:40 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Wed Oct 12 19:56:40 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/NaiveBayes.scala    | 46 ++++++++++++++++----
 .../spark/mllib/classification/NaiveBayes.scala | 15 +++----
 2 files changed, 43 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21cb59f1/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index e565a6f..994ed99 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -110,16 +110,28 @@ class NaiveBayes @Since("1.5.0") (
   @Since("2.1.0")
   def setWeightCol(value: String): this.type = set(weightCol, value)
 
-  override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
-    val numClasses = getNumClasses(dataset)
+  /**
+   * ml assumes input labels in range [0, numClasses). But this implementation
+   * is also called by mllib NaiveBayes which allows other kinds of input 
labels
+   * such as {-1, +1}. Here we use this parameter to switch between different 
processing logic.
+   * It should be removed when we remove mllib NaiveBayes.
+   */
+  private[spark] var isML: Boolean = true
 
-    if (isDefined(thresholds)) {
-      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
-        ".train() called with non-matching numClasses and thresholds.length." +
-        s" numClasses=$numClasses, but thresholds has length 
${$(thresholds).length}")
-    }
+  private[spark] def setIsML(isML: Boolean): this.type = {
+    this.isML = isML
+    this
+  }
 
-    val numFeatures = 
dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
+  override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+    if (isML) {
+      val numClasses = getNumClasses(dataset)
+      if (isDefined(thresholds)) {
+        require($(thresholds).length == numClasses, 
this.getClass.getSimpleName +
+          ".train() called with non-matching numClasses and 
thresholds.length." +
+          s" numClasses=$numClasses, but thresholds has length 
${$(thresholds).length}")
+      }
+    }
 
     val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
       val values = v match {
@@ -153,6 +165,7 @@ class NaiveBayes @Since("1.5.0") (
       }
     }
 
+    val numFeatures = 
dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else 
col($(weightCol))
 
     // Aggregates term frequencies per label.
@@ -176,6 +189,7 @@ class NaiveBayes @Since("1.5.0") (
     val numLabels = aggregated.length
     val numDocuments = aggregated.map(_._2._1).sum
 
+    val labelArray = new Array[Double](numLabels)
     val piArray = new Array[Double](numLabels)
     val thetaArray = new Array[Double](numLabels * numFeatures)
 
@@ -183,6 +197,7 @@ class NaiveBayes @Since("1.5.0") (
     val piLogDenom = math.log(numDocuments + numLabels * lambda)
     var i = 0
     aggregated.foreach { case (label, (n, sumTermFreqs)) =>
+      labelArray(i) = label
       piArray(i) = math.log(n + lambda) - piLogDenom
       val thetaLogDenom = $(modelType) match {
         case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * 
lambda)
@@ -201,7 +216,7 @@ class NaiveBayes @Since("1.5.0") (
 
     val pi = Vectors.dense(piArray)
     val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
-    new NaiveBayesModel(uid, pi, theta)
+    new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
   }
 
   @Since("1.5.0")
@@ -240,6 +255,19 @@ class NaiveBayesModel private[ml] (
   import NaiveBayes.{Bernoulli, Multinomial}
 
   /**
+   * mllib NaiveBayes is a wrapper of ml implementation currently.
+   * Input labels of mllib could be {-1, +1} and mllib NaiveBayesModel exposes 
labels,
+   * both of which are different from ml, so we should store the labels 
sequentially
+   * to be called by mllib. This should be removed when we remove mllib 
NaiveBayes.
+   */
+  private[spark] var oldLabels: Array[Double] = null
+
+  private[spark] def setOldLabels(labels: Array[Double]): this.type = {
+    this.oldLabels = labels
+    this
+  }
+
+  /**
    * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
    * This precomputes log(1.0 - exp(theta)) and its sum which are used for the 
linear algebra
    * application of this condition (in predict function).

http://git-wip-us.apache.org/repos/asf/spark/blob/21cb59f1/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 32d6968..33561be 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -364,15 +364,10 @@ class NaiveBayes private (
     val nb = new NewNaiveBayes()
       .setModelType(modelType)
       .setSmoothing(lambda)
+      .setIsML(false)
 
-    val labels = data.map(_.label).distinct().collect().sorted
-
-    // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must 
be
-    // in range [0, numClasses).
-    val dataset = data.map {
-      case LabeledPoint(label, features) =>
-        (labels.indexOf(label).toDouble, features.asML)
-    }.toDF("label", "features")
+    val dataset = data.map { case LabeledPoint(label, features) => (label, 
features.asML) }
+      .toDF("label", "features")
 
     val newModel = nb.fit(dataset)
 
@@ -383,7 +378,9 @@ class NaiveBayes private (
         theta(i)(j) = v
     }
 
-    new NaiveBayesModel(labels, pi, theta, modelType)
+    require(newModel.oldLabels != null,
+      "The underlying ML NaiveBayes training does not produce labels.")
+    new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to