Github user ymazari commented on a diff in the pull request:
https://github.com/apache/spark/pull/20367#discussion_r164275721
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala ---
@@ -155,24 +182,47 @@ class CountVectorizer @Since("1.5.0")
(@Since("1.5.0") override val uid: String)
transformSchema(dataset.schema, logging = true)
val vocSize = $(vocabSize)
val input =
dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
+ val filteringRequired = isSet(minDF) || isSet(maxDF)
+ val maybeInputSize = if (filteringRequired) {
+ Some(input.cache().count())
+ } else {
+ None
+ }
val minDf = if ($(minDF) >= 1.0) {
$(minDF)
} else {
- $(minDF) * input.cache().count()
+ $(minDF) * maybeInputSize.getOrElse(1L)
}
- val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) =>
+ val maxDf = if ($(maxDF) >= 1.0) {
+ $(maxDF)
+ } else {
+ $(maxDF) * maybeInputSize.getOrElse(1L)
+ }
+ require(maxDf >= minDf, "maxDF must be >= minDF.")
+ val allWordCounts = input.flatMap { case (tokens) =>
val wc = new OpenHashMap[String, Long]
tokens.foreach { w =>
wc.changeValue(w, 1L, _ + 1L)
}
wc.map { case (word, count) => (word, (count, 1)) }
}.reduceByKey { case ((wc1, df1), (wc2, df2)) =>
(wc1 + wc2, df1 + df2)
- }.filter { case (word, (wc, df)) =>
- df >= minDf
- }.map { case (word, (count, dfCount)) =>
- (word, count)
- }.cache()
+ }
+
+ val maybeFilteredWordCounts = if (filteringRequired) {
+ allWordCounts.filter { case (word, (wc, df)) => (df >= minDf) && (df
<= maxDf) }
--- End diff --
Changed.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]