This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 86f6dde3079 [SPARK-41767][CONNECT][PYTHON] Implement 
`Column.{withField, dropFields}`
86f6dde3079 is described below

commit 86f6dde30798e69c7a953ee59788a4a9831b37cd
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Dec 29 20:57:01 2022 +0900

    [SPARK-41767][CONNECT][PYTHON] Implement `Column.{withField, dropFields}`
    
    ### What changes were proposed in this pull request?
    Implement `Column.{withField, dropFields}`
    
    ### Why are the changes needed?
    For API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT
    
    Closes #39283 from zhengruifeng/connect_column_field.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  |  15 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  17 +++
 python/pyspark/sql/column.py                       |   6 +
 python/pyspark/sql/connect/column.py               |  37 +++++-
 python/pyspark/sql/connect/expressions.py          |  53 ++++++++
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  93 +++++++------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  53 ++++++++
 .../sql/tests/connect/test_connect_column.py       | 147 +++++++++++++++++++--
 8 files changed, 366 insertions(+), 55 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index b8ed9eb6f23..fa2836702c6 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -41,6 +41,7 @@ message Expression {
     LambdaFunction lambda_function = 10;
     Window window = 11;
     UnresolvedExtractValue unresolved_extract_value = 12;
+    UpdateFields update_fields = 13;
   }
 
 
@@ -241,6 +242,20 @@ message Expression {
     Expression extraction = 2;
   }
 
