This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new ae13b45  [SPARK-38763][PYTHON] Support lambda `column` parameter of 
`DataFrame.rename`
ae13b45 is described below

commit ae13b453f6b239af4c7f57cff99e7b8ef939cc9e
Author: Xinrong Meng <xinrong.m...@databricks.com>
AuthorDate: Sun Apr 3 09:52:48 2022 +0900

    [SPARK-38763][PYTHON] Support lambda `column` parameter of 
`DataFrame.rename`
    
    ### What changes were proposed in this pull request?
    Support lambda `column` parameter of `DataFrame.rename`.
    
    We may want to backport this to 3.3 since this is a regression.
    
    ### Why are the changes needed?
    To reach parity with Pandas.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. The regression is fixed; lambda `column` is supported again.
    
    ```py
    >>> psdf = ps.DataFrame({'x': [1, 2], 'y': [3, 4]})
    
    >>> psdf.rename(columns=lambda x: x + 'o')
       xo  yo
    0   1   3
    1   2   4
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #36042 from xinrong-databricks/frame.rename.
    
    Authored-by: Xinrong Meng <xinrong.m...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 037d07c8acb864f495ea74afba94531c28c163ce)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/frame.py                | 16 ++++++++++------
 python/pyspark/pandas/tests/test_dataframe.py |  9 +++++++++
 2 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index b355708..6e8f69a 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -10580,7 +10580,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         """
 
         def gen_mapper_fn(
-            mapper: Union[Dict, Callable[[Any], Any]]
+            mapper: Union[Dict, Callable[[Any], Any]], skip_return_type: bool 
= False
         ) -> Tuple[Callable[[Any], Any], Dtype, DataType]:
             if isinstance(mapper, dict):
                 mapper_dict = mapper
@@ -10598,21 +10598,25 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
                             raise KeyError("Index include value which is not 
in the `mapper`")
                         return x
 
+                return mapper_fn, dtype, spark_return_type
             elif callable(mapper):
                 mapper_callable = cast(Callable, mapper)
-                return_type = cast(ScalarType, infer_return_type(mapper))
-                dtype = return_type.dtype
-                spark_return_type = return_type.spark_type
 
                 def mapper_fn(x: Any) -> Any:
                     return mapper_callable(x)
 
+                if skip_return_type:
+                    return mapper_fn, None, None
+                else:
+                    return_type = cast(ScalarType, infer_return_type(mapper))
+                    dtype = return_type.dtype
+                    spark_return_type = return_type.spark_type
+                    return mapper_fn, dtype, spark_return_type
             else:
                 raise ValueError(
                     "`mapper` or `index` or `columns` should be "
                     "either dict-like or function type."
                 )
-            return mapper_fn, dtype, spark_return_type
 
         index_mapper_fn = None
         index_mapper_ret_stype = None
@@ -10633,7 +10637,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
                     index
                 )
             if columns:
-                columns_mapper_fn, _, _ = gen_mapper_fn(columns)
+                columns_mapper_fn, _, _ = gen_mapper_fn(columns, 
skip_return_type=True)
 
             if not index and not columns:
                 raise ValueError("Either `index` or `columns` should be 
provided.")
diff --git a/python/pyspark/pandas/tests/test_dataframe.py 
b/python/pyspark/pandas/tests/test_dataframe.py
index 6f3c1c4..1cc03bf 100644
--- a/python/pyspark/pandas/tests/test_dataframe.py
+++ b/python/pyspark/pandas/tests/test_dataframe.py
@@ -817,11 +817,20 @@ class DataFrameTest(ComparisonTestBase, SQLTestUtils):
             pdf1.rename(columns=str_lower, index={1: 10, 2: 20}),
         )
 
+        self.assert_eq(
+            psdf1.rename(columns=lambda x: str.lower(x)),
+            pdf1.rename(columns=lambda x: str.lower(x)),
+        )
+
         idx = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), 
("Y", "D")])
         pdf2 = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=idx)
         psdf2 = ps.from_pandas(pdf2)
 
         self.assert_eq(psdf2.rename(columns=str_lower), 
pdf2.rename(columns=str_lower))
+        self.assert_eq(
+            psdf2.rename(columns=lambda x: str.lower(x)),
+            pdf2.rename(columns=lambda x: str.lower(x)),
+        )
 
         self.assert_eq(
             psdf2.rename(columns=str_lower, level=0), 
pdf2.rename(columns=str_lower, level=0)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to