This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c629020592 [SPARK-40311][SQL][PYTHON] Add `withColumnsRenamed` to 
scala and pyspark API
2c629020592 is described below

commit 2c6290205928521e8d7404bb9a9cbccff0d35674
Author: santosh <3813695+santosh-d3vp...@users.noreply.github.com>
AuthorDate: Thu Oct 6 00:29:07 2022 -0700

    [SPARK-40311][SQL][PYTHON] Add `withColumnsRenamed` to scala and pyspark API
    
    ### What changes were proposed in this pull request?
    This change adds an ability for code to rename multiple columns in a single 
call.
    **Scala:**
    ```scala
    withColumnsRenamed(colsMap: Map[String, String]): DataFrame
    ```
    **Java:**
    ```java
    withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame
    ```
    **Python:**
    ```python
    withColumnsRenamed(self, *colsMap: Dict[str, Column]) -> "DataFrame"
    ```
    
    ### Why are the changes needed?
    We have seen that catalyst optimiser struggles with bigger plans. The 
larger contribution to these plans in our setup comes from `withColumnRenamed`, 
`drop` and `withColumn` being called in for loop by unknowing users. `master` 
branch of spark already has a version for handling `withColumns` and `drop` for 
multiple columns. The missing bit of the puzzle is `withColumnRenamed`.
    
    With large amount of columns, either JVM gets killed or StackOverflowError 