+  // Add, replace or drop a field of `StructType` expression by name.
+  message UpdateFields {
+    // (Required) The struct expression.
+    Expression struct_expression = 1;
+
+    // (Required) The field name.
+    string field_name = 2;
+
+    // (Optional) The expression to add or replace.
+    //
+    // When not set, it means this field will be dropped.
+    Expression value_expression = 3;
+  }
+
   message Alias {
     // (Required) The expression that alias will be added on.
     Expression expr = 1;
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4bb90fc5bc0..d06787e6b14 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -596,6 +596,8 @@ class SparkConnectPlanner(session: SparkSession) {
         transformUnresolvedRegex(exp.getUnresolvedRegex)
       case proto.Expression.ExprTypeCase.UNRESOLVED_EXTRACT_VALUE =>
         transformUnresolvedExtractValue(exp.getUnresolvedExtractValue)
+      case proto.Expression.ExprTypeCase.UPDATE_FIELDS =>
+        transformUpdateFields(exp.getUpdateFields)
       case proto.Expression.ExprTypeCase.SORT_ORDER => 
transformSortOrder(exp.getSortOrder)
       case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION =>
         transformLambdaFunction(exp.getLambdaFunction)
@@ -860,6 +862,21 @@ class SparkConnectPlanner(session: SparkSession) {
       transformExpression(extract.getExtraction))
   }
 
+  private def transformUpdateFields(update: proto.Expression.UpdateFields): 
UpdateFields = {
+    if (update.hasValueExpression) {
+      // add or replace a field
+      UpdateFields.apply(
+        col = transformExpression(update.getStructExpression),
+        fieldName = update.getFieldName,
+        expr = transformExpression(update.getValueExpression))
+    } else {
+      // drop a field
+      UpdateFields.apply(
+        col = transformExpression(update.getStructExpression),
+        fieldName = update.getFieldName)
+    }
+  }
+
   private def transformWindowExpression(window: proto.Expression.Window) = {
     if (!window.hasWindowFunction) {
       throw InvalidPlanInput(s"WindowFunction is required in WindowExpression")
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 5a0987b4cfe..cd7b6932c2f 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -522,6 +522,9 @@ class Column:
 
         .. versionadded:: 3.1.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         fieldName : str
@@ -569,6 +572,9 @@ class Column:
 
         .. versionadded:: 3.1.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         fieldNames : str
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 58d86a3d389..2667e795974 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -43,6 +43,8 @@ from pyspark.sql.connect.expressions import (
     SortOrder,
     CastExpression,
     WindowExpression,
+    WithField,
+    DropField,
 )
 
 
@@ -359,11 +361,38 @@ class Column:
 
     getField.__doc__ = PySparkColumn.getField.__doc__
 
-    def withField(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("withField() is not yet implemented.")
+    def withField(self, fieldName: str, col: "Column") -> "Column":
+        if not isinstance(fieldName, str):
+            raise TypeError(
+                f"fieldName should be a string, but got 
{type(fieldName).__name__} {fieldName}"
+            )
+
+        if not isinstance(col, Column):
+            raise TypeError(f"col should be a Column, but got 
{type(col).__name__} {col}")
+
+        return Column(WithField(self._expr, fieldName, col._expr))
+
+    withField.__doc__ = PySparkColumn.withField.__doc__
+
+    def dropFields(self, *fieldNames: str) -> "Column":
+        dropField: Optional[DropField] = None
+        for fieldName in fieldNames:
+            if not isinstance(fieldName, str):
+                raise TypeError(
+                    f"fieldName should be a string, but got 
{type(fieldName).__name__} {fieldName}"
+                )
+
+            if dropField is None:
+                dropField = DropField(self._expr, fieldName)
+            else:
+                dropField = DropField(dropField, fieldName)
+
+        if dropField is None:
+            raise ValueError("dropFields requires at least 1 field")
+
+        return Column(dropField)
 
-    def dropFields(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("dropFields() is not yet implemented.")
+    dropFields.__doc__ = PySparkColumn.dropFields.__doc__
 
     def __getattr__(self, item: Any) -> "Column":
         if item.startswith("__"):
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index fa0cfd52b1b..27397fc0c13 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -420,6 +420,59 @@ class UnresolvedFunction(Expression):
             return f"{self._name}({', '.join([str(arg) for arg in 
self._args])})"
 
 
+class WithField(Expression):
+    def __init__(
+        self,
+        structExpr: Expression,
+        fieldName: str,
+        valueExpr: Expression,
+    ) -> None:
+        super().__init__()
+
+        assert isinstance(structExpr, Expression)
+        self._structExpr = structExpr
+
+        assert isinstance(fieldName, str)
+        self._fieldName = fieldName
+
+        assert isinstance(valueExpr, Expression)
+        self._valueExpr = valueExpr
+
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+        expr = proto.Expression()
+        
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
+        expr.update_fields.field_name = self._fieldName
+        
expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session))
+        return expr
+
+    def __repr__(self) -> str:
+        return f"WithField({self._structExpr}, {self._fieldName}, 
{self._valueExpr})"
+
+
+class DropField(Expression):
+    def __init__(
+        self,
+        structExpr: Expression,
+        fieldName: str,
+    ) -> None:
+        super().__init__()
+
+        assert isinstance(structExpr, Expression)
+        self._structExpr = structExpr
+
+        assert isinstance(fieldName, str)
+        self._fieldName = fieldName
+
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+        expr = proto.Expression()
+        
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
+        expr.update_fields.field_name = self._fieldName
+        return expr
+
+    def __repr__(self) -> str:
+        return f"DropField({self._structExpr}, {self._fieldName})"
+
+
 class UnresolvedExtractValue(Expression):
     def __init__(
         self,
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 849b10cf90e..01c24d1bcd9 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa5\x1f\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb2!\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_strin
 [...]
 )
 
 
@@ -54,6 +54,7 @@ _EXPRESSION_EXPRESSIONSTRING = 
_EXPRESSION.nested_types_by_name["ExpressionStrin
 _EXPRESSION_UNRESOLVEDSTAR = _EXPRESSION.nested_types_by_name["UnresolvedStar"]
 _EXPRESSION_UNRESOLVEDREGEX = 
_EXPRESSION.nested_types_by_name["UnresolvedRegex"]
 _EXPRESSION_UNRESOLVEDEXTRACTVALUE = 
_EXPRESSION.nested_types_by_name["UnresolvedExtractValue"]
+_EXPRESSION_UPDATEFIELDS = _EXPRESSION.nested_types_by_name["UpdateFields"]
 _EXPRESSION_ALIAS = _EXPRESSION.nested_types_by_name["Alias"]
 _EXPRESSION_LAMBDAFUNCTION = _EXPRESSION.nested_types_by_name["LambdaFunction"]
 _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = 
_EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
@@ -191,6 +192,15 @@ Expression = _reflection.GeneratedProtocolMessageType(
                 # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedExtractValue)
             },
         ),
