srowen commented on a change in pull request #26413: [SPARK-16872][ML][PYSPARK]
Impl Gaussian Naive Bayes Classifier
URL: https://github.com/apache/spark/pull/26413#discussion_r344249771
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
##########
@@ -204,7 +224,80 @@ class NaiveBayes @Since("1.5.0") (
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
- new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
+ new NaiveBayesModel(uid, pi.compressed, theta.compressed, null)
+ .setOldLabels(labelArray)
+ }
+
+ private def trainGaussianImpl(
+ dataset: Dataset[_],
+ instr: Instrumentation): NaiveBayesModel = {
+ val spark = dataset.sparkSession
+ import spark.implicits._
+
+ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
+ col($(weightCol)).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
+
+ // Aggregates mean vector and square-sum vector per label.
+ // TODO: Summarizer directly returns square-sum vector.
+ val aggregated = dataset.groupBy(col($(labelCol)))
+ .agg(sum(w).as("weightSum"), Summarizer.metrics("mean", "normL2")
+ .summary(col($(featuresCol)), w).as("summary"))
+ .select($(labelCol), "weightSum", "summary.mean", "summary.normL2")
+ .as[(Double, Double, Vector, Vector)]
+ .map { case (label, weightSum, mean, normL2) =>
+ (label, weightSum, mean, Vectors.dense(normL2.toArray.map(v => v * v)))
+ }.collect().sortBy(_._1)
+
+ val numFeatures = aggregated.head._3.size
+ instr.logNumFeatures(numFeatures)
+
+ val numLabels = aggregated.length
+ instr.logNumClasses(numLabels)
+
+ val numInstances = aggregated.map(_._2).sum
+
+ // If the ratio of data variance between dimensions is too small, it
+ // will cause numerical errors. To address this, we artificially
+ // boost the variance by epsilon, a small fraction of the standard
+ // deviation of the largest dimension.
+ // Refer to scikit-learn's implementation
+ //
[https://github.com/scikit-learn/scikit-learn/blob/0.21.X/sklearn/naive_bayes.py#L348]
+ // and discussion [https://github.com/scikit-learn/scikit-learn/pull/5349]
for detail.
+ val epsilon = Iterator.range(0, numFeatures).map { j =>
+ val globalSum = aggregated.map(t => t._3(j) * t._2).sum
Review comment:
It doesn't matter much, but in a few cases around here it might be clearer
to use `case (_, x, y, _)` or whatever to name the variables in the expression
rather than `._2` etc
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]