Repository: spark
Updated Branches:
  refs/heads/master bd201bf61 -> 8a72734f3


[SPARK-15009][PYTHON][ML] Construct a CountVectorizerModel from a vocabulary 
list

## What changes were proposed in this pull request?

Added a class method to construct CountVectorizerModel from a list of 
vocabulary strings, equivalent to the Scala version.  Introduced a common param 
base class `_CountVectorizerParams` to allow the Python model to also own the 
parameters.  This now matches the Scala class hierarchy.

## How was this patch tested?

Added to CountVectorizer doctests to do a transform on a model constructed from 
vocab, and unit test to verify params and vocab are constructed correctly.

Author: Bryan Cutler <cutl...@gmail.com>

Closes #16770 from 
BryanCutler/pyspark-CountVectorizerModel-vocab_ctor-SPARK-15009.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a72734f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a72734f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a72734f

Branch: refs/heads/master
Commit: 8a72734f33f6a0abbd3207b0d661633c8b25d9ad
Parents: bd201bf
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Fri Mar 16 11:42:57 2018 -0700
Committer: Holden Karau <hol...@pigscanfly.ca>
Committed: Fri Mar 16 11:42:57 2018 -0700

----------------------------------------------------------------------
 python/pyspark/ml/feature.py | 168 +++++++++++++++++++++++++-------------
 python/pyspark/ml/tests.py   |  32 +++++++-
 2 files changed, 142 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8a72734f/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index f2e357f..a1ceb7f 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -19,12 +19,12 @@ import sys
 if sys.version > '3':
     basestring = str
 
