Repository: spark
Updated Branches:
  refs/heads/branch-1.4 2bbb685f4 -> 653db0a1b


[SPARK-6876] [PySpark] [SQL] add DataFrame na.replace in pyspark

Author: Daoyuan Wang <[email protected]>

Closes #6003 from adrian-wang/pynareplace and squashes the following commits:

672efba [Daoyuan Wang] remove py2.7 feature
4a148f7 [Daoyuan Wang] to_replace support dict, value support single value, and 
add full tests
9e232e7 [Daoyuan Wang] rename scala map
af0268a [Daoyuan Wang] remove na
63ac579 [Daoyuan Wang] add na.replace in pyspark

(cherry picked from commit d86ce845840a92b4dde7975082738ed94ab8c570)
Signed-off-by: Reynold Xin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/653db0a1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/653db0a1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/653db0a1

Branch: refs/heads/branch-1.4
Commit: 653db0a1bddea120f7b2184ae6f737544b035bd4
Parents: 2bbb685
Author: Daoyuan Wang <[email protected]>
Authored: Tue May 12 10:23:41 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Tue May 12 10:23:57 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/api/python/PythonUtils.scala   |  7 ++
 python/pyspark/sql/dataframe.py                 | 85 ++++++++++++++++++++
 python/pyspark/sql/tests.py                     | 48 +++++++++++
 3 files changed, 140 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/653db0a1/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index 9eff0a2..efb6b93 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -53,4 +53,11 @@ private[spark] object PythonUtils {
   def toSeq[T](cols: JList[T]): Seq[T] = {
     cols.toList.toSeq
   }
+
+  /**
+   * Convert java map of K, V into Map of K, V (for calling API with varargs)
+   */
+  def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = {
+    jm.toMap
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/653db0a1/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 72180f6..078acfd 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -578,6 +578,10 @@ class DataFrame(object):
         """Return a JVM Seq of Columns from a list of Column or names"""
         return _to_seq(self.sql_ctx._sc, cols, converter)
 
+    def _jmap(self, jm):
+        """Return a JVM Scala Map from a dict"""
+        return _to_scala_map(self.sql_ctx._sc, jm)
+
     def _jcols(self, *cols):
         """Return a JVM Seq of Columns from a list of Column or column names
 
@@ -924,6 +928,80 @@ class DataFrame(object):
 
             return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), 
self.sql_ctx)
 
+    def replace(self, to_replace, value, subset=None):
+        """Returns a new :class:`DataFrame` replacing a value with another 
value.
+
+        :param to_replace: int, long, float, string, or list.
+            Value to be replaced.
+            If the value is a dict, then `value` is ignored and `to_replace` 
must be a
+            mapping from column name (string) to replacement value. The value 
to be
+            replaced must be an int, long, float, or string.
+        :param value: int, long, float, string, or list.
+            Value to use to replace holes.
+            The replacement value must be an int, long, float, or string. If 
`value` is a
+            list or tuple, `value` should be of the same length with 
`to_replace`.
+        :param subset: optional list of column names to consider.
+            Columns specified in subset that do not have matching data type 
are ignored.
+            For example, if `value` is a string, and subset contains a 
non-string column,
+            then the non-string column is simply ignored.
+        >>> df4.replace(10, 20).show()
+        +----+------+-----+
+        | age|height| name|
+        +----+------+-----+
+        |  20|    80|Alice|
+        |   5|  null|  Bob|
+        |null|  null|  Tom|
+        |null|  null| null|
+        +----+------+-----+
+
+        >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
+        +----+------+----+
+        | age|height|name|
+        +----+------+----+
+        |  10|    80|   A|
+        |   5|  null|   B|
+        |null|  null| Tom|
+        |null|  null|null|
+        +----+------+----+
+        """
+        if not isinstance(to_replace, (float, int, long, basestring, list, 
tuple, dict)):
+            raise ValueError(
+                "to_replace should be a float, int, long, string, list, tuple, 
or dict")
+
+        if not isinstance(value, (float, int, long, basestring, list, tuple)):
+            raise ValueError("value should be a float, int, long, string, 
list, or tuple")
+
+        rep_dict = dict()
+
+        if isinstance(to_replace, (float, int, long, basestring)):
+            to_replace = [to_replace]
+
+        if isinstance(to_replace, tuple):
+            to_replace = list(to_replace)
+
+        if isinstance(value, tuple):
+            value = list(value)
+
+        if isinstance(to_replace, list) and isinstance(value, list):
+            if len(to_replace) != len(value):
+                raise ValueError("to_replace and value lists should be of the 
same length")
+            rep_dict = dict(zip(to_replace, value))
+        elif isinstance(to_replace, list) and isinstance(value, (float, int, 
long, basestring)):
+            rep_dict = dict([(tr, value) for tr in to_replace])
+        elif isinstance(to_replace, dict):
+            rep_dict = to_replace
+
+        if subset is None:
+            return DataFrame(self._jdf.na().replace('*', rep_dict), 
self.sql_ctx)
+        elif isinstance(subset, basestring):
+            subset = [subset]
+
+        if not isinstance(subset, (list, tuple)):
+            raise ValueError("subset should be a list or tuple of column 
names")
+
+        return DataFrame(
+            self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), 
self.sql_ctx)
+
     def corr(self, col1, col2, method=None):
         """
         Calculates the correlation of two columns of a DataFrame as a double 
value. Currently only
@@ -1226,6 +1304,13 @@ def _to_seq(sc, cols, converter=None):
     return sc._jvm.PythonUtils.toSeq(cols)
 
 
+def _to_scala_map(sc, jm):
+    """
+    Convert a dict into a JVM Map.
+    """
+    return sc._jvm.PythonUtils.toScalaMap(jm)
+
+
 def _unary_op(name, doc="unary operator"):
     """ Create a method for given unary operator """
     def _(self):

http://git-wip-us.apache.org/repos/asf/spark/blob/653db0a1/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 7e63f4d..1922d03 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -665,6 +665,54 @@ class SQLTests(ReusedPySparkTestCase):
         result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
         self.assertEqual(~75, result['~b'])
 
+    def test_replace(self):
+        schema = StructType([
+            StructField("name", StringType(), True),
+            StructField("age", IntegerType(), True),
+            StructField("height", DoubleType(), True)])
+
+        # replace with int
+        row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], 
schema).replace(10, 20).first()
+        self.assertEqual(row.age, 20)
+        self.assertEqual(row.height, 20.0)
+
+        # replace with double
+        row = self.sqlCtx.createDataFrame(
+            [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
+        self.assertEqual(row.age, 82)
+        self.assertEqual(row.height, 82.1)
+
+        # replace with string
+        row = self.sqlCtx.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
+        self.assertEqual(row.name, u"Ann")
+        self.assertEqual(row.age, 10)
+
+        # replace with subset specified by a string of a column name w/ actual 
change
+        row = self.sqlCtx.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace(10, 20, 
subset='age').first()
+        self.assertEqual(row.age, 20)
+
+        # replace with subset specified by a string of a column name w/o 
actual change
+        row = self.sqlCtx.createDataFrame(
+            [(u'Alice', 10, 80.1)], schema).replace(10, 20, 
subset='height').first()
+        self.assertEqual(row.age, 10)
+
+        # replace with subset specified with one column replaced, another 
column not in subset
+        # stays unchanged.
+        row = self.sqlCtx.createDataFrame(
+            [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 
'age']).first()
+        self.assertEqual(row.name, u'Alice')
+        self.assertEqual(row.age, 20)
+        self.assertEqual(row.height, 10.0)
+
+        # replace with subset specified but no column will be replaced
+        row = self.sqlCtx.createDataFrame(
+            [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 
'height']).first()
+        self.assertEqual(row.name, u'Alice')
+        self.assertEqual(row.age, 10)
+        self.assertEqual(row.height, None)
+
 
 class HiveContextSQLTests(ReusedPySparkTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to