This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new ae13b45 [SPARK-38763][PYTHON] Support lambda `column` parameter of `DataFrame.rename` ae13b45 is described below commit ae13b453f6b239af4c7f57cff99e7b8ef939cc9e Author: Xinrong Meng <xinrong.m...@databricks.com> AuthorDate: Sun Apr 3 09:52:48 2022 +0900 [SPARK-38763][PYTHON] Support lambda `column` parameter of `DataFrame.rename` ### What changes were proposed in this pull request? Support lambda `column` parameter of `DataFrame.rename`. We may want to backport this to 3.3 since this is a regression. ### Why are the changes needed? To reach parity with Pandas. ### Does this PR introduce _any_ user-facing change? Yes. The regression is fixed; lambda `column` is supported again. ```py >>> psdf = ps.DataFrame({'x': [1, 2], 'y': [3, 4]}) >>> psdf.rename(columns=lambda x: x + 'o') xo yo 0 1 3 1 2 4 ``` ### How was this patch tested? Unit tests. Closes #36042 from xinrong-databricks/frame.rename. Authored-by: Xinrong Meng <xinrong.m...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 037d07c8acb864f495ea74afba94531c28c163ce) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/pandas/frame.py | 16 ++++++++++------ python/pyspark/pandas/tests/test_dataframe.py | 9 +++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index b355708..6e8f69a 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -10580,7 +10580,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})] """ def gen_mapper_fn( - mapper: Union[Dict, Callable[[Any], Any]] + mapper: Union[Dict, Callable[[Any], Any]], skip_return_type: bool = False ) -> Tuple[Callable[[Any], Any], Dtype, DataType]: if isinstance(mapper, dict): mapper_dict = mapper @@ -10598,21 +10598,25 @@ defaultdict(<class 'list'>, {'col..., 'col...})] raise KeyError("Index include value which is not in the `mapper`") return x + return mapper_fn, dtype, spark_return_type elif callable(mapper): mapper_callable = cast(Callable, mapper) - return_type = cast(ScalarType, infer_return_type(mapper)) - dtype = return_type.dtype - spark_return_type = return_type.spark_type def mapper_fn(x: Any) -> Any: return mapper_callable(x) + if skip_return_type: + return mapper_fn, None, None + else: + return_type = cast(ScalarType, infer_return_type(mapper)) + dtype = return_type.dtype + spark_return_type = return_type.spark_type + return mapper_fn, dtype, spark_return_type else: raise ValueError( "`mapper` or `index` or `columns` should be " "either dict-like or function type." ) - return mapper_fn, dtype, spark_return_type index_mapper_fn = None index_mapper_ret_stype = None @@ -10633,7 +10637,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})] index ) if columns: - columns_mapper_fn, _, _ = gen_mapper_fn(columns) + columns_mapper_fn, _, _ = gen_mapper_fn(columns, skip_return_type=True) if not index and not columns: raise ValueError("Either `index` or `columns` should be provided.") diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 6f3c1c4..1cc03bf 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -817,11 +817,20 @@ class DataFrameTest(ComparisonTestBase, SQLTestUtils): pdf1.rename(columns=str_lower, index={1: 10, 2: 20}), ) + self.assert_eq( + psdf1.rename(columns=lambda x: str.lower(x)), + pdf1.rename(columns=lambda x: str.lower(x)), + ) + idx = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Y", "D")]) pdf2 = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=idx) psdf2 = ps.from_pandas(pdf2) self.assert_eq(psdf2.rename(columns=str_lower), pdf2.rename(columns=str_lower)) + self.assert_eq( + psdf2.rename(columns=lambda x: str.lower(x)), + pdf2.rename(columns=lambda x: str.lower(x)), + ) self.assert_eq( psdf2.rename(columns=str_lower, level=0), pdf2.rename(columns=str_lower, level=0) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org