This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 57052f56db8 [SPARK-41731][CONNECT][PYTHON] Implement the column accessor 57052f56db8 is described below commit 57052f56db85c87ead455e1172237fef840d9a68 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Dec 27 21:51:25 2022 +0900 [SPARK-41731][CONNECT][PYTHON] Implement the column accessor ### What changes were proposed in this pull request? Implement the column accessor: 1. `getItem` 2. `getField` 3. `__getattr__` 4. `__getitem__` ### Why are the changes needed? For API coverage ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added UT Closes #39241 from zhengruifeng/column_get_item. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/expressions.proto | 12 +++ .../sql/connect/planner/SparkConnectPlanner.scala | 11 ++- python/pyspark/sql/column.py | 6 ++ python/pyspark/sql/connect/column.py | 44 +++++++++-- python/pyspark/sql/connect/expressions.py | 24 ++++++ .../pyspark/sql/connect/proto/expressions_pb2.py | 89 +++++++++++++--------- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 41 ++++++++++ .../sql/tests/connect/test_connect_column.py | 76 ++++++++++++++++-- 8 files changed, 252 insertions(+), 51 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 90d880939a8..b8ed9eb6f23 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -40,6 +40,7 @@ message Expression { SortOrder sort_order = 9; LambdaFunction lambda_function = 10; Window window = 11; + UnresolvedExtractValue unresolved_extract_value = 12; } @@ -229,6 +230,17 @@ message Expression { string col_name = 1; } + // Extracts a value or values from an Expression + message UnresolvedExtractValue { + // (Required) The expression to extract value from, can be + // Map, Array, Struct or array of Structs. + Expression child = 1; + + // (Required) The expression to describe the extraction, can be + // key of Map, index of Array, field name of Struct. + Expression extraction = 2; + } + message Alias { // (Required) The expression that alias will be added on. Expression expr = 1; diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 1645eb2c381..c1e96b9d991 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -27,7 +27,7 @@ import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} -import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} @@ -556,6 +556,8 @@ class SparkConnectPlanner(session: SparkSession) { case proto.Expression.ExprTypeCase.CAST => transformCast(exp.getCast) case proto.Expression.ExprTypeCase.UNRESOLVED_REGEX => transformUnresolvedRegex(exp.getUnresolvedRegex) + case proto.Expression.ExprTypeCase.UNRESOLVED_EXTRACT_VALUE => + transformUnresolvedExtractValue(exp.getUnresolvedExtractValue) case proto.Expression.ExprTypeCase.SORT_ORDER => transformSortOrder(exp.getSortOrder) case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION => transformLambdaFunction(exp.getLambdaFunction) @@ -813,6 +815,13 @@ class SparkConnectPlanner(session: SparkSession) { } } + private def transformUnresolvedExtractValue( + extract: proto.Expression.UnresolvedExtractValue): UnresolvedExtractValue = { + UnresolvedExtractValue( + transformExpression(extract.getChild), + transformExpression(extract.getExtraction)) + } + private def transformWindowExpression(window: proto.Expression.Window) = { if (!window.hasWindowFunction) { throw InvalidPlanInput(s"WindowFunction is required in WindowExpression") diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index bb43b0af57c..96b4333e604 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -430,6 +430,9 @@ class Column: .. versionadded:: 1.3.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- key @@ -469,6 +472,9 @@ class Column: .. versionadded:: 1.3.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- name diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 918d6cd2adc..c9bc434fec3 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -17,6 +17,7 @@ import datetime import decimal +import warnings from typing import ( TYPE_CHECKING, @@ -34,6 +35,7 @@ import pyspark.sql.connect.proto as proto from pyspark.sql.connect.expressions import ( Expression, UnresolvedFunction, + UnresolvedExtractValue, SQLExpression, LiteralExpression, CaseWhen, @@ -351,9 +353,6 @@ class Column: isin.__doc__ = PySparkColumn.isin.__doc__ - def getItem(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("getItem() is not yet implemented.") - def between( self, lowerBound: Union["Column", "LiteralType", "DateTimeLiteral", "DecimalLiteral"], @@ -363,8 +362,29 @@ class Column: between.__doc__ = PySparkColumn.between.__doc__ - def getField(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("getField() is not yet implemented.") + def getItem(self, key: Any) -> "Column": + if isinstance(key, Column): + warnings.warn( + "A column as 'key' in getItem is deprecated as of Spark 3.0, and will not " + "be supported in the future release. Use `column[key]` or `column.key` syntax " + "instead.", + FutureWarning, + ) + return self[key] + + getItem.__doc__ = PySparkColumn.getItem.__doc__ + + def getField(self, name: Any) -> "Column": + if isinstance(name, Column): + warnings.warn( + "A column as 'name' in getField is deprecated as of Spark 3.0, and will not " + "be supported in the future release. Use `column[name]` or `column.name` syntax " + "instead.", + FutureWarning, + ) + return self[name] + + getField.__doc__ = PySparkColumn.getField.__doc__ def withField(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("withField() is not yet implemented.") @@ -372,8 +392,18 @@ class Column: def dropFields(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("dropFields() is not yet implemented.") - def __getitem__(self, k: Any) -> None: - raise NotImplementedError("apply() - __getitem__ is not yet implemented.") + def __getattr__(self, item: Any) -> "Column": + if item.startswith("__"): + raise AttributeError(item) + return self[item] + + def __getitem__(self, k: Any) -> "Column": + if isinstance(k, slice): + if k.step is not None: + raise ValueError("slice with step is not supported.") + return self.substr(k.start, k.stop) + else: + return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k))) def __iter__(self) -> None: raise TypeError("Column is not iterable") diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 1f63d6b0a10..fa0cfd52b1b 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -420,6 +420,30 @@ class UnresolvedFunction(Expression): return f"{self._name}({', '.join([str(arg) for arg in self._args])})" +class UnresolvedExtractValue(Expression): + def __init__( + self, + child: Expression, + extraction: Expression, + ) -> None: + super().__init__() + + assert isinstance(child, Expression) + self._child = child + + assert isinstance(extraction, Expression) + self._extraction = extraction + + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: + expr = proto.Expression() + expr.unresolved_extract_value.child.CopyFrom(self._child.to_plan(session)) + expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session)) + return expr + + def __repr__(self) -> str: + return f"UnresolvedExtractValue({str(self._child)}, {str(self._extraction)})" + + class UnresolvedRegex(Expression): def __init__( self, diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 5e4d25b8b94..849b10cf90e 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb0\x1d\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa5\x1f\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st [...] ) @@ -53,6 +53,7 @@ _EXPRESSION_UNRESOLVEDFUNCTION = _EXPRESSION.nested_types_by_name["UnresolvedFun _EXPRESSION_EXPRESSIONSTRING = _EXPRESSION.nested_types_by_name["ExpressionString"] _EXPRESSION_UNRESOLVEDSTAR = _EXPRESSION.nested_types_by_name["UnresolvedStar"] _EXPRESSION_UNRESOLVEDREGEX = _EXPRESSION.nested_types_by_name["UnresolvedRegex"] +_EXPRESSION_UNRESOLVEDEXTRACTVALUE = _EXPRESSION.nested_types_by_name["UnresolvedExtractValue"] _EXPRESSION_ALIAS = _EXPRESSION.nested_types_by_name["Alias"] _EXPRESSION_LAMBDAFUNCTION = _EXPRESSION.nested_types_by_name["LambdaFunction"] _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[ @@ -181,6 +182,15 @@ Expression = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedRegex) }, ), + "UnresolvedExtractValue": _reflection.GeneratedProtocolMessageType( + "UnresolvedExtractValue", + (_message.Message,), + { + "DESCRIPTOR": _EXPRESSION_UNRESOLVEDEXTRACTVALUE, + "__module__": "spark.connect.expressions_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedExtractValue) + }, + ), "Alias": _reflection.GeneratedProtocolMessageType( "Alias", (_message.Message,), @@ -218,6 +228,7 @@ _sym_db.RegisterMessage(Expression.UnresolvedFunction) _sym_db.RegisterMessage(Expression.ExpressionString) _sym_db.RegisterMessage(Expression.UnresolvedStar) _sym_db.RegisterMessage(Expression.UnresolvedRegex) +_sym_db.RegisterMessage(Expression.UnresolvedExtractValue) _sym_db.RegisterMessage(Expression.Alias) _sym_db.RegisterMessage(Expression.LambdaFunction) @@ -226,41 +237,43 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 78 - _EXPRESSION._serialized_end = 3838 - _EXPRESSION_WINDOW._serialized_start = 943 - _EXPRESSION_WINDOW._serialized_end = 1726 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1233 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1726 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1500 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1645 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1647 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1726 - _EXPRESSION_SORTORDER._serialized_start = 1729 - _EXPRESSION_SORTORDER._serialized_end = 2154 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 1959 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2067 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2069 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2154 - _EXPRESSION_CAST._serialized_start = 2157 - _EXPRESSION_CAST._serialized_end = 2302 - _EXPRESSION_LITERAL._serialized_start = 2305 - _EXPRESSION_LITERAL._serialized_end = 3181 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 2948 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3065 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3067 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3165 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3183 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3253 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3256 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3460 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3462 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3512 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3514 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3554 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3556 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3600 - _EXPRESSION_ALIAS._serialized_start = 3602 - _EXPRESSION_ALIAS._serialized_end = 3722 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3724 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 3825 + _EXPRESSION._serialized_end = 4083 + _EXPRESSION_WINDOW._serialized_start = 1053 + _EXPRESSION_WINDOW._serialized_end = 1836 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1343 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1836 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1610 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1755 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1757 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1836 + _EXPRESSION_SORTORDER._serialized_start = 1839 + _EXPRESSION_SORTORDER._serialized_end = 2264 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2069 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2177 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2179 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2264 + _EXPRESSION_CAST._serialized_start = 2267 + _EXPRESSION_CAST._serialized_end = 2412 + _EXPRESSION_LITERAL._serialized_start = 2415 + _EXPRESSION_LITERAL._serialized_end = 3291 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3058 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3175 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3177 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3275 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3293 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3363 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3366 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3570 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3572 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3622 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3624 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3664 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3666 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3710 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3713 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3845 + _EXPRESSION_ALIAS._serialized_start = 3847 + _EXPRESSION_ALIAS._serialized_end = 3967 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3969 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4070 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 26002e649d5..6a248a04767 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -734,6 +734,38 @@ class Expression(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["col_name", b"col_name"] ) -> None: ... + class UnresolvedExtractValue(google.protobuf.message.Message): + """Extracts a value or values from an Expression""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CHILD_FIELD_NUMBER: builtins.int + EXTRACTION_FIELD_NUMBER: builtins.int + @property + def child(self) -> global___Expression: + """(Required) The expression to extract value from, can be + Map, Array, Struct or array of Structs. + """ + @property + def extraction(self) -> global___Expression: + """(Required) The expression to describe the extraction, can be + key of Map, index of Array, field name of Struct. + """ + def __init__( + self, + *, + child: global___Expression | None = ..., + extraction: global___Expression | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal["child", b"child", "extraction", b"extraction"], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["child", b"child", "extraction", b"extraction"], + ) -> None: ... + class Alias(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -820,6 +852,7 @@ class Expression(google.protobuf.message.Message): SORT_ORDER_FIELD_NUMBER: builtins.int LAMBDA_FUNCTION_FIELD_NUMBER: builtins.int WINDOW_FIELD_NUMBER: builtins.int + UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @property @@ -842,6 +875,8 @@ class Expression(google.protobuf.message.Message): def lambda_function(self) -> global___Expression.LambdaFunction: ... @property def window(self) -> global___Expression.Window: ... + @property + def unresolved_extract_value(self) -> global___Expression.UnresolvedExtractValue: ... def __init__( self, *, @@ -856,6 +891,7 @@ class Expression(google.protobuf.message.Message): sort_order: global___Expression.SortOrder | None = ..., lambda_function: global___Expression.LambdaFunction | None = ..., window: global___Expression.Window | None = ..., + unresolved_extract_value: global___Expression.UnresolvedExtractValue | None = ..., ) -> None: ... def HasField( self, @@ -876,6 +912,8 @@ class Expression(google.protobuf.message.Message): b"sort_order", "unresolved_attribute", b"unresolved_attribute", + "unresolved_extract_value", + b"unresolved_extract_value", "unresolved_function", b"unresolved_function", "unresolved_regex", @@ -905,6 +943,8 @@ class Expression(google.protobuf.message.Message): b"sort_order", "unresolved_attribute", b"unresolved_attribute", + "unresolved_extract_value", + b"unresolved_extract_value", "unresolved_function", b"unresolved_function", "unresolved_regex", @@ -929,6 +969,7 @@ class Expression(google.protobuf.message.Message): "sort_order", "lambda_function", "window", + "unresolved_extract_value", ] | None: ... global___Expression = Expression diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index e0c883a7f76..9f5587ccce5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -538,21 +538,87 @@ class SparkConnectTests(SparkConnectSQLTestCase): ).toPandas(), ) + def test_column_accessor(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT STRUCT(a, b, c) AS x, y, z, c FROM VALUES + (float(1.0), double(1.0), '2022', MAP('b', '123', 'a', 'kk'), ARRAY(1, 2, 3)), + (float(2.0), double(2.0), '2018', MAP('a', 'xy'), ARRAY(-1, -2, -3)), + (float(3.0), double(3.0), NULL, MAP('a', 'ab'), ARRAY(-1, 0, 1)) + AS tab(a, b, c, y, z) + """ + + # +----------------+-------------------+------------+----+ + # | x| y| z| c| + # +----------------+-------------------+------------+----+ + # |{1.0, 1.0, 2022}|{b -> 123, a -> kk}| [1, 2, 3]|2022| + # |{2.0, 2.0, 2018}| {a -> xy}|[-1, -2, -3]|2018| + # |{3.0, 3.0, null}| {a -> ab}| [-1, 0, 1]|null| + # +----------------+-------------------+------------+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test struct + self.assert_eq( + cdf.select(cdf.x.a, cdf.x["b"], cdf["x"].c).toPandas(), + sdf.select(sdf.x.a, sdf.x["b"], sdf["x"].c).toPandas(), + ) + self.assert_eq( + cdf.select(CF.col("x").a, cdf.x.b, CF.col("x")["c"]).toPandas(), + sdf.select(SF.col("x").a, sdf.x.b, SF.col("x")["c"]).toPandas(), + ) + self.assert_eq( + cdf.select(cdf.x.getItem("a"), cdf.x.getItem("b"), cdf["x"].getField("c")).toPandas(), + sdf.select(sdf.x.getItem("a"), sdf.x.getItem("b"), sdf["x"].getField("c")).toPandas(), + ) + + # test map + self.assert_eq( + cdf.select(cdf.y.a, cdf.y["b"], cdf["y"].c).toPandas(), + sdf.select(sdf.y.a, sdf.y["b"], sdf["y"].c).toPandas(), + ) + self.assert_eq( + cdf.select(CF.col("y").a, cdf.y.b, CF.col("y")["c"]).toPandas(), + sdf.select(SF.col("y").a, sdf.y.b, SF.col("y")["c"]).toPandas(), + ) + self.assert_eq( + cdf.select(cdf.y.getItem("a"), cdf.y.getItem("b"), cdf["y"].getField("c")).toPandas(), + sdf.select(sdf.y.getItem("a"), sdf.y.getItem("b"), sdf["y"].getField("c")).toPandas(), + ) + + # test array + self.assert_eq( + cdf.select(cdf.z[0], cdf.z[1], cdf["z"][2]).toPandas(), + sdf.select(sdf.z[0], sdf.z[1], sdf["z"][2]).toPandas(), + ) + self.assert_eq( + cdf.select(CF.col("z")[0], cdf.z[10], CF.col("z")[-10]).toPandas(), + sdf.select(SF.col("z")[0], sdf.z[10], SF.col("z")[-10]).toPandas(), + ) + self.assert_eq( + cdf.select(cdf.z.getItem(0), cdf.z.getItem(1), cdf["z"].getField(2)).toPandas(), + sdf.select(sdf.z.getItem(0), sdf.z.getItem(1), sdf["z"].getField(2)).toPandas(), + ) + + # test string with slice + self.assert_eq( + cdf.select(cdf.c[0:1], cdf["c"][2:10]).toPandas(), + sdf.select(sdf.c[0:1], sdf["c"][2:10]).toPandas(), + ) + def test_unsupported_functions(self): # SPARK-41225: Disable unsupported functions. c = self.connect.range(1).id for f in ( - "getItem", - "getField", "withField", "dropFields", ): with self.assertRaises(NotImplementedError): getattr(c, f)() - with self.assertRaises(NotImplementedError): - c["a"] - with self.assertRaises(TypeError): for x in c: pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org