This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 57052f56db8 [SPARK-41731][CONNECT][PYTHON] Implement the column 
accessor
57052f56db8 is described below

commit 57052f56db85c87ead455e1172237fef840d9a68
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Dec 27 21:51:25 2022 +0900

    [SPARK-41731][CONNECT][PYTHON] Implement the column accessor
    
    ### What changes were proposed in this pull request?
    Implement the column accessor:
    1. `getItem`
    2. `getField`
    3. `__getattr__`
    4. `__getitem__`
    
    ### Why are the changes needed?
    For API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT
    
    Closes #39241 from zhengruifeng/column_get_item.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  | 12 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 11 ++-
 python/pyspark/sql/column.py                       |  6 ++
 python/pyspark/sql/connect/column.py               | 44 +++++++++--
 python/pyspark/sql/connect/expressions.py          | 24 ++++++
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 89 +++++++++++++---------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 41 ++++++++++
 .../sql/tests/connect/test_connect_column.py       | 76 ++++++++++++++++--
 8 files changed, 252 insertions(+), 51 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 90d880939a8..b8ed9eb6f23 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -40,6 +40,7 @@ message Expression {
     SortOrder sort_order = 9;
     LambdaFunction lambda_function = 10;
     Window window = 11;
+    UnresolvedExtractValue unresolved_extract_value = 12;
   }
 
 
@@ -229,6 +230,17 @@ message Expression {
     string col_name = 1;
   }
 
+  // Extracts a value or values from an Expression
+  message UnresolvedExtractValue {
+    // (Required) The expression to extract value from, can be
+    // Map, Array, Struct or array of Structs.
+    Expression child = 1;
+
+    // (Required) The expression to describe the extraction, can be
+    // key of Map, index of Array, field name of Struct.
+    Expression extraction = 2;
+  }
+
   message Alias {
     // (Required) The expression that alias will be added on.
     Expression expr = 1;
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 1645eb2c381..c1e96b9d991 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -27,7 +27,7 @@ import org.apache.spark.api.python.{PythonEvalType, 
SimplePythonFunction}
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, 
UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, 
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParseException, ParserUtils}
@@ -556,6 +556,8 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Expression.ExprTypeCase.CAST => transformCast(exp.getCast)
       case proto.Expression.ExprTypeCase.UNRESOLVED_REGEX =>
         transformUnresolvedRegex(exp.getUnresolvedRegex)
+      case proto.Expression.ExprTypeCase.UNRESOLVED_EXTRACT_VALUE =>
+        transformUnresolvedExtractValue(exp.getUnresolvedExtractValue)
       case proto.Expression.ExprTypeCase.SORT_ORDER => 
transformSortOrder(exp.getSortOrder)
       case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION =>
         transformLambdaFunction(exp.getLambdaFunction)
@@ -813,6 +815,13 @@ class SparkConnectPlanner(session: SparkSession) {
     }
   }
 
