This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3c57180038f [SPARK-41772][CONNECT][PYTHON] Fix incorrect column name 
in `withField`'s doctest
3c57180038f is described below

commit 3c57180038f8ddfcc184a74a5e387a3637e01cc2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Jan 22 17:53:43 2023 +0800

    [SPARK-41772][CONNECT][PYTHON] Fix incorrect column name in `withField`'s 
doctest
    
    ### What changes were proposed in this pull request?
    Fix incorrect column name in `withField`'s doctest
    
    ```
    pyspark.sql.connect.column.Column.withField
    Failed example:
        df.withColumn('a', df['a'].withField('b', lit(3))).select('a.b').show()
    Expected:
        +---+
        |  b|
        +---+
        |  3|
        +---+
    Got:
        +---+
        |a.b|
        +---+
        |  3|
        +---+
        <BLANKLINE>
    ```
    
    ### Why are the changes needed?
    for parity
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    added UT and enabled doctest
    
    Closes #39699 from zhengruifeng/connect_fix_41772.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 19 ++++++++------
 python/pyspark/sql/connect/column.py               |  3 ---
 .../sql/tests/connect/test_connect_column.py       | 29 +++++++++++++++-------
 3 files changed, 31 insertions(+), 20 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f65fc2c8d0f..f95f065c5b3 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -691,9 +691,12 @@ class SparkConnectPlanner(val session: SparkSession) {
     } else {
       logical.OneRowRelation()
     }
-    val projection =
-      
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
-    logical.Project(projectList = projection.toSeq, child = baseRel)
+
+    val projection = rel.getExpressionsList.asScala.toSeq
+      .map(transformExpression)
+      .map(toNamedExpression)
+
+    logical.Project(projectList = projection, child = baseRel)
   }
 
   private def transformUnresolvedExpression(exp: proto.Expression): 
UnresolvedAttribute = {
@@ -745,6 +748,11 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
+  private def toNamedExpression(expr: Expression): NamedExpression = expr 
match {
+    case named: NamedExpression => named
+    case expr => UnresolvedAlias(expr)
+  }
+
   private def transformExpressionPlugin(extension: ProtoAny): Expression = {
     SparkConnectPluginRegistry.expressionRegistry
       // Lazily traverse the collection.
@@ -1245,11 +1253,6 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
     val input = transformRelation(rel.getInput)
 
-    def toNamedExpression(expr: Expression): NamedExpression = expr match {
-      case named: NamedExpression => named
-      case expr => UnresolvedAlias(expr)
-    }
-
     val groupingExprs = 
rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
     val aggExprs = 
rel.getAggregateExpressionsList.asScala.toSeq.map(transformExpression)
     val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression)
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index d2c334ae67f..44200e21495 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -439,9 +439,6 @@ def _test() -> None:
             .getOrCreate()
         )
 
-        # TODO(SPARK-41772): Enable 
pyspark.sql.connect.column.Column.withField doctest
-        del pyspark.sql.connect.column.Column.withField.__doc__
-
         (failure_count, test_count) = doctest.testmod(
             pyspark.sql.connect.column,
             globs=globs,
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index ffee64706d5..1e0609c480c 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -33,6 +33,7 @@ from pyspark.sql.connect.types import (
 )
 
 from pyspark.sql.types import (
+    Row,
     StructField,
     StructType,
     ArrayType,
@@ -58,7 +59,8 @@ from pyspark.sql.connect.client import SparkConnectException
 
 if should_test_connect:
     import pandas as pd
-    from pyspark.sql.connect.functions import lit
+    from pyspark.sql import functions as SF
+    from pyspark.sql.connect import functions as CF
 
 
 class SparkConnectColumnTests(SparkConnectSQLTestCase):
@@ -83,7 +85,7 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
     def test_column_operator(self):
         # SPARK-41351: Column needs to support !=
         df = self.connect.range(10)
-        self.assertEqual(9, len(df.filter(df.id != lit(1)).collect()))
+        self.assertEqual(9, len(df.filter(df.id != CF.lit(1)).collect()))
 
     def test_columns(self):
         # SPARK-41036: test `columns` API for python client.
@@ -133,8 +135,6 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
 
     def test_column_with_null(self):
         # SPARK-41751: test isNull, isNotNull, eqNullSafe
-        from pyspark.sql import functions as SF
-        from pyspark.sql.connect import functions as CF
 
         query = """
             SELECT * FROM VALUES
@@ -313,9 +313,6 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
     def test_none(self):
         # SPARK-41783: test none
 
-        from pyspark.sql import functions as SF
-        from pyspark.sql.connect import functions as CF
-
         query = """
             SELECT * FROM VALUES
             (1, 1, NULL), (2, NULL, 1), (NULL, 3, 4)
@@ -348,8 +345,10 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
 
     def test_simple_binary_expressions(self):
         """Test complex expression"""
-        df = self.connect.read.table(self.tbl_name)
-        pdf = df.select(df.id).where(df.id % lit(30) == 
lit(0)).sort(df.id.asc()).toPandas()
+        cdf = self.connect.read.table(self.tbl_name)
+        pdf = (
+            cdf.select(cdf.id).where(cdf.id % CF.lit(30) == 
CF.lit(0)).sort(cdf.id.asc()).toPandas()
+        )
         self.assertEqual(len(pdf.index), 4)
 
         res = pd.DataFrame(data={"id": [0, 30, 60, 90]})
@@ -964,6 +963,18 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
             ).toPandas(),
         )
 
+    def test_with_field_column_name(self):
+        data = [Row(a=Row(b=1, c=2))]
+
+        cdf = self.connect.createDataFrame(data)
+        cdf1 = cdf.withColumn("a", cdf["a"].withField("b", 
CF.lit(3))).select("a.b")
+
+        sdf = self.spark.createDataFrame(data)
+        sdf1 = sdf.withColumn("a", sdf["a"].withField("b", 
SF.lit(3))).select("a.b")
+
+        self.assertEqual(cdf1.schema, sdf1.schema)
+        self.assertEqual(cdf1.collect(), sdf1.collect())
+
 
 if __name__ == "__main__":
     import os


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to