Repository: spark
Updated Branches:
  refs/heads/master b8c32dc57 -> c40fda9e4


[SPARK-23166][ML] Add maxDF Parameter to CountVectorizer

## What changes were proposed in this pull request?
Currently, the CountVectorizer has a minDF parameter.

It might be useful to also have a maxDF parameter.
It will be used as a threshold for filtering all the terms that occur very 
frequently in a text corpus, because they are not very informative or could 
even be stop-words.

This is analogous to scikit-learn, CountVectorizer, max_df.

Other changes:
- Refactored code to invoke "filter()" conditioned on maxDF or minDF set.
- Refactored code to unpersist input after counting is done.

## How was this patch tested?
Unit tests.

Author: Yacine Mazari <y.maz...@gmail.com>

Closes #20367 from ymazari/SPARK-23166.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c40fda9e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c40fda9e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c40fda9e

Branch: refs/heads/master
Commit: c40fda9e4cf32d6cd17af2ace959bbbbe7c782a4
Parents: b8c32dc
Author: Yacine Mazari <y.maz...@gmail.com>
Authored: Sun Jan 28 10:27:59 2018 -0600
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Jan 28 10:27:59 2018 -0600

----------------------------------------------------------------------
 .../spark/ml/feature/CountVectorizer.scala      | 67 +++++++++++++++---
 .../spark/ml/feature/CountVectorizerSuite.scala | 72 ++++++++++++++++++++
 2 files changed, 131 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c40fda9e/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 1ebe297..60a4f91 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -69,6 +69,25 @@ private[feature] trait CountVectorizerParams extends Params 
with HasInputCol wit
   /** @group getParam */
   def getMinDF: Double = $(minDF)
 
+  /**
+   * Specifies the maximum number of different documents a term must appear in 
to be included
+   * in the vocabulary.
+   * If this is an integer greater than or equal to 1, this specifies the 
number of documents
+   * the term must appear in; if this is a double in [0,1), then this 
specifies the fraction of
+   * documents.
+   *
+   * Default: (2^64^) - 1
+   * @group param
+   */
+  val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the 
maximum number of" +
+    " different documents a term must appear in to be included in the 
vocabulary." +
+    " If this is an integer >= 1, this specifies the number of documents the 
term must" +
+    " appear in; if this is a double in [0,1), then this specifies the 
fraction of documents.",
+    ParamValidators.gtEq(0.0))
+
+  /** @group getParam */
+  def getMaxDF: Double = $(maxDF)
+
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
     val typeCandidates = List(new ArrayType(StringType, true), new 
ArrayType(StringType, false))
@@ -113,7 +132,11 @@ private[feature] trait CountVectorizerParams extends 
Params with HasInputCol wit
   /** @group getParam */
   def getBinary: Boolean = $(binary)
 
-  setDefault(vocabSize -> (1 << 18), minDF -> 1.0, minTF -> 1.0, binary -> 
false)
+  setDefault(vocabSize -> (1 << 18),
+    minDF -> 1.0,
+    maxDF -> Long.MaxValue,
+    minTF -> 1.0,
+    binary -> false)
 }
 
 /**
@@ -143,6 +166,10 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") 
override val uid: String)
   def setMinDF(value: Double): this.type = set(minDF, value)
 
   /** @group setParam */
+  @Since("2.4.0")
+  def setMaxDF(value: Double): this.type = set(maxDF, value)
+
+  /** @group setParam */
   @Since("1.5.0")
   def setMinTF(value: Double): this.type = set(minTF, value)
 
@@ -155,12 +182,24 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") 
override val uid: String)
     transformSchema(dataset.schema, logging = true)
     val vocSize = $(vocabSize)
     val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
+    val countingRequired = $(minDF) < 1.0 || $(maxDF) < 1.0
+    val maybeInputSize = if (countingRequired) {
+      Some(input.cache().count())
+    } else {
+      None
+    }
     val minDf = if ($(minDF) >= 1.0) {
       $(minDF)
     } else {
-      $(minDF) * input.cache().count()
+      $(minDF) * maybeInputSize.get
     }
-    val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) =>
+    val maxDf = if ($(maxDF) >= 1.0) {
+      $(maxDF)
+    } else {
+      $(maxDF) * maybeInputSize.get
+    }
+    require(maxDf >= minDf, "maxDF must be >= minDF.")
+    val allWordCounts = input.flatMap { case (tokens) =>
       val wc = new OpenHashMap[String, Long]
       tokens.foreach { w =>
         wc.changeValue(w, 1L, _ + 1L)
@@ -168,11 +207,23 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") 
override val uid: String)
       wc.map { case (word, count) => (word, (count, 1)) }
     }.reduceByKey { case ((wc1, df1), (wc2, df2)) =>
       (wc1 + wc2, df1 + df2)
-    }.filter { case (word, (wc, df)) =>
-      df >= minDf
-    }.map { case (word, (count, dfCount)) =>
-      (word, count)
-    }.cache()
+    }
+
+    val filteringRequired = isSet(minDF) || isSet(maxDF)
+    val maybeFilteredWordCounts = if (filteringRequired) {
+      allWordCounts.filter { case (_, (_, df)) => df >= minDf && df <= maxDf }
+    } else {
+      allWordCounts
+    }
+
+    val wordCounts = maybeFilteredWordCounts
+      .map { case (word, (count, _)) => (word, count) }
+      .cache()
+
+    if (countingRequired) {
+      input.unpersist()
+    }
+
     val fullVocabSize = wordCounts.count()
 
     val vocab = wordCounts

http://git-wip-us.apache.org/repos/asf/spark/blob/c40fda9e/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index f213145..1784c07 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -119,6 +119,78 @@ class CountVectorizerSuite extends SparkFunSuite with 
MLlibTestSparkContext
     }
   }
 
+  test("CountVectorizer maxDF") {
+    val df = Seq(
+      (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0), (2, 
1.0)))),
+      (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
+      (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0)))),
+      (3, split("a"), Vectors.sparse(3, Seq()))
+    ).toDF("id", "words", "expected")
+
+    // maxDF: ignore terms with count more than 3
+    val cvModel = new CountVectorizer()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMaxDF(3)
+      .fit(df)
+    assert(cvModel.vocabulary === Array("b", "c", "d"))
+
+    cvModel.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+
+    // maxDF: ignore terms with freq > 0.75
+    val cvModel2 = new CountVectorizer()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMaxDF(0.75)
+      .fit(df)
+    assert(cvModel2.vocabulary === Array("b", "c", "d"))
+
+    cvModel2.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+  }
+
+  test("CountVectorizer using both minDF and maxDF") {
+    // Ignore terms with count more than 3 AND less than 2
+    val df = Seq(
+      (0, split("a b c d"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))),
+      (1, split("a b c"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))),
+      (2, split("a b"), Vectors.sparse(2, Seq((0, 1.0)))),
+      (3, split("a"), Vectors.sparse(2, Seq()))
+    ).toDF("id", "words", "expected")
+
+    val cvModel = new CountVectorizer()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMinDF(2)
+      .setMaxDF(3)
+      .fit(df)
+    assert(cvModel.vocabulary === Array("b", "c"))
+
+    cvModel.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+
+    // Ignore terms with frequency higher than 0.75 AND less than 0.5
+    val cvModel2 = new CountVectorizer()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMinDF(0.5)
+      .setMaxDF(0.75)
+      .fit(df)
+    assert(cvModel2.vocabulary === Array("b", "c"))
+
+    cvModel2.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+  }
+
   test("CountVectorizer throws exception when vocab is empty") {
     intercept[IllegalArgumentException] {
       val df = Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to