-from pyspark import since, keyword_only
+from pyspark import since, keyword_only, SparkContext
 from pyspark.rdd import ignore_unicode_prefix
 from pyspark.ml.linalg import _convert_to_vector
 from pyspark.ml.param.shared import *
 from pyspark.ml.util import JavaMLReadable, JavaMLWritable
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, 
JavaTransformer, _jvm
 from pyspark.ml.common import inherit_doc
 
 __all__ = ['Binarizer',
@@ -403,8 +403,69 @@ class Bucketizer(JavaTransformer, HasInputCol, 
HasOutputCol, HasHandleInvalid,
         return self.getOrDefault(self.splits)
 
 
+class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
+    """
+    Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`.
+    """
+
+    minTF = Param(
+        Params._dummy(), "minTF", "Filter to ignore rare words in" +
+        " a document. For each document, terms with frequency/count less than 
the given" +
+        " threshold are ignored. If this is an integer >= 1, then this 
specifies a count (of" +
+        " times the term must appear in the document); if this is a double in 
[0,1), then this " +
+        "specifies a fraction (out of the document's token count). Note that 
the parameter is " +
+        "only used in transform of CountVectorizerModel and does not affect 
fitting. Default 1.0",
+        typeConverter=TypeConverters.toFloat)
+    minDF = Param(
+        Params._dummy(), "minDF", "Specifies the minimum number of" +
+        " different documents a term must appear in to be included in the 
vocabulary." +
+        " If this is an integer >= 1, this specifies the number of documents 
the term must" +
+        " appear in; if this is a double in [0,1), then this specifies the 
fraction of documents." +
+        " Default 1.0", typeConverter=TypeConverters.toFloat)
+    vocabSize = Param(
+        Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 
<< 18.",
+        typeConverter=TypeConverters.toInt)
+    binary = Param(
+        Params._dummy(), "binary", "Binary toggle to control the output vector 
values." +
+        " If True, all nonzero counts (after minTF filter applied) are set to 
1. This is useful" +
+        " for discrete probabilistic models that model binary events rather 
than integer counts." +
+        " Default False", typeConverter=TypeConverters.toBoolean)
+
+    def __init__(self, *args):
+        super(_CountVectorizerParams, self).__init__(*args)
+        self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
+
+    @since("1.6.0")
+    def getMinTF(self):
+        """
+        Gets the value of minTF or its default value.
+        """
+        return self.getOrDefault(self.minTF)
+
+    @since("1.6.0")
+    def getMinDF(self):
+        """
+        Gets the value of minDF or its default value.
+        """
+        return self.getOrDefault(self.minDF)
+
+    @since("1.6.0")
+    def getVocabSize(self):
+        """
+        Gets the value of vocabSize or its default value.
+        """
+        return self.getOrDefault(self.vocabSize)
+
+    @since("2.0.0")
+    def getBinary(self):
+        """
+        Gets the value of binary or its default value.
+        """
+        return self.getOrDefault(self.binary)
+
+
 @inherit_doc
-class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, 
JavaMLReadable, JavaMLWritable):
+class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, 
JavaMLWritable):
     """
     Extracts a vocabulary from document collections and generates a 
:py:attr:`CountVectorizerModel`.
 
@@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable,
     >>> loadedModel = CountVectorizerModel.load(modelPath)
     >>> loadedModel.vocabulary == model.vocabulary
     True
+    >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"],
+    ...     inputCol="raw", outputCol="vectors")
+    >>> fromVocabModel.transform(df).show(truncate=False)
+    +-----+---------------+-------------------------+
+    |label|raw            |vectors                  |
+    +-----+---------------+-------------------------+
+    |0    |[a, b, c]      |(3,[0,1,2],[1.0,1.0,1.0])|
+    |1    |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
+    +-----+---------------+-------------------------+
+    ...
 
     .. versionadded:: 1.6.0
     """
 
-    minTF = Param(
-        Params._dummy(), "minTF", "Filter to ignore rare words in" +
-        " a document. For each document, terms with frequency/count less than 
the given" +
-        " threshold are ignored. If this is an integer >= 1, then this 
specifies a count (of" +
-        " times the term must appear in the document); if this is a double in 
[0,1), then this " +
-        "specifies a fraction (out of the document's token count). Note that 
the parameter is " +
-        "only used in transform of CountVectorizerModel and does not affect 
fitting. Default 1.0",
-        typeConverter=TypeConverters.toFloat)
-    minDF = Param(
-        Params._dummy(), "minDF", "Specifies the minimum number of" +
-        " different documents a term must appear in to be included in the 
vocabulary." +
-        " If this is an integer >= 1, this specifies the number of documents 
the term must" +
-        " appear in; if this is a double in [0,1), then this specifies the 
fraction of documents." +
-        " Default 1.0", typeConverter=TypeConverters.toFloat)
-    vocabSize = Param(
-        Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 
<< 18.",
-        typeConverter=TypeConverters.toInt)
-    binary = Param(
-        Params._dummy(), "binary", "Binary toggle to control the output vector 
values." +
-        " If True, all nonzero counts (after minTF filter applied) are set to 
1. This is useful" +
-        " for discrete probabilistic models that model binary events rather 
than integer counts." +
-        " Default False", typeConverter=TypeConverters.toBoolean)
-
     @keyword_only
     def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, 
inputCol=None,
                  outputCol=None):
@@ -474,7 +522,6 @@ class CountVectorizer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable,
         super(CountVectorizer, self).__init__()
         self._java_obj = 
self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
                                             self.uid)
-        self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
@@ -498,13 +545,6 @@ class CountVectorizer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable,
         return self._set(minTF=value)
 
     @since("1.6.0")
-    def getMinTF(self):
-        """
-        Gets the value of minTF or its default value.
-        """
-        return self.getOrDefault(self.minTF)
-
-    @since("1.6.0")
     def setMinDF(self, value):
         """
         Sets the value of :py:attr:`minDF`.
@@ -512,26 +552,12 @@ class CountVectorizer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable,
         return self._set(minDF=value)
 
     @since("1.6.0")
-    def getMinDF(self):
-        """
-        Gets the value of minDF or its default value.
-        """
-        return self.getOrDefault(self.minDF)
-
-    @since("1.6.0")
     def setVocabSize(self, value):
         """
         Sets the value of :py:attr:`vocabSize`.
         """
         return self._set(vocabSize=value)
 
-    @since("1.6.0")
-    def getVocabSize(self):
-        """
-        Gets the value of vocabSize or its default value.
-        """
-        return self.getOrDefault(self.vocabSize)
-
     @since("2.0.0")
     def setBinary(self, value):
         """
