Repository: spark
Updated Branches:
  refs/heads/master 30fcdc038 -> 816963043


[SPARK-22734][ML][PYSPARK] Added Python API for VectorSizeHint.

(Please fill in changes proposed in this fix)

Python API for VectorSizeHint Transformer.

(Please explain how this patch was tested. E.g. unit tests, integration tests, 
manual tests)

doc-tests.

Author: Bago Amirbekian <[email protected]>

Closes #20112 from MrBago/vectorSizeHint-PythonAPI.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/81696304
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/81696304
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/81696304

Branch: refs/heads/master
Commit: 816963043a09b082e8b9280dfa97fdaa19211015
Parents: 30fcdc0
Author: Bago Amirbekian <[email protected]>
Authored: Fri Dec 29 19:45:14 2017 -0800
Committer: Joseph K. Bradley <[email protected]>
Committed: Fri Dec 29 19:45:14 2017 -0800

----------------------------------------------------------------------
 .../spark/ml/feature/VectorSizeHint.scala       |  1 +
 python/pyspark/ml/feature.py                    | 79 ++++++++++++++++++++
 2 files changed, 80 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/81696304/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
index 1fe3cfc..f5947d6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType
  * VectorAssembler needs size information for its input columns and cannot be 
used on streaming
  * dataframes without this metadata.
  *
+ * Note: VectorSizeHint modifies `inputCol` to include size metadata and does 
not have an outputCol.
  */
 @Experimental
 @Since("2.3.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/81696304/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 608f2a5..5094324 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -57,6 +57,7 @@ __all__ = ['Binarizer',
            'Tokenizer',
            'VectorAssembler',
            'VectorIndexer', 'VectorIndexerModel',
+           'VectorSizeHint',
            'VectorSlicer',
            'Word2Vec', 'Word2VecModel']
 
@@ -3466,6 +3467,84 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, 
JavaMLWritable):
         return self._call_java("selectedFeatures")
 
 
+@inherit_doc
+class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, 
JavaMLReadable,
+                     JavaMLWritable):
+    """
+    .. note:: Experimental
+
+    A feature transformer that adds size information to the metadata of a 
vector column.
+    VectorAssembler needs size information for its input columns and cannot be 
used on streaming
+    dataframes without this metadata.
+
+    .. note:: VectorSizeHint modifies `inputCol` to include size metadata and 
does not have an
+        outputCol.
+
+    >>> from pyspark.ml.linalg import Vectors
+    >>> from pyspark.ml import Pipeline, PipelineModel
+    >>> data = [(Vectors.dense([1., 2., 3.]), 4.)]
+    >>> df = spark.createDataFrame(data, ["vector", "float"])
+    >>>
+    >>> sizeHint = VectorSizeHint(inputCol="vector", size=3, 
handleInvalid="skip")
+    >>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], 
outputCol="assembled")
+    >>> pipeline = Pipeline(stages=[sizeHint, vecAssembler])
+    >>>
+    >>> pipelineModel = pipeline.fit(df)
+    >>> pipelineModel.transform(df).head().assembled
+    DenseVector([1.0, 2.0, 3.0, 4.0])
+    >>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline"
+    >>> pipelineModel.save(vectorSizeHintPath)
+    >>> loadedPipeline = PipelineModel.load(vectorSizeHintPath)
+    >>> loaded = loadedPipeline.transform(df).head().assembled
+    >>> expected = pipelineModel.transform(df).head().assembled
+    >>> loaded == expected
+    True
+
+    .. versionadded:: 2.3.0
+    """
+
+    size = Param(Params._dummy(), "size", "Size of vectors in column.",
+                 typeConverter=TypeConverters.toInt)
+
+    handleInvalid = Param(Params._dummy(), "handleInvalid",
+                          "How to handle invalid vectors in inputCol. Invalid 
vectors include "
+                          "nulls and vectors with the wrong size. The options 
are `skip` (filter "
+                          "out rows with invalid vectors), `error` (throw an 
error) and "
+                          "`optimistic` (do not check the vector size, and 
keep all rows). "
+                          "`error` by default.",
+                          TypeConverters.toString)
+
+    @keyword_only
+    def __init__(self, inputCol=None, size=None, handleInvalid="error"):
+        """
+        __init__(self, inputCol=None, size=None, handleInvalid="error")
+        """
+        super(VectorSizeHint, self).__init__()
+        self._java_obj = 
self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid)
+        self._setDefault(handleInvalid="error")
+        self.setParams(**self._input_kwargs)
+
+    @keyword_only
+    @since("2.3.0")
+    def setParams(self, inputCol=None, size=None, handleInvalid="error"):
+        """
+        setParams(self, inputCol=None, size=None, handleInvalid="error")
+        Sets params for this VectorSizeHint.
+        """
+        kwargs = self._input_kwargs
+        return self._set(**kwargs)
+
+    @since("2.3.0")
+    def getSize(self):
+        """ Gets size param, the size of vectors in `inputCol`."""
+        self.getOrDefault(self.size)
+
+    @since("2.3.0")
+    def setSize(self, value):
+        """ Sets size param, the size of vectors in `inputCol`."""
+        self._set(size=value)
+
+
 if __name__ == "__main__":
     import doctest
     import tempfile


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to