Github user ymazari commented on a diff in the pull request:
https://github.com/apache/spark/pull/20367#discussion_r163465302
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala ---
@@ -119,6 +119,41 @@ class CountVectorizerSuite extends SparkFunSuite with
MLlibTestSparkContext
}
}
+ test("CountVectorizer maxDF") {
+ val df = Seq(
+ (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0), (2,
1.0)))),
+ (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
+ (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0)))),
+ (3, split("a"), Vectors.sparse(3, Seq()))
+ ).toDF("id", "words", "expected")
+
+ // maxDF: ignore terms with count more than 3
+ val cvModel = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setMaxDF(3)
+ .fit(df)
+ assert(cvModel.vocabulary === Array("b", "c", "d"))
+
+ cvModel.transform(df).select("features", "expected").collect().foreach
{
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+
+ // maxDF: ignore terms with freq > 0.75
+ val cvModel2 = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setMaxDF(0.75)
+ .fit(df)
+ assert(cvModel2.vocabulary === Array("b", "c", "d"))
+
+ cvModel2.transform(df).select("features",
"expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
--- End diff --
Done.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]