occurs. I am skipping those for the following benchmark and focus on number of 
columns which work in both old and new implementation. Following example shows 
the performance impact with 100 columns.:
    **Old fashioned with 100 columns**
    ```python
    import datetime
    import numpy as np
    import pandas as pd
    
    num_rows = 2
    num_columns = 100
    data = np.zeros((num_rows, num_columns))
    columns = map(str, range(num_columns))
    raw = spark.createDataFrame(pd.DataFrame(data, columns=columns))
    
    a = datetime.datetime.now()
    
    for col in raw.columns:
        raw = raw.withColumnRenamed(col, f"prefix_{col}")
    
    b = datetime.datetime.now()
    for col in raw.columns:
        raw = raw.withColumnRenamed(col, f"prefix_{col}")
    
    c = datetime.datetime.now()
    for col in raw.columns:
        raw = raw.withColumnRenamed(col, f"prefix_{col}")
    
    d = datetime.datetime.now()
    for col in raw.columns:
        raw = raw.withColumnRenamed(col, f"prefix_{col}")
    
    e = datetime.datetime.now()
    for col in raw.columns:
        raw = raw.withColumnRenamed(col, f"prefix_{col}")
    
    f = datetime.datetime.now()
    for col in raw.columns:
        raw = raw.withColumnRenamed(col, f"prefix_{col}")
    
    g = datetime.datetime.now()
    g-a
    datetime.timedelta(seconds=12, microseconds=480021)
    ```
    
    **New implementation with 100 columns**
    ```python
    import datetime
    import numpy as np
    import pandas as pd
    
    num_rows = 2
    num_columns = 100
    data = np.zeros((num_rows, num_columns))
    columns = map(str, range(num_columns))
    raw = spark.createDataFrame(pd.DataFrame(data, columns=columns))
    
    a = datetime.datetime.now()
    raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns})
    b = datetime.datetime.now()
    raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns})
    c = datetime.datetime.now()
    raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns})
    d = datetime.datetime.now()
    raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns})
    e = datetime.datetime.now()
    raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns})
    f = datetime.datetime.now()
    raw = raw.withColumnsRenamed({col: f"prefix_{col}" for col in raw.columns})
    g = datetime.datetime.now()
    g-a
    datetime.timedelta(microseconds=210400)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, adds a method to efficiently rename columns in a single batch.
    
    ### How was this patch tested?
    Added unit tests
    
    Closes #37761 from santosh-d3vpl3x/master.
    
    Lead-authored-by: santosh <3813695+santosh-d3vp...@users.noreply.github.com>
    Co-authored-by: Santosh Pingale 
<3813695+santosh-d3vp...@users.noreply.github.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../source/reference/pyspark.sql/dataframe.rst     |  1 +
 python/pyspark/sql/dataframe.py                    | 40 +++++++++++++++++
 python/pyspark/sql/tests/test_dataframe.py         | 16 +++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 47 ++++++++++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala      | 51 ++++++++++++++++++++++
 5 files changed, 155 insertions(+)

diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst 
b/python/docs/source/reference/pyspark.sql/dataframe.rst
index fdb79f72fc7..e647704158f 100644
--- a/python/docs/source/reference/pyspark.sql/dataframe.rst
+++ b/python/docs/source/reference/pyspark.sql/dataframe.rst
@@ -119,6 +119,7 @@ DataFrame
     DataFrame.withColumn
     DataFrame.withColumns
     DataFrame.withColumnRenamed
+    DataFrame.withColumnsRenamed
     DataFrame.withMetadata
     DataFrame.withWatermark
     DataFrame.write
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 23dfd4e7ec8..7c3cc92d393 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -4430,6 +4430,46 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         """
         return DataFrame(self._jdf.withColumnRenamed(existing, new), 
self.sparkSession)
 
+    def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame":
+        """
+        Returns a new :class:`DataFrame` by renaming multiple columns.
+        This is a no-op if schema doesn't contain the given column names.
+
+        .. versionadded:: 3.4.0
+           Added support for multiple columns renaming
+
+        Parameters
+        ----------
+        colsMap : dict
+            a dict of existing column names and corresponding desired column 
names.
+            Currently, only single map is supported.
+
+        Returns
+        -------
+        :class:`DataFrame`
+            DataFrame with renamed columns.
+
+        See Also
+        --------
+        :meth:`withColumnRenamed`
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], 
schema=["age", "name"])
+        >>> df = df.withColumns({'age2': df.age + 2, 'age3': df.age + 3})
+        >>> df.withColumnsRenamed({'age2': 'age4', 'age3': 'age5'}).show()
+        +---+-----+----+----+
+        |age| name|age4|age5|
+        +---+-----+----+----+
+        |  2|Alice|   4|   5|
+        |  5|  Bob|   7|   8|
+        +---+-----+----+----+
+        """
+        if not isinstance(colsMap, dict):
+            raise TypeError("colsMap must be dict of existing column name and 
new column name.")
+
+        return DataFrame(self._jdf.withColumnsRenamed(colsMap), 
self.sparkSession)
+
     def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> 
"DataFrame":
         """Returns a new :class:`DataFrame` by updating an existing column 
with metadata.
 
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index d15ba442ab4..be5784114fb 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -97,6 +97,22 @@ class DataFrameTests(ReusedSQLTestCase):
         self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
         self.assertEqual(df.drop(col("name"), col("age"), 
col("random")).columns, ["active"])
 
+    def test_with_columns_renamed(self):
+        df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], 
["name", "age"])
+
+        # rename both columns
+        renamed_df1 = df.withColumnsRenamed({"name": "naam", "age": 
"leeftijd"})
+        self.assertEqual(renamed_df1.columns, ["naam", "leeftijd"])
+
+        # rename one column with one missing name
+        renamed_df2 = df.withColumnsRenamed({"name": "naam", "address": 
"adres"})
+        self.assertEqual(renamed_df2.columns, ["naam", "age"])
+
+        # negative test for incorrect type
+        type_error_msg = "colsMap must be dict of existing column name and new 
column name."
+        with self.assertRaisesRegex(TypeError, type_error_msg):
+            df.withColumnsRenamed(("name", "x"))
+
     def test_drop_duplicates(self):
         # SPARK-36034 test that drop duplicates throws a type error when in 
correct type provided
         df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], 
["name", "age"])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 18aea40f556..6a07db71428 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2808,6 +2808,53 @@ class Dataset[T] private[sql](
     }
   }
 
+  /**
+   * (Scala-specific)
+   * Returns a new Dataset with a columns renamed.
+   * This is a no-op if schema doesn't contain existingName.
+   *
+   * `colsMap` is a map of existing column name and new column name.
+   *
+   * @throws AnalysisException if there are duplicate names in resulting 
projection
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  @throws[AnalysisException]
+  def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = {
+    val resolver = sparkSession.sessionState.analyzer.resolver
+    val output: Seq[NamedExpression] = queryExecution.analyzed.output
+
+    val projectList = colsMap.foldLeft(output) {
+      case (attrs, (existingName, newName)) =>
+      attrs.map(attr =>
+        if (resolver(attr.name, existingName)) {
+          Alias(attr, newName)()
+        } else {
+          attr
+        }
+      )
+    }
+    SchemaUtils.checkColumnNameDuplication(
+      projectList.map(_.name),
+      "in given column names for withColumnsRenamed",
+      sparkSession.sessionState.conf.caseSensitiveAnalysis)
+    withPlan(Project(projectList, logicalPlan))
+  }
+
+  /**
+   * (Java-specific)
+   * Returns a new Dataset with a columns renamed.
+   * This is a no-op if schema doesn't contain existingName.
+   *
+   * `colsMap` is a map of existing column name and new column name.
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  def withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame =
+    withColumnsRenamed(colsMap.asScala.toMap)
+
   /**
    * Returns a new Dataset by updating an existing column with metadata.
    *
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b29b5c2b341..0fcbbe6fa69 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -895,6 +895,57 @@ class DataFrameSuite extends QueryTest
     assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
   }
 
+  test("SPARK-40311: withColumnsRenamed") {
+      val df = testData.toDF().withColumns(Seq("newCol1", "newCOL2"),
+        Seq(col("key") + 1, col("key") + 2))
+        .withColumnsRenamed(Map("newCol1" -> "renamed1", "newCol2" -> 
"renamed2"))
+      checkAnswer(
+        df,
+        testData.collect().map { case Row(key: Int, value: String) =>
+          Row(key, value, key + 1, key + 2)
+        }.toSeq)
+      assert(df.columns === Array("key", "value", "renamed1", "renamed2"))
+  }
+
+  test("SPARK-40311: withColumnsRenamed case sensitive") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      val df = testData.toDF().withColumns(Seq("newCol1", "newCOL2"),
+        Seq(col("key") + 1, col("key") + 2))
+        .withColumnsRenamed(Map("newCol1" -> "renamed1", "newCol2" -> 
"renamed2"))
+      checkAnswer(
+        df,
+        testData.collect().map { case Row(key: Int, value: String) =>
+          Row(key, value, key + 1, key + 2)
+        }.toSeq)
+      assert(df.columns === Array("key", "value", "renamed1", "newCOL2"))
+    }
+  }
+
+  test("SPARK-40311: withColumnsRenamed duplicate column names simple") {
+    val e = intercept[AnalysisException] {
+      person.withColumnsRenamed(Map("id" -> "renamed", "name" -> "renamed"))
+    }
+    assert(e.getMessage.contains("Found duplicate column(s)"))
+    assert(e.getMessage.contains("in given column names for 
withColumnsRenamed:"))
+    assert(e.getMessage.contains("`renamed`"))
+  }
+
+  test("SPARK-40311: withColumnsRenamed duplicate column names simple case 
sensitive") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      val df = person.withColumnsRenamed(Map("id" -> "renamed", "name" -> 
"Renamed"))
+      assert(df.columns === Array("renamed", "Renamed", "age"))
+    }
+  }
+
+  test("SPARK-40311: withColumnsRenamed duplicate column names indirect") {
+    val e = intercept[AnalysisException] {
+      person.withColumnsRenamed(Map("id" -> "renamed1", "renamed1" -> "age"))
+    }
+    assert(e.getMessage.contains("Found duplicate column(s)"))
+    assert(e.getMessage.contains("in given column names for 
withColumnsRenamed:"))
+    assert(e.getMessage.contains("`age`"))
+  }
+
   test("SPARK-20384: Value class filter") {
     val df = spark.sparkContext
       .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), 
StringWrapper("c")))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to