This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 032e78297b0 [SPARK-46260][PYTHON][SQL] DataFrame.withColumnsRenamed`
should respect the dict ordering
032e78297b0 is described below
commit 032e78297b02adb4266818776b55e09057705084
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 6 17:16:07 2023 +0900
[SPARK-46260][PYTHON][SQL] DataFrame.withColumnsRenamed` should respect the
dict ordering
### What changes were proposed in this pull request?
Make `DataFrame.withColumnsRenamed` respect the dict ordering
### Why are the changes needed?
the ordering in `withColumnsRenamed` matters
in scala
```
scala> val df = spark.range(1000)
val df: org.apache.spark.sql.Dataset[Long] = [id: bigint]
scala> df.withColumnsRenamed(Map("id" -> "a", "a" -> "b"))
val res0: org.apache.spark.sql.DataFrame = [b: bigint]
scala> df.withColumnsRenamed(Map("a" -> "b", "id" -> "a"))
val res1: org.apache.spark.sql.DataFrame = [a: bigint]
```
However, in py4j the Python `dict` -> JVM `map` conversion can not
guarantee the ordering
### Does this PR introduce _any_ user-facing change?
yes, behavior change
before this PR
```
In [1]: df = spark.range(10)
In [2]: df.withColumnsRenamed({"id": "a", "a": "b"})
Out[2]: DataFrame[a: bigint]
In [3]: df.withColumnsRenamed({"a": "b", "id": "a"})
Out[3]: DataFrame[a: bigint]
```
after this PR
```
In [1]: df = spark.range(10)
In [2]: df.withColumnsRenamed({"id": "a", "a": "b"})
Out[2]: DataFrame[b: bigint]
In [3]: df.withColumnsRenamed({"a": "b", "id": "a"})
Out[3]: DataFrame[a: bigint]
```
### How was this patch tested?
added ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44177 from zhengruifeng/sql_withColumnsRenamed_sql.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/dataframe.py | 13 ++++++++++-
.../sql/tests/connect/test_parity_dataframe.py | 5 ++++
python/pyspark/sql/tests/test_dataframe.py | 9 ++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 27 +++++++++++++++-------
.../org/apache/spark/sql/DataFrameSuite.scala | 7 ++++++
5 files changed, 52 insertions(+), 9 deletions(-)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 5211d874ba3..1419d1f3cb6 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -6272,7 +6272,18 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
message_parameters={"arg_name": "colsMap", "arg_type":
type(colsMap).__name__},
)
- return DataFrame(self._jdf.withColumnsRenamed(colsMap),
self.sparkSession)
+ col_names: List[str] = []
+ new_col_names: List[str] = []
+ for k, v in colsMap.items():
+ col_names.append(k)
+ new_col_names.append(v)
+
+ return DataFrame(
+ self._jdf.withColumnsRenamed(
+ _to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names)
+ ),
+ self.sparkSession,
+ )
def withMetadata(self, columnName: str, metadata: Dict[str, Any]) ->
"DataFrame":
"""Returns a new :class:`DataFrame` by updating an existing column
with metadata.
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index b7b4fdcd287..fbef282e0b9 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -77,6 +77,11 @@ class DataFrameParityTests(DataFrameTestsMixin,
ReusedConnectTestCase):
def test_toDF_with_string(self):
super().test_toDF_with_string()
+ # TODO(SPARK-46261): Python Client withColumnsRenamed should respect the
dict ordering
+ @unittest.skip("Fails in Spark Connect, should enable.")
+ def test_ordering_of_with_columns_renamed(self):
+ super().test_ordering_of_with_columns_renamed()
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 52806f4f4a3..c25fe60ad17 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -163,6 +163,15 @@ class DataFrameTestsMixin:
message_parameters={"arg_name": "colsMap", "arg_type": "tuple"},
)
+ def test_ordering_of_with_columns_renamed(self):
+ df = self.spark.range(10)
+
+ df1 = df.withColumnsRenamed({"id": "a", "a": "b"})
+ self.assertEqual(df1.columns, ["b"])
+
+ df2 = df.withColumnsRenamed({"a": "b", "id": "a"})
+ self.assertEqual(df2.columns, ["a"])
+
def test_drop_duplicates(self):
# SPARK-36034 test that drop duplicates throws a type error when in
correct type provided
df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)],
["name", "age"])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 293f20c453a..cacc193885d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2922,18 +2922,29 @@ class Dataset[T] private[sql](
*/
@throws[AnalysisException]
def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin
{
+ val (colNames, newColNames) = colsMap.toSeq.unzip
+ withColumnsRenamed(colNames, newColNames)
+ }
+
+ private def withColumnsRenamed(
+ colNames: Seq[String],
+ newColNames: Seq[String]): DataFrame = withOrigin {
+ require(colNames.size == newColNames.size,
+ s"The size of existing column names: ${colNames.size} isn't equal to " +
+ s"the size of new column names: ${newColNames.size}")
+
val resolver = sparkSession.sessionState.analyzer.resolver
val output: Seq[NamedExpression] = queryExecution.analyzed.output
- val projectList = colsMap.foldLeft(output) {
+ val projectList = colNames.zip(newColNames).foldLeft(output) {
case (attrs, (existingName, newName)) =>
- attrs.map(attr =>
- if (resolver(attr.name, existingName)) {
- Alias(attr, newName)()
- } else {
- attr
- }
- )
+ attrs.map(attr =>
+ if (resolver(attr.name, existingName)) {
+ Alias(attr, newName)()
+ } else {
+ attr
+ }
+ )
}
SchemaUtils.checkColumnNameDuplication(
projectList.map(_.name),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b732f6631a7..25ecefd28cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp}
import java.util.{Locale, UUID}
import java.util.concurrent.atomic.AtomicLong
+import scala.collection.immutable.ListMap
import scala.reflect.runtime.universe.TypeTag
import scala.util.Random
@@ -987,6 +988,12 @@ class DataFrameSuite extends QueryTest
parameters = Map("columnName" -> "`age`"))
}
+ test("SPARK-46260: withColumnsRenamed should respect the Map ordering") {
+ val df = spark.range(10).toDF()
+ assert(df.withColumnsRenamed(ListMap("id" -> "a", "a" -> "b")).columns ===
Array("b"))
+ assert(df.withColumnsRenamed(ListMap("a" -> "b", "id" -> "a")).columns ===
Array("a"))
+ }
+
test("SPARK-20384: Value class filter") {
val df = spark.sparkContext
.parallelize(Seq(StringWrapper("a"), StringWrapper("b"),
StringWrapper("c")))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]