+  private def transformUnresolvedExtractValue(
+      extract: proto.Expression.UnresolvedExtractValue): 
UnresolvedExtractValue = {
+    UnresolvedExtractValue(
+      transformExpression(extract.getChild),
+      transformExpression(extract.getExtraction))
+  }
+
   private def transformWindowExpression(window: proto.Expression.Window) = {
     if (!window.hasWindowFunction) {
       throw InvalidPlanInput(s"WindowFunction is required in WindowExpression")
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index bb43b0af57c..96b4333e604 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -430,6 +430,9 @@ class Column:
 
         .. versionadded:: 1.3.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         key
@@ -469,6 +472,9 @@ class Column:
 
         .. versionadded:: 1.3.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         name
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 918d6cd2adc..c9bc434fec3 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -17,6 +17,7 @@
 
 import datetime
 import decimal
+import warnings
 
 from typing import (
     TYPE_CHECKING,
@@ -34,6 +35,7 @@ import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect.expressions import (
     Expression,
     UnresolvedFunction,
+    UnresolvedExtractValue,
     SQLExpression,
     LiteralExpression,
     CaseWhen,
@@ -351,9 +353,6 @@ class Column:
 
     isin.__doc__ = PySparkColumn.isin.__doc__
 
-    def getItem(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("getItem() is not yet implemented.")
-
     def between(
         self,
         lowerBound: Union["Column", "LiteralType", "DateTimeLiteral", 
"DecimalLiteral"],
@@ -363,8 +362,29 @@ class Column:
 
     between.__doc__ = PySparkColumn.between.__doc__
 
-    def getField(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("getField() is not yet implemented.")
+    def getItem(self, key: Any) -> "Column":
+        if isinstance(key, Column):
+            warnings.warn(
+                "A column as 'key' in getItem is deprecated as of Spark 3.0, 
and will not "
+                "be supported in the future release. Use `column[key]` or 
`column.key` syntax "
+                "instead.",
+                FutureWarning,
+            )
+        return self[key]
+
+    getItem.__doc__ = PySparkColumn.getItem.__doc__
+
+    def getField(self, name: Any) -> "Column":
+        if isinstance(name, Column):
+            warnings.warn(
+                "A column as 'name' in getField is deprecated as of Spark 3.0, 
and will not "
+                "be supported in the future release. Use `column[name]` or 
`column.name` syntax "
+                "instead.",
+                FutureWarning,
+            )
+        return self[name]
+
+    getField.__doc__ = PySparkColumn.getField.__doc__
 
     def withField(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("withField() is not yet implemented.")
@@ -372,8 +392,18 @@ class Column:
     def dropFields(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("dropFields() is not yet implemented.")
 
-    def __getitem__(self, k: Any) -> None:
-        raise NotImplementedError("apply() - __getitem__ is not yet 
implemented.")
+    def __getattr__(self, item: Any) -> "Column":
+        if item.startswith("__"):
+            raise AttributeError(item)
+        return self[item]
+
+    def __getitem__(self, k: Any) -> "Column":
+        if isinstance(k, slice):
+            if k.step is not None:
+                raise ValueError("slice with step is not supported.")
+            return self.substr(k.start, k.stop)
+        else:
+            return Column(UnresolvedExtractValue(self._expr, 
LiteralExpression._from_value(k)))
 
     def __iter__(self) -> None:
         raise TypeError("Column is not iterable")
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 1f63d6b0a10..fa0cfd52b1b 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -420,6 +420,30 @@ class UnresolvedFunction(Expression):
             return f"{self._name}({', '.join([str(arg) for arg in 
self._args])})"
 
 
+class UnresolvedExtractValue(Expression):
+    def __init__(
+        self,
+        child: Expression,
+        extraction: Expression,
+    ) -> None:
+        super().__init__()
+
+        assert isinstance(child, Expression)
+        self._child = child
+
+        assert isinstance(extraction, Expression)
+        self._extraction = extraction
+
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+        expr = proto.Expression()
+        
expr.unresolved_extract_value.child.CopyFrom(self._child.to_plan(session))
+        
expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session))
+        return expr
+
+    def __repr__(self) -> str:
+        return f"UnresolvedExtractValue({str(self._child)}, 
{str(self._extraction)})"
+
+
 class UnresolvedRegex(Expression):
     def __init__(
         self,
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 5e4d25b8b94..849b10cf90e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb0\x1d\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa5\x1f\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
 )
 
 
@@ -53,6 +53,7 @@ _EXPRESSION_UNRESOLVEDFUNCTION = 
_EXPRESSION.nested_types_by_name["UnresolvedFun
 _EXPRESSION_EXPRESSIONSTRING = 
_EXPRESSION.nested_types_by_name["ExpressionString"]
 _EXPRESSION_UNRESOLVEDSTAR = _EXPRESSION.nested_types_by_name["UnresolvedStar"]
 _EXPRESSION_UNRESOLVEDREGEX = 
_EXPRESSION.nested_types_by_name["UnresolvedRegex"]
+_EXPRESSION_UNRESOLVEDEXTRACTVALUE = 
_EXPRESSION.nested_types_by_name["UnresolvedExtractValue"]
 _EXPRESSION_ALIAS = _EXPRESSION.nested_types_by_name["Alias"]
 _EXPRESSION_LAMBDAFUNCTION = _EXPRESSION.nested_types_by_name["LambdaFunction"]
 _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = 
_EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
@@ -181,6 +182,15 @@ Expression = _reflection.GeneratedProtocolMessageType(
                 # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedRegex)
             },
         ),
+        "UnresolvedExtractValue": _reflection.GeneratedProtocolMessageType(
+            "UnresolvedExtractValue",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _EXPRESSION_UNRESOLVEDEXTRACTVALUE,
+                "__module__": "spark.connect.expressions_pb2"
+                # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedExtractValue)
+            },
+        ),
         "Alias": _reflection.GeneratedProtocolMessageType(
             "Alias",
             (_message.Message,),
@@ -218,6 +228,7 @@ _sym_db.RegisterMessage(Expression.UnresolvedFunction)
 _sym_db.RegisterMessage(Expression.ExpressionString)
 _sym_db.RegisterMessage(Expression.UnresolvedStar)
 _sym_db.RegisterMessage(Expression.UnresolvedRegex)
+_sym_db.RegisterMessage(Expression.UnresolvedExtractValue)
 _sym_db.RegisterMessage(Expression.Alias)
 _sym_db.RegisterMessage(Expression.LambdaFunction)
 
@@ -226,41 +237,43 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 78
-    _EXPRESSION._serialized_end = 3838
-    _EXPRESSION_WINDOW._serialized_start = 943
-    _EXPRESSION_WINDOW._serialized_end = 1726
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1233
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1726
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1500
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1645
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1647
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1726
-    _EXPRESSION_SORTORDER._serialized_start = 1729
-    _EXPRESSION_SORTORDER._serialized_end = 2154
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 1959
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2067
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2069
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2154
-    _EXPRESSION_CAST._serialized_start = 2157
-    _EXPRESSION_CAST._serialized_end = 2302
-    _EXPRESSION_LITERAL._serialized_start = 2305
-    _EXPRESSION_LITERAL._serialized_end = 3181
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 2948
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3065
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3067
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3165
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3183
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3253
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3256
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3460
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3462
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3512
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3514
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3554
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3556
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3600
-    _EXPRESSION_ALIAS._serialized_start = 3602
-    _EXPRESSION_ALIAS._serialized_end = 3722
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3724
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 3825
+    _EXPRESSION._serialized_end = 4083
+    _EXPRESSION_WINDOW._serialized_start = 1053
+    _EXPRESSION_WINDOW._serialized_end = 1836
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1343
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1836
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1610
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1755
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1757
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1836
+    _EXPRESSION_SORTORDER._serialized_start = 1839
+    _EXPRESSION_SORTORDER._serialized_end = 2264
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2069
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2177
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2179
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2264
+    _EXPRESSION_CAST._serialized_start = 2267
+    _EXPRESSION_CAST._serialized_end = 2412
+    _EXPRESSION_LITERAL._serialized_start = 2415
+    _EXPRESSION_LITERAL._serialized_end = 3291
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3058
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3175
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3177
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3275
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3293
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3363
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3366
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3570
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3572
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3622
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3624
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3664
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3666
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3710
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3713
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3845
+    _EXPRESSION_ALIAS._serialized_start = 3847
+    _EXPRESSION_ALIAS._serialized_end = 3967
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3969
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4070
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 26002e649d5..6a248a04767 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -734,6 +734,38 @@ class Expression(google.protobuf.message.Message):
             self, field_name: typing_extensions.Literal["col_name", 
b"col_name"]
         ) -> None: ...
 
+    class UnresolvedExtractValue(google.protobuf.message.Message):
+        """Extracts a value or values from an Expression"""
+
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        CHILD_FIELD_NUMBER: builtins.int
+        EXTRACTION_FIELD_NUMBER: builtins.int
+        @property
+        def child(self) -> global___Expression:
+            """(Required) The expression to extract value from, can be
+            Map, Array, Struct or array of Structs.
+            """
+        @property
+        def extraction(self) -> global___Expression:
+            """(Required) The expression to describe the extraction, can be
+            key of Map, index of Array, field name of Struct.
+            """
+        def __init__(
+            self,
+            *,
+            child: global___Expression | None = ...,
+            extraction: global___Expression | None = ...,
+        ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal["child", b"child", 
"extraction", b"extraction"],
+        ) -> builtins.bool: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal["child", b"child", 
"extraction", b"extraction"],
+        ) -> None: ...
+
     class Alias(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
@@ -820,6 +852,7 @@ class Expression(google.protobuf.message.Message):
     SORT_ORDER_FIELD_NUMBER: builtins.int
     LAMBDA_FUNCTION_FIELD_NUMBER: builtins.int
     WINDOW_FIELD_NUMBER: builtins.int
+    UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int
     @property
     def literal(self) -> global___Expression.Literal: ...
     @property
@@ -842,6 +875,8 @@ class Expression(google.protobuf.message.Message):
     def lambda_function(self) -> global___Expression.LambdaFunction: ...
     @property
     def window(self) -> global___Expression.Window: ...
+    @property
+    def unresolved_extract_value(self) -> 
global___Expression.UnresolvedExtractValue: ...
     def __init__(
         self,
         *,
@@ -856,6 +891,7 @@ class Expression(google.protobuf.message.Message):
         sort_order: global___Expression.SortOrder | None = ...,
         lambda_function: global___Expression.LambdaFunction | None = ...,
         window: global___Expression.Window | None = ...,
+        unresolved_extract_value: global___Expression.UnresolvedExtractValue | 
None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -876,6 +912,8 @@ class Expression(google.protobuf.message.Message):
             b"sort_order",
             "unresolved_attribute",
             b"unresolved_attribute",
+            "unresolved_extract_value",
+            b"unresolved_extract_value",
             "unresolved_function",
             b"unresolved_function",
             "unresolved_regex",
@@ -905,6 +943,8 @@ class Expression(google.protobuf.message.Message):
             b"sort_order",
             "unresolved_attribute",
             b"unresolved_attribute",
+            "unresolved_extract_value",
+            b"unresolved_extract_value",
             "unresolved_function",
             b"unresolved_function",
             "unresolved_regex",
@@ -929,6 +969,7 @@ class Expression(google.protobuf.message.Message):
         "sort_order",
         "lambda_function",
         "window",
+        "unresolved_extract_value",
     ] | None: ...
 
 global___Expression = Expression
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index e0c883a7f76..9f5587ccce5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -538,21 +538,87 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             ).toPandas(),
         )
 
