Github user ymazari commented on a diff in the pull request:
https://github.com/apache/spark/pull/20367#discussion_r164275712
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala ---
@@ -155,24 +182,47 @@ class CountVectorizer @Since("1.5.0")
(@Since("1.5.0") override val uid: String)
transformSchema(dataset.schema, logging = true)
val vocSize = $(vocabSize)
val input =
dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
+ val filteringRequired = isSet(minDF) || isSet(maxDF)
+ val maybeInputSize = if (filteringRequired) {
+ Some(input.cache().count())
+ } else {
+ None
+ }
val minDf = if ($(minDF) >= 1.0) {
$(minDF)
} else {
- $(minDF) * input.cache().count()
+ $(minDF) * maybeInputSize.getOrElse(1L)
}
- val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) =>
+ val maxDf = if ($(maxDF) >= 1.0) {
+ $(maxDF)
+ } else {
+ $(maxDF) * maybeInputSize.getOrElse(1L)
--- End diff --
Changed.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]