Repository: spark
Updated Branches:
  refs/heads/master adb222b95 -> 4f1e8b9bb


[SPARK-23871][ML][PYTHON] add python api for VectorAssembler handleInvalid

## What changes were proposed in this pull request?

add python api for VectorAssembler handleInvalid

## How was this patch tested?

Add doctest

Author: Huaxin Gao <huax...@us.ibm.com>

Closes #21003 from huaxingao/spark-23871.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f1e8b9b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f1e8b9b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f1e8b9b

Branch: refs/heads/master
Commit: 4f1e8b9bb7d795d4ca3d5cd5dcc0f9419e52dfae
Parents: adb222b
Author: Huaxin Gao <huax...@us.ibm.com>
Authored: Tue Apr 10 15:41:45 2018 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Apr 10 15:41:45 2018 -0700

----------------------------------------------------------------------
 .../spark/ml/feature/VectorAssembler.scala      | 12 +++---
 python/pyspark/ml/feature.py                    | 42 +++++++++++++++++---
 2 files changed, 43 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f1e8b9b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 6bf4aa3..4061154 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -71,12 +71,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
    */
   @Since("2.4.0")
   override val handleInvalid: Param[String] = new Param[String](this, 
"handleInvalid",
-    """Param for how to handle invalid data (NULL values). Options are 'skip' 
(filter out rows with
-      |invalid data), 'error' (throw an error), or 'keep' (return relevant 
number of NaN in the
-      |output). Column lengths are taken from the size of ML Attribute Group, 
which can be set using
-      |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths 
can also be inferred
-      |from first rows of the data since it is safe to do so but only in case 
of 'error' or 'skip'.
-      |""".stripMargin.replaceAll("\n", " "),
+    """Param for how to handle invalid data (NULL and NaN values). Options are 
'skip' (filter out
+      |rows with invalid data), 'error' (throw an error), or 'keep' (return 
relevant number of NaN
+      |in the output). Column lengths are taken from the size of ML Attribute 
Group, which can be
+      |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. 
Column lengths can also
+      |be inferred from first rows of the data since it is safe to do so but 
only in case of 'error'
+      |or 'skip'.""".stripMargin.replaceAll("\n", " "),
     ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
 
   setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)

http://git-wip-us.apache.org/repos/asf/spark/blob/4f1e8b9b/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 5a3e0dd..cdda30c 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2701,7 +2701,8 @@ class Tokenizer(JavaTransformer, HasInputCol, 
HasOutputCol, JavaMLReadable, Java
 
 
 @inherit_doc
-class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, 
JavaMLReadable, JavaMLWritable):
+class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, 
HasHandleInvalid, JavaMLReadable,
+                      JavaMLWritable):
     """
     A feature transformer that merges multiple columns into a vector column.
 
@@ -2719,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, 
HasOutputCol, JavaMLReadabl
     >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
     >>> loadedAssembler.transform(df).head().freqs == 
vecAssembler.transform(df).head().freqs
     True
+    >>> dfWithNullsAndNaNs = spark.createDataFrame(
+    ...    [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], 
["a", "b", "c"])
+    >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], 
outputCol="features",
+    ...    handleInvalid="keep")
+    >>> vecAssembler2.transform(dfWithNullsAndNaNs).show()
+    +---+---+----+-------------+
+    |  a|  b|   c|     features|
+    +---+---+----+-------------+
+    |1.0|2.0|null|[1.0,2.0,NaN]|
+    |3.0|NaN| 4.0|[3.0,NaN,4.0]|
+    |5.0|6.0| 7.0|[5.0,6.0,7.0]|
+    +---+---+----+-------------+
+    ...
+    >>> 
vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show()
+    +---+---+---+-------------+
+    |  a|  b|  c|     features|
+    +---+---+---+-------------+
+    |5.0|6.0|7.0|[5.0,6.0,7.0]|
+    +---+---+---+-------------+
+    ...
 
     .. versionadded:: 1.4.0
     """
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle 
invalid data (NULL " +
+                          "and NaN values). Options are 'skip' (filter out 
rows with invalid " +
+                          "data), 'error' (throw an error), or 'keep' (return 
relevant number " +
+                          "of NaN in the output). Column lengths are taken 
from the size of ML " +
+                          "Attribute Group, which can be set using 
`VectorSizeHint` in a " +
+                          "pipeline before `VectorAssembler`. Column lengths 
can also be " +
+                          "inferred from first rows of the data since it is 
safe to do so but " +
+                          "only in case of 'error' or 'skip').",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
-    def __init__(self, inputCols=None, outputCol=None):
+    def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"):
         """
-        __init__(self, inputCols=None, outputCol=None)
+        __init__(self, inputCols=None, outputCol=None, handleInvalid="error")
         """
         super(VectorAssembler, self).__init__()
         self._java_obj = 
self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
+        self._setDefault(handleInvalid="error")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, inputCols=None, outputCol=None):
+    def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"):
         """
-        setParams(self, inputCols=None, outputCol=None)
+        setParams(self, inputCols=None, outputCol=None, handleInvalid="error")
         Sets params for this VectorAssembler.
         """
         kwargs = self._input_kwargs


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to