HyukjinKwon closed pull request #22428: [SPARK-25430][SQL] Add map parameter 
for withColumnRenamed
URL: https://github.com/apache/spark/pull/22428
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index fa14aa14ee968..d3a170896a142 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2300,6 +2300,37 @@ class Dataset[T] private[sql](
     }
   }
 
+  /**
+   * Returns a new Dataset with columns renamed.
+   * This is a no-op if schema doesn't contain existingNames in columnMap.
+   * {{{
+   *   df.withColumnRenamed(Map(
+   *     "c1" -> "first_column",
+   *     "c2" -> "second_column"
+   *   ))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.0.0
+   */
+  def withColumnRenamed(columnMap: Map[String, String]): DataFrame = {
+    val resolver = sparkSession.sessionState.analyzer.resolver
+    val output = queryExecution.analyzed.output
+    val existingNames = columnMap.keys.toSeq
+    val shouldRename = !output.map(_.name).intersect(existingNames).isEmpty
+    if (shouldRename) {
+      val columns = output.map { col =>
+        columnMap.get(col.name) match {
+          case Some(newName) => Column(col).as(newName)
+          case _ => Column(col)
+        }
+      }
+      select(columns : _*)
+    } else {
+      toDF()
+    }
+  }
+
   /**
    * Returns a new Dataset with a column dropped. This is a no-op if schema 
doesn't contain
    * column name.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f001b138f4b8e..525f139e74a9e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1021,6 +1021,18 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
   }
 
+  test("SPARK-25430: Add map parameter for withColumnRenamed") {
+    val df = testData.toDF().withColumn("newCol", col("key") + 1)
+      .withColumnRenamed(Map("value"->"valueRenamed", 
"newCol"->"newColRenamed",
+        "newCol2"->"newColRenamed2"))
+    checkAnswer(
+      df,
+      testData.collect().map { case Row(key: Int, value: String) =>
+        Row(key, value, key + 1)
+      }.toSeq)
+    assert(df.schema.map(_.name) === Seq("key", "valueRenamed", 
"newColRenamed"))
+  }
+
   private lazy val person2: DataFrame = Seq(
     ("Bob", 16, 176),
     ("Alice", 32, 164),


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to