@@ -539,24 +565,40 @@ class CountVectorizer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable,
         """
         return self._set(binary=value)
 
-    @since("2.0.0")
-    def getBinary(self):
-        """
-        Gets the value of binary or its default value.
-        """
-        return self.getOrDefault(self.binary)
-
     def _create_model(self, java_model):
         return CountVectorizerModel(java_model)
 
 
-class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
+@inherit_doc
+class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, 
JavaMLWritable):
     """
     Model fitted by :py:class:`CountVectorizer`.
 
     .. versionadded:: 1.6.0
     """
 
+    @classmethod
+    @since("2.4.0")
+    def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, 
binary=None):
+        """
+        Construct the model directly from a vocabulary list of strings,
+        requires an active SparkContext.
+        """
+        sc = SparkContext._active_spark_context
+        java_class = sc._gateway.jvm.java.lang.String
+        jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
+        model = CountVectorizerModel._create_from_java_class(
+            "org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
+        model.setInputCol(inputCol)
+        if outputCol is not None:
+            model.setOutputCol(outputCol)
+        if minTF is not None:
+            model.setMinTF(minTF)
+        if binary is not None:
+            model.setBinary(binary)
+        model._set(vocabSize=len(vocabulary))
+        return model
+
     @property
     @since("1.6.0")
     def vocabulary(self):
@@ -565,6 +607,20 @@ class CountVectorizerModel(JavaModel, JavaMLReadable, 
JavaMLWritable):
         """
         return self._call_java("vocabulary")
 
+    @since("2.4.0")
+    def setMinTF(self, value):
+        """
+        Sets the value of :py:attr:`minTF`.
+        """
+        return self._set(minTF=value)
+
+    @since("2.4.0")
+    def setBinary(self, value):
+        """
+        Sets the value of :py:attr:`binary`.
+        """
+        return self._set(binary=value)
+
 
 @inherit_doc
 class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, 
JavaMLWritable):

http://git-wip-us.apache.org/repos/asf/spark/blob/8a72734f/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 6dee693..fd45fd0 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -679,6 +679,34 @@ class FeatureTests(SparkSessionTestCase):
             feature, expected = r
             self.assertEqual(feature, expected)
 
+    def test_count_vectorizer_from_vocab(self):
+        model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], 
inputCol="words",
+                                                     outputCol="features", 
minTF=2)
+        self.assertEqual(model.vocabulary, ["a", "b", "c"])
+        self.assertEqual(model.getMinTF(), 2)
+
+        dataset = self.spark.createDataFrame([
+            (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),),
+            (1, "a a".split(' '), SparseVector(3, {0: 2.0}),),
+            (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", 
"expected"])
+
+        transformed_list = model.transform(dataset).select("features", 
"expected").collect()
+
+        for r in transformed_list:
+            feature, expected = r
+            self.assertEqual(feature, expected)
+
+        # Test an empty vocabulary
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
+                CountVectorizerModel.from_vocabulary([], inputCol="words")
+
+        # Test model with default settings can transform
+        model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], 
inputCol="words")
+        transformed_list = model_default.transform(dataset)\
+            
.select(model_default.getOrDefault(model_default.outputCol)).collect()
+        self.assertEqual(len(transformed_list), 3)
+
     def test_rformula_force_index_label(self):
         df = self.spark.createDataFrame([
             (1.0, 1.0, "a"),
@@ -2019,8 +2047,8 @@ class DefaultValuesTests(PySparkTestCase):
                    pyspark.ml.regression]
         for module in modules:
             for name, cls in inspect.getmembers(module, inspect.isclass):
-                if not name.endswith('Model') and issubclass(cls, JavaParams)\
-                        and not inspect.isabstract(cls):
+                if not name.endswith('Model') and not name.endswith('Params')\
+                        and issubclass(cls, JavaParams) and not 
inspect.isabstract(cls):
                     # NOTE: disable check_params_exist until there is parity 
with Scala API
                     ParamTests.check_params(self, cls(), 
check_params_exist=False)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to