+        "UpdateFields": _reflection.GeneratedProtocolMessageType(
+            "UpdateFields",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _EXPRESSION_UPDATEFIELDS,
+                "__module__": "spark.connect.expressions_pb2"
+                # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.UpdateFields)
+            },
+        ),
         "Alias": _reflection.GeneratedProtocolMessageType(
             "Alias",
             (_message.Message,),
@@ -229,6 +239,7 @@ _sym_db.RegisterMessage(Expression.ExpressionString)
 _sym_db.RegisterMessage(Expression.UnresolvedStar)
 _sym_db.RegisterMessage(Expression.UnresolvedRegex)
 _sym_db.RegisterMessage(Expression.UnresolvedExtractValue)
+_sym_db.RegisterMessage(Expression.UpdateFields)
 _sym_db.RegisterMessage(Expression.Alias)
 _sym_db.RegisterMessage(Expression.LambdaFunction)
 
@@ -237,43 +248,45 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 78
-    _EXPRESSION._serialized_end = 4083
-    _EXPRESSION_WINDOW._serialized_start = 1053
-    _EXPRESSION_WINDOW._serialized_end = 1836
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1343
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1836
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1610
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1755
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1757
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1836
-    _EXPRESSION_SORTORDER._serialized_start = 1839
-    _EXPRESSION_SORTORDER._serialized_end = 2264
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2069
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2177
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2179
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2264
-    _EXPRESSION_CAST._serialized_start = 2267
-    _EXPRESSION_CAST._serialized_end = 2412
-    _EXPRESSION_LITERAL._serialized_start = 2415
-    _EXPRESSION_LITERAL._serialized_end = 3291
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3058
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3175
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3177
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3275
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3293
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3363
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3366
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3570
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3572
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3622
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3624
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3664
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3666
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3710
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3713
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3845
-    _EXPRESSION_ALIAS._serialized_start = 3847
-    _EXPRESSION_ALIAS._serialized_end = 3967
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3969
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4070
+    _EXPRESSION._serialized_end = 4352
+    _EXPRESSION_WINDOW._serialized_start = 1132
+    _EXPRESSION_WINDOW._serialized_end = 1915
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1422
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1915
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1689
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1834
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1836
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1915
+    _EXPRESSION_SORTORDER._serialized_start = 1918
+    _EXPRESSION_SORTORDER._serialized_end = 2343
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2148
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2256
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2258
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2343
+    _EXPRESSION_CAST._serialized_start = 2346
+    _EXPRESSION_CAST._serialized_end = 2491
+    _EXPRESSION_LITERAL._serialized_start = 2494
+    _EXPRESSION_LITERAL._serialized_end = 3370
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3137
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3254
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3256
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3354
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3372
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3442
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3445
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3649
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3651
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3701
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3703
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3743
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3745
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3789
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3792
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3924
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 3927
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 4114
+    _EXPRESSION_ALIAS._serialized_start = 4116
+    _EXPRESSION_ALIAS._serialized_end = 4236
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4238
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4339
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 6a248a04767..5e5eab5b5d9 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -766,6 +766,50 @@ class Expression(google.protobuf.message.Message):
             field_name: typing_extensions.Literal["child", b"child", 
"extraction", b"extraction"],
         ) -> None: ...
 