+    def test_column_accessor(self):
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        query = """
+            SELECT STRUCT(a, b, c) AS x, y, z, c FROM VALUES
+            (float(1.0), double(1.0), '2022', MAP('b', '123', 'a', 'kk'), 
ARRAY(1, 2, 3)),
+            (float(2.0), double(2.0), '2018', MAP('a', 'xy'), ARRAY(-1, -2, 
-3)),
+            (float(3.0), double(3.0), NULL, MAP('a', 'ab'), ARRAY(-1, 0, 1))
+            AS tab(a, b, c, y, z)
+            """
+
+        # +----------------+-------------------+------------+----+
+        # |               x|                  y|           z|   c|
+        # +----------------+-------------------+------------+----+
+        # |{1.0, 1.0, 2022}|{b -> 123, a -> kk}|   [1, 2, 3]|2022|
+        # |{2.0, 2.0, 2018}|          {a -> xy}|[-1, -2, -3]|2018|
+        # |{3.0, 3.0, null}|          {a -> ab}|  [-1, 0, 1]|null|
+        # +----------------+-------------------+------------+----+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        # test struct
+        self.assert_eq(
+            cdf.select(cdf.x.a, cdf.x["b"], cdf["x"].c).toPandas(),
+            sdf.select(sdf.x.a, sdf.x["b"], sdf["x"].c).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.col("x").a, cdf.x.b, CF.col("x")["c"]).toPandas(),
+            sdf.select(SF.col("x").a, sdf.x.b, SF.col("x")["c"]).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(cdf.x.getItem("a"), cdf.x.getItem("b"), 
cdf["x"].getField("c")).toPandas(),
+            sdf.select(sdf.x.getItem("a"), sdf.x.getItem("b"), 
sdf["x"].getField("c")).toPandas(),
+        )
+
+        # test map
+        self.assert_eq(
+            cdf.select(cdf.y.a, cdf.y["b"], cdf["y"].c).toPandas(),
+            sdf.select(sdf.y.a, sdf.y["b"], sdf["y"].c).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.col("y").a, cdf.y.b, CF.col("y")["c"]).toPandas(),
+            sdf.select(SF.col("y").a, sdf.y.b, SF.col("y")["c"]).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(cdf.y.getItem("a"), cdf.y.getItem("b"), 
cdf["y"].getField("c")).toPandas(),
+            sdf.select(sdf.y.getItem("a"), sdf.y.getItem("b"), 
sdf["y"].getField("c")).toPandas(),
+        )
+
+        # test array
+        self.assert_eq(
+            cdf.select(cdf.z[0], cdf.z[1], cdf["z"][2]).toPandas(),
+            sdf.select(sdf.z[0], sdf.z[1], sdf["z"][2]).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.col("z")[0], cdf.z[10], CF.col("z")[-10]).toPandas(),
+            sdf.select(SF.col("z")[0], sdf.z[10], SF.col("z")[-10]).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(cdf.z.getItem(0), cdf.z.getItem(1), 
cdf["z"].getField(2)).toPandas(),
+            sdf.select(sdf.z.getItem(0), sdf.z.getItem(1), 
sdf["z"].getField(2)).toPandas(),
+        )
+
+        # test string with slice
+        self.assert_eq(
+            cdf.select(cdf.c[0:1], cdf["c"][2:10]).toPandas(),
+            sdf.select(sdf.c[0:1], sdf["c"][2:10]).toPandas(),
+        )
+
     def test_unsupported_functions(self):
         # SPARK-41225: Disable unsupported functions.
         c = self.connect.range(1).id
         for f in (
-            "getItem",
-            "getField",
             "withField",
             "dropFields",
         ):
             with self.assertRaises(NotImplementedError):
                 getattr(c, f)()
 
-        with self.assertRaises(NotImplementedError):
-            c["a"]
-
         with self.assertRaises(TypeError):
             for x in c:
                 pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to