huaxingao commented on a change in pull request #29482:
URL: https://github.com/apache/spark/pull/29482#discussion_r474144300
##########
File path:
mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
##########
@@ -305,4 +290,63 @@ class CountVectorizerSuite extends MLTest with
DefaultReadWriteTest {
.setOutputCol("features")
interaction.transform(df1)
}
+
+ test("SPARK-32662: Test on empty dataset") {
+ val df = Seq[(Int, Array[String])]().toDF("id", "words")
+ val cvModel = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .fit(df)
+ assert(cvModel.vocabulary.isEmpty === true)
+ val ans = cvModel.transform(df).select("features").collect()
+ assert(ans.length === 0)
+ }
+
+ test("SPARK-32662: Remove requirement for minimum vocabulary size") {
+ val df = Seq(
+ (0, Array[String]()),
+ (1, Array[String]())
+ ).toDF("id", "words")
+ val cvModel = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .fit(df)
+ assert(cvModel.vocabulary.isEmpty === true)
+ cvModel.transform(df).select("features").collect().foreach {
+ case Row(features: Vector) =>
+ assert(features === Vectors.sparse(0, Seq()))
+ }
+
+ val df2 = Seq(
+ (0, Array("a", "b", "c")),
+ (1, Array("d", "e")),
+ (2, Array[String]())
+ ).toDF("id", "words")
+ val cvModel2 = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setMinDF(2)
+ .fit(df2)
+ assert(cvModel2.vocabulary.isEmpty === true)
+ cvModel2.transform(df2).select("features").collect().foreach {
+ case Row(features: Vector) =>
+ assert(features === Vectors.sparse(0, Seq()))
+ }
+
+ val df3 = Seq(
+ (0, Array("a")),
+ (1, Array("a")),
+ (2, Array("a"))
+ ).toDF("id", "words")
+ val cvModel3 = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setMaxDF(2)
+ .fit(df3)
+ assert(cvModel3.vocabulary.isEmpty === true)
+ cvModel3.transform(df3).select("features").collect().foreach {
+ case Row(features: Vector) =>
+ assert(features === Vectors.sparse(0, Seq()))
+ }
Review comment:
nit: could you please call `testTransformer` so your new tests have the
same pattern as other tests? Thanks!
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]