Github user srowen commented on a diff in the pull request:
https://github.com/apache/spark/pull/21561#discussion_r209256004
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala ---
@@ -157,11 +157,15 @@ class NaiveBayes @Since("1.5.0") (
instr.logNumFeatures(numFeatures)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
else col($(weightCol))
+ val countAccum = dataset.sparkSession.sparkContext.longAccumulator
+
// Aggregates term frequencies per label.
// TODO: Calling aggregateByKey and collect creates two stages, we can
implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = dataset.select(col($(labelCol)), w,
col($(featuresCol))).rdd
- .map { row => (row.getDouble(0), (row.getDouble(1),
row.getAs[Vector](2)))
+ .map { row =>
+ countAccum.add(1L)
--- End diff --
Is this guaranteed to work correctly, given that this is in a map
operation? wondering if this introduces a correctness issue or whether this
number is available elsewhere.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]