srowen commented on a change in pull request #26413: [SPARK-16872][ML][PYSPARK]
Impl Gaussian Naive Bayes Classifier
URL: https://github.com/apache/spark/pull/26413#discussion_r344246733
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
##########
@@ -204,7 +224,80 @@ class NaiveBayes @Since("1.5.0") (
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
- new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
+ new NaiveBayesModel(uid, pi.compressed, theta.compressed, null)
+ .setOldLabels(labelArray)
+ }
+
+ private def trainGaussianImpl(
+ dataset: Dataset[_],
+ instr: Instrumentation): NaiveBayesModel = {
+ val spark = dataset.sparkSession
+ import spark.implicits._
+
+ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
+ col($(weightCol)).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
+
+ // Aggregates mean vector and square-sum vector per label.
+ // TODO: Summarizer directly returns square-sum vector.
+ val aggregated = dataset.groupBy(col($(labelCol)))
+ .agg(sum(w).as("weightSum"), Summarizer.metrics("mean", "normL2")
+ .summary(col($(featuresCol)), w).as("summary"))
+ .select($(labelCol), "weightSum", "summary.mean", "summary.normL2")
+ .as[(Double, Double, Vector, Vector)]
+ .map { case (label, weightSum, mean, normL2) =>
+ (label, weightSum, mean, Vectors.dense(normL2.toArray.map(v => v * v)))
+ }.collect().sortBy(_._1)
+
+ val numFeatures = aggregated.head._3.size
+ instr.logNumFeatures(numFeatures)
+
+ val numLabels = aggregated.length
+ instr.logNumClasses(numLabels)
+
+ val numInstances = aggregated.map(_._2).sum
+
+ // If the ratio of data variance between dimensions is too small, it
+ // will cause numerical errors. To address this, we artificially
+ // boost the variance by epsilon, a small fraction of the standard
+ // deviation of the largest dimension.
+ // Refer to scikit-learn's implementation
+ //
[https://github.com/scikit-learn/scikit-learn/blob/0.21.X/sklearn/naive_bayes.py#L348]
+ // and discussion [https://github.com/scikit-learn/scikit-learn/pull/5349]
for detail.
+ val epsilon = Iterator.range(0, numFeatures).map { j =>
+ val globalSum = aggregated.map(t => t._3(j) * t._2).sum
+ val globalSqrSum = aggregated.map(t => t._4(j)).sum
+ globalSqrSum / numInstances -
+ globalSum * globalSum / numInstances / numInstances
+ }.max * 1e-9
+
+ val piArray = new Array[Double](numLabels)
+
+ // thetaArray in Gaussian NB store the means of features per label
+ val thetaArray = new Array[Double](numLabels * numFeatures)
+
+ // thetaArray in Gaussian NB store the variances of features per label
+ val sigmaArray = new Array[Double](numLabels * numFeatures)
+
+ var i = 0
+ aggregated.foreach { case (_, n, mean, squareSum) =>
+ piArray(i) = math.log(n / numInstances)
+ var j = 0
+ while (j < numFeatures) {
+ val m = mean(j)
+ thetaArray(i * numFeatures + j) = m
Review comment:
Also don't know if it matters, but `i * numFeatures` could be outside the
inner loop. Or at least the offset computed once
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]