This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 032e78297b0 [SPARK-46260][PYTHON][SQL] DataFrame.withColumnsRenamed` 
should respect the dict ordering
032e78297b0 is described below

commit 032e78297b02adb4266818776b55e09057705084
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Dec 6 17:16:07 2023 +0900

    [SPARK-46260][PYTHON][SQL] DataFrame.withColumnsRenamed` should respect the 
dict ordering
    
    ### What changes were proposed in this pull request?
    Make `DataFrame.withColumnsRenamed` respect the dict ordering
    
    ### Why are the changes needed?
    the ordering in `withColumnsRenamed` matters
    
    in scala
    ```
    scala> val df = spark.range(1000)
    val df: org.apache.spark.sql.Dataset[Long] = [id: bigint]
    
    scala> df.withColumnsRenamed(Map("id" -> "a", "a" -> "b"))
    val res0: org.apache.spark.sql.DataFrame = [b: bigint]
    
    scala> df.withColumnsRenamed(Map("a" -> "b", "id" -> "a"))
    val res1: org.apache.spark.sql.DataFrame = [a: bigint]
    ```
    
    However, in py4j the Python `dict` -> JVM `map` conversion can not 
guarantee the ordering
    
    ### Does this PR introduce _any_ user-facing change?
    yes, behavior change
    
    before this PR
    ```
    In [1]: df = spark.range(10)
    
    In [2]: df.withColumnsRenamed({"id": "a", "a": "b"})
    Out[2]: DataFrame[a: bigint]
    
    In [3]: df.withColumnsRenamed({"a": "b", "id": "a"})
    Out[3]: DataFrame[a: bigint]
    ```
    
    after this PR
    ```
    In [1]: df = spark.range(10)
    
    In [2]: df.withColumnsRenamed({"id": "a", "a": "b"})
    Out[2]: DataFrame[b: bigint]
    
    In [3]: df.withColumnsRenamed({"a": "b", "id": "a"})
    Out[3]: DataFrame[a: bigint]
    ```
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44177 from zhengruifeng/sql_withColumnsRenamed_sql.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/dataframe.py                    | 13 ++++++++++-
 .../sql/tests/connect/test_parity_dataframe.py     |  5 ++++
 python/pyspark/sql/tests/test_dataframe.py         |  9 ++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 27 +++++++++++++++-------
 .../org/apache/spark/sql/DataFrameSuite.scala      |  7 ++++++
 5 files changed, 52 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 5211d874ba3..1419d1f3cb6 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -6272,7 +6272,18 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
                 message_parameters={"arg_name": "colsMap", "arg_type": 
type(colsMap).__name__},
             )
 
-        return DataFrame(self._jdf.withColumnsRenamed(colsMap), 
self.sparkSession)
+        col_names: List[str] = []
+        new_col_names: List[str] = []
+        for k, v in colsMap.items():
+            col_names.append(k)
+            new_col_names.append(v)
+
+        return DataFrame(
+            self._jdf.withColumnsRenamed(
+                _to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names)
+            ),
+            self.sparkSession,
+        )
 
     def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> 
"DataFrame":
         """Returns a new :class:`DataFrame` by updating an existing column 
with metadata.
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index b7b4fdcd287..fbef282e0b9 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -77,6 +77,11 @@ class DataFrameParityTests(DataFrameTestsMixin, 
ReusedConnectTestCase):
     def test_toDF_with_string(self):
         super().test_toDF_with_string()
 
+    # TODO(SPARK-46261): Python Client withColumnsRenamed should respect the 
dict ordering
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_ordering_of_with_columns_renamed(self):
+        super().test_ordering_of_with_columns_renamed()
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 52806f4f4a3..c25fe60ad17 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -163,6 +163,15 @@ class DataFrameTestsMixin:
             message_parameters={"arg_name": "colsMap", "arg_type": "tuple"},
         )
 
+    def test_ordering_of_with_columns_renamed(self):
+        df = self.spark.range(10)
+
+        df1 = df.withColumnsRenamed({"id": "a", "a": "b"})
+        self.assertEqual(df1.columns, ["b"])
+
+        df2 = df.withColumnsRenamed({"a": "b", "id": "a"})
+        self.assertEqual(df2.columns, ["a"])
+
     def test_drop_duplicates(self):
         # SPARK-36034 test that drop duplicates throws a type error when in 
correct type provided
         df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], 
["name", "age"])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 293f20c453a..cacc193885d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2922,18 +2922,29 @@ class Dataset[T] private[sql](
    */
   @throws[AnalysisException]
   def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin 
{
+    val (colNames, newColNames) = colsMap.toSeq.unzip
+    withColumnsRenamed(colNames, newColNames)
+  }
+
+  private def withColumnsRenamed(
+    colNames: Seq[String],
+    newColNames: Seq[String]): DataFrame = withOrigin {
+    require(colNames.size == newColNames.size,
+      s"The size of existing column names: ${colNames.size} isn't equal to " +
+        s"the size of new column names: ${newColNames.size}")
+
     val resolver = sparkSession.sessionState.analyzer.resolver
     val output: Seq[NamedExpression] = queryExecution.analyzed.output
 
-    val projectList = colsMap.foldLeft(output) {
+    val projectList = colNames.zip(newColNames).foldLeft(output) {
       case (attrs, (existingName, newName)) =>
-      attrs.map(attr =>
-        if (resolver(attr.name, existingName)) {
-          Alias(attr, newName)()
-        } else {
-          attr
-        }
-      )
+        attrs.map(attr =>
+          if (resolver(attr.name, existingName)) {
+            Alias(attr, newName)()
+          } else {
+            attr
+          }
+        )
     }
     SchemaUtils.checkColumnNameDuplication(
       projectList.map(_.name),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b732f6631a7..25ecefd28cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp}
 import java.util.{Locale, UUID}
 import java.util.concurrent.atomic.AtomicLong
 
+import scala.collection.immutable.ListMap
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.Random
 
@@ -987,6 +988,12 @@ class DataFrameSuite extends QueryTest
       parameters = Map("columnName" -> "`age`"))
   }
 
+  test("SPARK-46260: withColumnsRenamed should respect the Map ordering") {
+    val df = spark.range(10).toDF()
+    assert(df.withColumnsRenamed(ListMap("id" -> "a", "a" -> "b")).columns === 
Array("b"))
+    assert(df.withColumnsRenamed(ListMap("a" -> "b", "id" -> "a")).columns === 
Array("a"))
+  }
+
   test("SPARK-20384: Value class filter") {
     val df = spark.sparkContext
       .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), 
StringWrapper("c")))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to