+    class UpdateFields(google.protobuf.message.Message):
+        """Add, replace or drop a field of `StructType` expression by name."""
+
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        STRUCT_EXPRESSION_FIELD_NUMBER: builtins.int
+        FIELD_NAME_FIELD_NUMBER: builtins.int
+        VALUE_EXPRESSION_FIELD_NUMBER: builtins.int
+        @property
+        def struct_expression(self) -> global___Expression:
+            """(Required) The struct expression."""
+        field_name: builtins.str
+        """(Required) The field name."""
+        @property
+        def value_expression(self) -> global___Expression:
+            """(Optional) The expression to add or replace.
+
+            When not set, it means this field will be dropped.
+            """
+        def __init__(
+            self,
+            *,
+            struct_expression: global___Expression | None = ...,
+            field_name: builtins.str = ...,
+            value_expression: global___Expression | None = ...,
+        ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "struct_expression", b"struct_expression", "value_expression", 
b"value_expression"
+            ],
+        ) -> builtins.bool: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "field_name",
+                b"field_name",
+                "struct_expression",
+                b"struct_expression",
+                "value_expression",
+                b"value_expression",
+            ],
+        ) -> None: ...
+
     class Alias(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
@@ -853,6 +897,7 @@ class Expression(google.protobuf.message.Message):
     LAMBDA_FUNCTION_FIELD_NUMBER: builtins.int
     WINDOW_FIELD_NUMBER: builtins.int
     UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int
+    UPDATE_FIELDS_FIELD_NUMBER: builtins.int
     @property
     def literal(self) -> global___Expression.Literal: ...
     @property
@@ -877,6 +922,8 @@ class Expression(google.protobuf.message.Message):
     def window(self) -> global___Expression.Window: ...
     @property
     def unresolved_extract_value(self) -> 
global___Expression.UnresolvedExtractValue: ...
+    @property
+    def update_fields(self) -> global___Expression.UpdateFields: ...
     def __init__(
         self,
         *,
@@ -892,6 +939,7 @@ class Expression(google.protobuf.message.Message):
         lambda_function: global___Expression.LambdaFunction | None = ...,
         window: global___Expression.Window | None = ...,
         unresolved_extract_value: global___Expression.UnresolvedExtractValue | 
None = ...,
+        update_fields: global___Expression.UpdateFields | None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -920,6 +968,8 @@ class Expression(google.protobuf.message.Message):
             b"unresolved_regex",
             "unresolved_star",
             b"unresolved_star",
+            "update_fields",
+            b"update_fields",
             "window",
             b"window",
         ],
@@ -951,6 +1001,8 @@ class Expression(google.protobuf.message.Message):
             b"unresolved_regex",
             "unresolved_star",
             b"unresolved_star",
+            "update_fields",
+            b"update_fields",
             "window",
             b"window",
         ],
@@ -970,6 +1022,7 @@ class Expression(google.protobuf.message.Message):
         "lambda_function",
         "window",
         "unresolved_extract_value",
+        "update_fields",
     ] | None: ...
 
 global___Expression = Expression
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index 0be990ebbe1..9d18a1fe9b2 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -54,6 +54,7 @@ from pyspark.sql.types import (
     BooleanType,
 )
 from pyspark.testing.connectutils import should_test_connect
+from pyspark.sql.connect.client import SparkConnectException
 
 if should_test_connect:
     import pandas as pd
@@ -61,6 +62,24 @@ if should_test_connect:
 
 
 class SparkConnectTests(SparkConnectSQLTestCase):
+    def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
+        from pyspark.sql.dataframe import DataFrame as SDF
+        from pyspark.sql.connect.dataframe import DataFrame as CDF
+
+        assert isinstance(df1, (SDF, CDF))
+        if isinstance(df1, SDF):
+            str1 = df1._jdf.showString(n, truncate, False)
+        else:
+            str1 = df1._show_string(n, truncate, False)
+
+        assert isinstance(df2, (SDF, CDF))
+        if isinstance(df2, SDF):
+            str2 = df2._jdf.showString(n, truncate, False)
+        else:
+            str2 = df2._show_string(n, truncate, False)
+
+        self.assertEqual(str1, str2)
+
     def test_column_operator(self):
         # SPARK-41351: Column needs to support !=
         df = self.connect.range(10)
@@ -184,6 +203,13 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         ):
             not (cdf.a > 2)
 
