zhengruifeng commented on a change in pull request #26413:
[SPARK-16872][ML][PYSPARK] Impl Gaussian Naive Bayes Classifier
URL: https://github.com/apache/spark/pull/26413#discussion_r343489454
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
##########
@@ -138,44 +150,55 @@ class NaiveBayes @Since("1.5.0") (
s" numClasses=$numClasses, but thresholds has length
${$(thresholds).length}")
}
- val validateInstance = $(modelType) match {
- case Multinomial =>
- (instance: Instance) => requireNonnegativeValues(instance.features)
- case Bernoulli =>
- (instance: Instance) =>
requireZeroOneBernoulliValues(instance.features)
+ $(modelType) match {
+ case Bernoulli | Multinomial =>
+ trainDiscreteImpl(dataset, instr)
+ case Gaussian =>
+ trainGaussianImpl(dataset, instr)
case _ =>
// This should never happen.
throw new IllegalArgumentException(s"Invalid modelType:
${$(modelType)}.")
}
+ }
- instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol,
rawPredictionCol,
- probabilityCol, modelType, smoothing, thresholds)
+ private def trainDiscreteImpl(
+ dataset: Dataset[_],
+ instr: Instrumentation): NaiveBayesModel = {
+ val spark = dataset.sparkSession
+ import spark.implicits._
- val numFeatures =
dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
- instr.logNumFeatures(numFeatures)
+ val validateUDF = $(modelType) match {
+ case Multinomial =>
+ udf { vector: Vector => requireNonnegativeValues(vector); vector }
+ case Bernoulli =>
+ udf { vector: Vector => requireZeroOneBernoulliValues(vector); vector }
+ }
+
+ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
+ col($(weightCol)).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
// Aggregates term frequencies per label.
- // TODO: Calling aggregateByKey and collect creates two stages, we can
implement something
- // TODO: similar to reduceByKeyLocally to save one stage.
- val aggregated = extractInstances(dataset, validateInstance).map {
instance =>
- (instance.label, (instance.weight, instance.features))
- }.aggregateByKey[(Double, DenseVector, Long)]((0.0,
Vectors.zeros(numFeatures).toDense, 0L))(
- seqOp = {
- case ((weightSum, featureSum, count), (weight, features)) =>
- BLAS.axpy(weight, features, featureSum)
- (weightSum + weight, featureSum, count + 1)
- },
- combOp = {
- case ((weightSum1, featureSum1, count1), (weightSum2, featureSum2,
count2)) =>
- BLAS.axpy(1.0, featureSum2, featureSum1)
- (weightSum1 + weightSum2, featureSum1, count1 + count2)
- }).collect().sortBy(_._1)
-
- val numSamples = aggregated.map(_._2._3).sum
+ // TODO: Summarizer directly returns sum vector.
Review comment:
here, I use dataset OPs instead of previous RDD, and found that the speeds
are almost the same.
```scala
import org.apache.spark.ml.feature._
import org.apache.spark.ml.regression._
import org.apache.spark.ml.classification._
var df = spark.read.format("libsvm").load("/data1/Datasets/a9a/a9a")
df.persist()
df.count
(0 until 8).foreach(_ => df = df.union(df))
df.count
val nb = new NaiveBayes()
val durations = (0 until 50).map{i => val tic = System.currentTimeMillis;
val model = nb.fit(df); val toc = System.currentTimeMillis; toc - tic}
durations.takeRight(30).sum.toDouble / 30
```
Previous impl based on RDD: 25535.0
this PR: 25261.7
So I use dataset-based one for consistency and simplity.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]