Github user holdenk commented on a diff in the pull request:
https://github.com/apache/spark/pull/16770#discussion_r175179845
--- Diff: python/pyspark/ml/tests.py ---
@@ -640,6 +640,34 @@ def test_count_vectorizer_with_binary(self):
feature, expected = r
self.assertEqual(feature, expected)
+ def test_count_vectorizer_from_vocab(self):
+ model = CountVectorizerModel.from_vocabulary(["a", "b", "c"],
inputCol="words",
+ outputCol="features",
minTF=2)
+ self.assertEqual(model.vocabulary, ["a", "b", "c"])
+ self.assertEqual(model.getMinTF(), 2)
+
+ dataset = self.spark.createDataFrame([
+ (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1:
2.0}),),
+ (1, "a a".split(' '), SparseVector(3, {0: 2.0}),),
+ (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words",
"expected"])
+
+ transformed_list = model.transform(dataset).select("features",
"expected").collect()
+
+ for r in transformed_list:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
+ # Test an empty vocabulary
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception,
"vocabSize.*invalid.*0"):
+ CountVectorizerModel.from_vocabulary([], inputCol="words")
+
+ # Test model with default settings can transform
+ model_default = CountVectorizerModel.from_vocabulary(["a", "b",
"c"], inputCol="words")
+ transformed_list = model_default.transform(dataset)\
+
.select(model_default.getOrDefault(model_default.outputCol)).collect()
+ self.assertEqual(len(transformed_list), 3)
--- End diff --
sgtm
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]