+        with self.assertRaisesRegex(
+            TypeError,
+            "Column is not iterable",
+        ):
+            for x in cdf.a:
+                pass
+
     def test_datetime(self):
         query = """
             SELECT * FROM VALUES
@@ -743,19 +769,118 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             sdf.select(sdf.a ** sdf["b"], sdf.d**2, 2**sdf.c).toPandas(),
         )
 
-    def test_unsupported_functions(self):
-        # SPARK-41225: Disable unsupported functions.
-        c = self.connect.range(1).id
-        for f in (
-            "withField",
-            "dropFields",
+    def test_column_field_ops(self):
+        # SPARK-41767: test withField, dropFields
+
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        query = """
+            SELECT STRUCT(a, b, c, d) AS x, e FROM VALUES
+            (float(1.0), double(1.0), '2022', 1, 0),
+            (float(2.0), double(2.0), '2018', NULL, 2),
+            (float(3.0), double(3.0), NULL, 3, NULL)
+            AS tab(a, b, c, d, e)
+            """
+
+        # +----------------------+----+
+        # |                     x|   e|
+        # +----------------------+----+
+        # |   {1.0, 1.0, 2022, 1}|   0|
+        # |{2.0, 2.0, 2018, null}|   2|
+        # |   {3.0, 3.0, null, 3}|null|
+        # +----------------------+----+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        # add field
+        self.compare_by_show(
+            cdf.select(cdf.x.withField("z", cdf.e)),
+            sdf.select(sdf.x.withField("z", sdf.e)),
+            truncate=100,
+        )
+        self.compare_by_show(
+            cdf.select(cdf.x.withField("z", CF.col("e"))),
+            sdf.select(sdf.x.withField("z", SF.col("e"))),
+            truncate=100,
+        )
+        self.compare_by_show(
+            cdf.select(cdf.x.withField("z", CF.lit("xyz"))),
+            sdf.select(sdf.x.withField("z", SF.lit("xyz"))),
+            truncate=100,
+        )
+
+        # replace field
+        self.compare_by_show(
+            cdf.select(cdf.x.withField("a", cdf.e)),
+            sdf.select(sdf.x.withField("a", sdf.e)),
+            truncate=100,
+        )
+        self.compare_by_show(
+            cdf.select(cdf.x.withField("a", CF.col("e"))),
+            sdf.select(sdf.x.withField("a", SF.col("e"))),
+            truncate=100,
+        )
+        self.compare_by_show(
+            cdf.select(cdf.x.withField("a", CF.lit("xyz"))),
+            sdf.select(sdf.x.withField("a", SF.lit("xyz"))),
+            truncate=100,
+        )
+
+        # drop field
+        self.compare_by_show(
+            cdf.select(cdf.x.dropFields("a")),
+            sdf.select(sdf.x.dropFields("a")),
+            truncate=100,
+        )
+        self.compare_by_show(
+            cdf.select(cdf.x.dropFields("z")),
+            sdf.select(sdf.x.dropFields("z")),
+            truncate=100,
+        )
+        self.compare_by_show(
+            cdf.select(cdf.x.dropFields("a", "b", "z")),
+            sdf.select(sdf.x.dropFields("a", "b", "z")),
+            truncate=100,
+        )
+
+        # check error
+        # invalid column: not a struct column
+        with self.assertRaises(SparkConnectException):
+            cdf.select(cdf.e.withField("a", CF.lit(1))).show()
+
+        # invalid column: not a struct column
+        with self.assertRaises(SparkConnectException):
+            cdf.select(cdf.e.dropFields("a")).show()
+
+        # cannot drop all fields in struct
+        with self.assertRaises(SparkConnectException):
+            cdf.select(cdf.x.dropFields("a", "b", "c", "d")).show()
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "fieldName should be a string",
         ):
-            with self.assertRaises(NotImplementedError):
-                getattr(c, f)()
+            cdf.select(cdf.x.withField(CF.col("a"), cdf.e)).show()
 
-        with self.assertRaises(TypeError):
-            for x in c:
-                pass
+        with self.assertRaisesRegex(
+            TypeError,
+            "col should be a Column",
+        ):
+            cdf.select(cdf.x.withField("a", 2)).show()
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "fieldName should be a string",
+        ):
+            cdf.select(cdf.x.dropFields("a", 1, 2)).show()
+
+        with self.assertRaisesRegex(
+            ValueError,
+            "dropFields requires at least 1 field",
+        ):
+            cdf.select(cdf.x.dropFields()).show()
 
     def test_column_string_ops(self):
         # SPARK-41764: test string ops


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to