This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 57052f56db8 [SPARK-41731][CONNECT][PYTHON] Implement the column
accessor
57052f56db8 is described below
commit 57052f56db85c87ead455e1172237fef840d9a68
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Dec 27 21:51:25 2022 +0900
[SPARK-41731][CONNECT][PYTHON] Implement the column accessor
### What changes were proposed in this pull request?
Implement the column accessor:
1. `getItem`
2. `getField`
3. `__getattr__`
4. `__getitem__`
### Why are the changes needed?
For API coverage
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added UT
Closes #39241 from zhengruifeng/column_get_item.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/expressions.proto | 12 +++
.../sql/connect/planner/SparkConnectPlanner.scala | 11 ++-
python/pyspark/sql/column.py | 6 ++
python/pyspark/sql/connect/column.py | 44 +++++++++--
python/pyspark/sql/connect/expressions.py | 24 ++++++
.../pyspark/sql/connect/proto/expressions_pb2.py | 89 +++++++++++++---------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 41 ++++++++++
.../sql/tests/connect/test_connect_column.py | 76 ++++++++++++++++--
8 files changed, 252 insertions(+), 51 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 90d880939a8..b8ed9eb6f23 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -40,6 +40,7 @@ message Expression {
SortOrder sort_order = 9;
LambdaFunction lambda_function = 10;
Window window = 11;
+ UnresolvedExtractValue unresolved_extract_value = 12;
}
@@ -229,6 +230,17 @@ message Expression {
string col_name = 1;
}
+ // Extracts a value or values from an Expression
+ message UnresolvedExtractValue {
+ // (Required) The expression to extract value from, can be
+ // Map, Array, Struct or array of Structs.
+ Expression child = 1;
+
+ // (Required) The expression to describe the extraction, can be
+ // key of Map, index of Array, field name of Struct.
+ Expression extraction = 2;
+ }
+
message Alias {
// (Required) The expression that alias will be added on.
Expression expr = 1;
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 1645eb2c381..c1e96b9d991 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -27,7 +27,7 @@ import org.apache.spark.api.python.{PythonEvalType,
SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction,
UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue,
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
ParseException, ParserUtils}
@@ -556,6 +556,8 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Expression.ExprTypeCase.CAST => transformCast(exp.getCast)
case proto.Expression.ExprTypeCase.UNRESOLVED_REGEX =>
transformUnresolvedRegex(exp.getUnresolvedRegex)
+ case proto.Expression.ExprTypeCase.UNRESOLVED_EXTRACT_VALUE =>
+ transformUnresolvedExtractValue(exp.getUnresolvedExtractValue)
case proto.Expression.ExprTypeCase.SORT_ORDER =>
transformSortOrder(exp.getSortOrder)
case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION =>
transformLambdaFunction(exp.getLambdaFunction)
@@ -813,6 +815,13 @@ class SparkConnectPlanner(session: SparkSession) {
}
}
+ private def transformUnresolvedExtractValue(
+ extract: proto.Expression.UnresolvedExtractValue):
UnresolvedExtractValue = {
+ UnresolvedExtractValue(
+ transformExpression(extract.getChild),
+ transformExpression(extract.getExtraction))
+ }
+
private def transformWindowExpression(window: proto.Expression.Window) = {
if (!window.hasWindowFunction) {
throw InvalidPlanInput(s"WindowFunction is required in WindowExpression")
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index bb43b0af57c..96b4333e604 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -430,6 +430,9 @@ class Column:
.. versionadded:: 1.3.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
key
@@ -469,6 +472,9 @@ class Column:
.. versionadded:: 1.3.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
name
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index 918d6cd2adc..c9bc434fec3 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -17,6 +17,7 @@
import datetime
import decimal
+import warnings
from typing import (
TYPE_CHECKING,
@@ -34,6 +35,7 @@ import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.expressions import (
Expression,
UnresolvedFunction,
+ UnresolvedExtractValue,
SQLExpression,
LiteralExpression,
CaseWhen,
@@ -351,9 +353,6 @@ class Column:
isin.__doc__ = PySparkColumn.isin.__doc__
- def getItem(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("getItem() is not yet implemented.")
-
def between(
self,
lowerBound: Union["Column", "LiteralType", "DateTimeLiteral",
"DecimalLiteral"],
@@ -363,8 +362,29 @@ class Column:
between.__doc__ = PySparkColumn.between.__doc__
- def getField(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("getField() is not yet implemented.")
+ def getItem(self, key: Any) -> "Column":
+ if isinstance(key, Column):
+ warnings.warn(
+ "A column as 'key' in getItem is deprecated as of Spark 3.0,
and will not "
+ "be supported in the future release. Use `column[key]` or
`column.key` syntax "
+ "instead.",
+ FutureWarning,
+ )
+ return self[key]
+
+ getItem.__doc__ = PySparkColumn.getItem.__doc__
+
+ def getField(self, name: Any) -> "Column":
+ if isinstance(name, Column):
+ warnings.warn(
+ "A column as 'name' in getField is deprecated as of Spark 3.0,
and will not "
+ "be supported in the future release. Use `column[name]` or
`column.name` syntax "
+ "instead.",
+ FutureWarning,
+ )
+ return self[name]
+
+ getField.__doc__ = PySparkColumn.getField.__doc__
def withField(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("withField() is not yet implemented.")
@@ -372,8 +392,18 @@ class Column:
def dropFields(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("dropFields() is not yet implemented.")
- def __getitem__(self, k: Any) -> None:
- raise NotImplementedError("apply() - __getitem__ is not yet
implemented.")
+ def __getattr__(self, item: Any) -> "Column":
+ if item.startswith("__"):
+ raise AttributeError(item)
+ return self[item]
+
+ def __getitem__(self, k: Any) -> "Column":
+ if isinstance(k, slice):
+ if k.step is not None:
+ raise ValueError("slice with step is not supported.")
+ return self.substr(k.start, k.stop)
+ else:
+ return Column(UnresolvedExtractValue(self._expr,
LiteralExpression._from_value(k)))
def __iter__(self) -> None:
raise TypeError("Column is not iterable")
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 1f63d6b0a10..fa0cfd52b1b 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -420,6 +420,30 @@ class UnresolvedFunction(Expression):
return f"{self._name}({', '.join([str(arg) for arg in
self._args])})"
+class UnresolvedExtractValue(Expression):
+ def __init__(
+ self,
+ child: Expression,
+ extraction: Expression,
+ ) -> None:
+ super().__init__()
+
+ assert isinstance(child, Expression)
+ self._child = child
+
+ assert isinstance(extraction, Expression)
+ self._extraction = extraction
+
+ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+ expr = proto.Expression()
+
expr.unresolved_extract_value.child.CopyFrom(self._child.to_plan(session))
+
expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session))
+ return expr
+
+ def __repr__(self) -> str:
+ return f"UnresolvedExtractValue({str(self._child)},
{str(self._extraction)})"
+
+
class UnresolvedRegex(Expression):
def __init__(
self,
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 5e4d25b8b94..849b10cf90e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb0\x1d\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa5\x1f\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
[...]
)
@@ -53,6 +53,7 @@ _EXPRESSION_UNRESOLVEDFUNCTION =
_EXPRESSION.nested_types_by_name["UnresolvedFun
_EXPRESSION_EXPRESSIONSTRING =
_EXPRESSION.nested_types_by_name["ExpressionString"]
_EXPRESSION_UNRESOLVEDSTAR = _EXPRESSION.nested_types_by_name["UnresolvedStar"]
_EXPRESSION_UNRESOLVEDREGEX =
_EXPRESSION.nested_types_by_name["UnresolvedRegex"]
+_EXPRESSION_UNRESOLVEDEXTRACTVALUE =
_EXPRESSION.nested_types_by_name["UnresolvedExtractValue"]
_EXPRESSION_ALIAS = _EXPRESSION.nested_types_by_name["Alias"]
_EXPRESSION_LAMBDAFUNCTION = _EXPRESSION.nested_types_by_name["LambdaFunction"]
_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE =
_EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
@@ -181,6 +182,15 @@ Expression = _reflection.GeneratedProtocolMessageType(
#
@@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedRegex)
},
),
+ "UnresolvedExtractValue": _reflection.GeneratedProtocolMessageType(
+ "UnresolvedExtractValue",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _EXPRESSION_UNRESOLVEDEXTRACTVALUE,
+ "__module__": "spark.connect.expressions_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedExtractValue)
+ },
+ ),
"Alias": _reflection.GeneratedProtocolMessageType(
"Alias",
(_message.Message,),
@@ -218,6 +228,7 @@ _sym_db.RegisterMessage(Expression.UnresolvedFunction)
_sym_db.RegisterMessage(Expression.ExpressionString)
_sym_db.RegisterMessage(Expression.UnresolvedStar)
_sym_db.RegisterMessage(Expression.UnresolvedRegex)
+_sym_db.RegisterMessage(Expression.UnresolvedExtractValue)
_sym_db.RegisterMessage(Expression.Alias)
_sym_db.RegisterMessage(Expression.LambdaFunction)
@@ -226,41 +237,43 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 78
- _EXPRESSION._serialized_end = 3838
- _EXPRESSION_WINDOW._serialized_start = 943
- _EXPRESSION_WINDOW._serialized_end = 1726
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1233
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1726
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1500
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1645
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1647
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1726
- _EXPRESSION_SORTORDER._serialized_start = 1729
- _EXPRESSION_SORTORDER._serialized_end = 2154
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 1959
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2067
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2069
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2154
- _EXPRESSION_CAST._serialized_start = 2157
- _EXPRESSION_CAST._serialized_end = 2302
- _EXPRESSION_LITERAL._serialized_start = 2305
- _EXPRESSION_LITERAL._serialized_end = 3181
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 2948
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3065
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3067
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3165
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3183
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3253
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3256
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3460
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3462
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3512
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3514
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3554
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3556
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3600
- _EXPRESSION_ALIAS._serialized_start = 3602
- _EXPRESSION_ALIAS._serialized_end = 3722
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3724
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 3825
+ _EXPRESSION._serialized_end = 4083
+ _EXPRESSION_WINDOW._serialized_start = 1053
+ _EXPRESSION_WINDOW._serialized_end = 1836
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1343
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1836
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1610
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1755
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1757
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1836
+ _EXPRESSION_SORTORDER._serialized_start = 1839
+ _EXPRESSION_SORTORDER._serialized_end = 2264
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2069
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2177
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2179
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2264
+ _EXPRESSION_CAST._serialized_start = 2267
+ _EXPRESSION_CAST._serialized_end = 2412
+ _EXPRESSION_LITERAL._serialized_start = 2415
+ _EXPRESSION_LITERAL._serialized_end = 3291
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3058
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3175
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3177
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3275
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3293
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3363
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3366
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3570
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3572
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3622
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3624
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3664
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3666
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3710
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3713
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3845
+ _EXPRESSION_ALIAS._serialized_start = 3847
+ _EXPRESSION_ALIAS._serialized_end = 3967
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3969
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4070
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 26002e649d5..6a248a04767 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -734,6 +734,38 @@ class Expression(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["col_name",
b"col_name"]
) -> None: ...
+ class UnresolvedExtractValue(google.protobuf.message.Message):
+ """Extracts a value or values from an Expression"""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ CHILD_FIELD_NUMBER: builtins.int
+ EXTRACTION_FIELD_NUMBER: builtins.int
+ @property
+ def child(self) -> global___Expression:
+ """(Required) The expression to extract value from, can be
+ Map, Array, Struct or array of Structs.
+ """
+ @property
+ def extraction(self) -> global___Expression:
+ """(Required) The expression to describe the extraction, can be
+ key of Map, index of Array, field name of Struct.
+ """
+ def __init__(
+ self,
+ *,
+ child: global___Expression | None = ...,
+ extraction: global___Expression | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal["child", b"child",
"extraction", b"extraction"],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal["child", b"child",
"extraction", b"extraction"],
+ ) -> None: ...
+
class Alias(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -820,6 +852,7 @@ class Expression(google.protobuf.message.Message):
SORT_ORDER_FIELD_NUMBER: builtins.int
LAMBDA_FUNCTION_FIELD_NUMBER: builtins.int
WINDOW_FIELD_NUMBER: builtins.int
+ UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int
@property
def literal(self) -> global___Expression.Literal: ...
@property
@@ -842,6 +875,8 @@ class Expression(google.protobuf.message.Message):
def lambda_function(self) -> global___Expression.LambdaFunction: ...
@property
def window(self) -> global___Expression.Window: ...
+ @property
+ def unresolved_extract_value(self) ->
global___Expression.UnresolvedExtractValue: ...
def __init__(
self,
*,
@@ -856,6 +891,7 @@ class Expression(google.protobuf.message.Message):
sort_order: global___Expression.SortOrder | None = ...,
lambda_function: global___Expression.LambdaFunction | None = ...,
window: global___Expression.Window | None = ...,
+ unresolved_extract_value: global___Expression.UnresolvedExtractValue |
None = ...,
) -> None: ...
def HasField(
self,
@@ -876,6 +912,8 @@ class Expression(google.protobuf.message.Message):
b"sort_order",
"unresolved_attribute",
b"unresolved_attribute",
+ "unresolved_extract_value",
+ b"unresolved_extract_value",
"unresolved_function",
b"unresolved_function",
"unresolved_regex",
@@ -905,6 +943,8 @@ class Expression(google.protobuf.message.Message):
b"sort_order",
"unresolved_attribute",
b"unresolved_attribute",
+ "unresolved_extract_value",
+ b"unresolved_extract_value",
"unresolved_function",
b"unresolved_function",
"unresolved_regex",
@@ -929,6 +969,7 @@ class Expression(google.protobuf.message.Message):
"sort_order",
"lambda_function",
"window",
+ "unresolved_extract_value",
] | None: ...
global___Expression = Expression
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py
b/python/pyspark/sql/tests/connect/test_connect_column.py
index e0c883a7f76..9f5587ccce5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -538,21 +538,87 @@ class SparkConnectTests(SparkConnectSQLTestCase):
).toPandas(),
)
+ def test_column_accessor(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT STRUCT(a, b, c) AS x, y, z, c FROM VALUES
+ (float(1.0), double(1.0), '2022', MAP('b', '123', 'a', 'kk'),
ARRAY(1, 2, 3)),
+ (float(2.0), double(2.0), '2018', MAP('a', 'xy'), ARRAY(-1, -2,
-3)),
+ (float(3.0), double(3.0), NULL, MAP('a', 'ab'), ARRAY(-1, 0, 1))
+ AS tab(a, b, c, y, z)
+ """
+
+ # +----------------+-------------------+------------+----+
+ # | x| y| z| c|
+ # +----------------+-------------------+------------+----+
+ # |{1.0, 1.0, 2022}|{b -> 123, a -> kk}| [1, 2, 3]|2022|
+ # |{2.0, 2.0, 2018}| {a -> xy}|[-1, -2, -3]|2018|
+ # |{3.0, 3.0, null}| {a -> ab}| [-1, 0, 1]|null|
+ # +----------------+-------------------+------------+----+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # test struct
+ self.assert_eq(
+ cdf.select(cdf.x.a, cdf.x["b"], cdf["x"].c).toPandas(),
+ sdf.select(sdf.x.a, sdf.x["b"], sdf["x"].c).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.col("x").a, cdf.x.b, CF.col("x")["c"]).toPandas(),
+ sdf.select(SF.col("x").a, sdf.x.b, SF.col("x")["c"]).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(cdf.x.getItem("a"), cdf.x.getItem("b"),
cdf["x"].getField("c")).toPandas(),
+ sdf.select(sdf.x.getItem("a"), sdf.x.getItem("b"),
sdf["x"].getField("c")).toPandas(),
+ )
+
+ # test map
+ self.assert_eq(
+ cdf.select(cdf.y.a, cdf.y["b"], cdf["y"].c).toPandas(),
+ sdf.select(sdf.y.a, sdf.y["b"], sdf["y"].c).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.col("y").a, cdf.y.b, CF.col("y")["c"]).toPandas(),
+ sdf.select(SF.col("y").a, sdf.y.b, SF.col("y")["c"]).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(cdf.y.getItem("a"), cdf.y.getItem("b"),
cdf["y"].getField("c")).toPandas(),
+ sdf.select(sdf.y.getItem("a"), sdf.y.getItem("b"),
sdf["y"].getField("c")).toPandas(),
+ )
+
+ # test array
+ self.assert_eq(
+ cdf.select(cdf.z[0], cdf.z[1], cdf["z"][2]).toPandas(),
+ sdf.select(sdf.z[0], sdf.z[1], sdf["z"][2]).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.col("z")[0], cdf.z[10], CF.col("z")[-10]).toPandas(),
+ sdf.select(SF.col("z")[0], sdf.z[10], SF.col("z")[-10]).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(cdf.z.getItem(0), cdf.z.getItem(1),
cdf["z"].getField(2)).toPandas(),
+ sdf.select(sdf.z.getItem(0), sdf.z.getItem(1),
sdf["z"].getField(2)).toPandas(),
+ )
+
+ # test string with slice
+ self.assert_eq(
+ cdf.select(cdf.c[0:1], cdf["c"][2:10]).toPandas(),
+ sdf.select(sdf.c[0:1], sdf["c"][2:10]).toPandas(),
+ )
+
def test_unsupported_functions(self):
# SPARK-41225: Disable unsupported functions.
c = self.connect.range(1).id
for f in (
- "getItem",
- "getField",
"withField",
"dropFields",
):
with self.assertRaises(NotImplementedError):
getattr(c, f)()
- with self.assertRaises(NotImplementedError):
- c["a"]
-
with self.assertRaises(TypeError):
for x in c:
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]