This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 86f6dde3079 [SPARK-41767][CONNECT][PYTHON] Implement `Column.{withField, dropFields}` 86f6dde3079 is described below commit 86f6dde30798e69c7a953ee59788a4a9831b37cd Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Dec 29 20:57:01 2022 +0900 [SPARK-41767][CONNECT][PYTHON] Implement `Column.{withField, dropFields}` ### What changes were proposed in this pull request? Implement `Column.{withField, dropFields}` ### Why are the changes needed? For API coverage ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added UT Closes #39283 from zhengruifeng/connect_column_field. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/expressions.proto | 15 +++ .../sql/connect/planner/SparkConnectPlanner.scala | 17 +++ python/pyspark/sql/column.py | 6 + python/pyspark/sql/connect/column.py | 37 +++++- python/pyspark/sql/connect/expressions.py | 53 ++++++++ .../pyspark/sql/connect/proto/expressions_pb2.py | 93 +++++++------ .../pyspark/sql/connect/proto/expressions_pb2.pyi | 53 ++++++++ .../sql/tests/connect/test_connect_column.py | 147 +++++++++++++++++++-- 8 files changed, 366 insertions(+), 55 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index b8ed9eb6f23..fa2836702c6 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -41,6 +41,7 @@ message Expression { LambdaFunction lambda_function = 10; Window window = 11; UnresolvedExtractValue unresolved_extract_value = 12; + UpdateFields update_fields = 13; } @@ -241,6 +242,20 @@ message Expression { Expression extraction = 2; } + // Add, replace or drop a field of `StructType` expression by name. + message UpdateFields { + // (Required) The struct expression. + Expression struct_expression = 1; + + // (Required) The field name. + string field_name = 2; + + // (Optional) The expression to add or replace. + // + // When not set, it means this field will be dropped. + Expression value_expression = 3; + } + message Alias { // (Required) The expression that alias will be added on. Expression expr = 1; diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 4bb90fc5bc0..d06787e6b14 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -596,6 +596,8 @@ class SparkConnectPlanner(session: SparkSession) { transformUnresolvedRegex(exp.getUnresolvedRegex) case proto.Expression.ExprTypeCase.UNRESOLVED_EXTRACT_VALUE => transformUnresolvedExtractValue(exp.getUnresolvedExtractValue) + case proto.Expression.ExprTypeCase.UPDATE_FIELDS => + transformUpdateFields(exp.getUpdateFields) case proto.Expression.ExprTypeCase.SORT_ORDER => transformSortOrder(exp.getSortOrder) case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION => transformLambdaFunction(exp.getLambdaFunction) @@ -860,6 +862,21 @@ class SparkConnectPlanner(session: SparkSession) { transformExpression(extract.getExtraction)) } + private def transformUpdateFields(update: proto.Expression.UpdateFields): UpdateFields = { + if (update.hasValueExpression) { + // add or replace a field + UpdateFields.apply( + col = transformExpression(update.getStructExpression), + fieldName = update.getFieldName, + expr = transformExpression(update.getValueExpression)) + } else { + // drop a field + UpdateFields.apply( + col = transformExpression(update.getStructExpression), + fieldName = update.getFieldName) + } + } + private def transformWindowExpression(window: proto.Expression.Window) = { if (!window.hasWindowFunction) { throw InvalidPlanInput(s"WindowFunction is required in WindowExpression") diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 5a0987b4cfe..cd7b6932c2f 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -522,6 +522,9 @@ class Column: .. versionadded:: 3.1.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- fieldName : str @@ -569,6 +572,9 @@ class Column: .. versionadded:: 3.1.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- fieldNames : str diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 58d86a3d389..2667e795974 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -43,6 +43,8 @@ from pyspark.sql.connect.expressions import ( SortOrder, CastExpression, WindowExpression, + WithField, + DropField, ) @@ -359,11 +361,38 @@ class Column: getField.__doc__ = PySparkColumn.getField.__doc__ - def withField(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("withField() is not yet implemented.") + def withField(self, fieldName: str, col: "Column") -> "Column": + if not isinstance(fieldName, str): + raise TypeError( + f"fieldName should be a string, but got {type(fieldName).__name__} {fieldName}" + ) + + if not isinstance(col, Column): + raise TypeError(f"col should be a Column, but got {type(col).__name__} {col}") + + return Column(WithField(self._expr, fieldName, col._expr)) + + withField.__doc__ = PySparkColumn.withField.__doc__ + + def dropFields(self, *fieldNames: str) -> "Column": + dropField: Optional[DropField] = None + for fieldName in fieldNames: + if not isinstance(fieldName, str): + raise TypeError( + f"fieldName should be a string, but got {type(fieldName).__name__} {fieldName}" + ) + + if dropField is None: + dropField = DropField(self._expr, fieldName) + else: + dropField = DropField(dropField, fieldName) + + if dropField is None: + raise ValueError("dropFields requires at least 1 field") + + return Column(dropField) - def dropFields(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("dropFields() is not yet implemented.") + dropFields.__doc__ = PySparkColumn.dropFields.__doc__ def __getattr__(self, item: Any) -> "Column": if item.startswith("__"): diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index fa0cfd52b1b..27397fc0c13 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -420,6 +420,59 @@ class UnresolvedFunction(Expression): return f"{self._name}({', '.join([str(arg) for arg in self._args])})" +class WithField(Expression): + def __init__( + self, + structExpr: Expression, + fieldName: str, + valueExpr: Expression, + ) -> None: + super().__init__() + + assert isinstance(structExpr, Expression) + self._structExpr = structExpr + + assert isinstance(fieldName, str) + self._fieldName = fieldName + + assert isinstance(valueExpr, Expression) + self._valueExpr = valueExpr + + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: + expr = proto.Expression() + expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session)) + expr.update_fields.field_name = self._fieldName + expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session)) + return expr + + def __repr__(self) -> str: + return f"WithField({self._structExpr}, {self._fieldName}, {self._valueExpr})" + + +class DropField(Expression): + def __init__( + self, + structExpr: Expression, + fieldName: str, + ) -> None: + super().__init__() + + assert isinstance(structExpr, Expression) + self._structExpr = structExpr + + assert isinstance(fieldName, str) + self._fieldName = fieldName + + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: + expr = proto.Expression() + expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session)) + expr.update_fields.field_name = self._fieldName + return expr + + def __repr__(self) -> str: + return f"DropField({self._structExpr}, {self._fieldName})" + + class UnresolvedExtractValue(Expression): def __init__( self, diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 849b10cf90e..01c24d1bcd9 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa5\x1f\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb2!\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_strin [...] ) @@ -54,6 +54,7 @@ _EXPRESSION_EXPRESSIONSTRING = _EXPRESSION.nested_types_by_name["ExpressionStrin _EXPRESSION_UNRESOLVEDSTAR = _EXPRESSION.nested_types_by_name["UnresolvedStar"] _EXPRESSION_UNRESOLVEDREGEX = _EXPRESSION.nested_types_by_name["UnresolvedRegex"] _EXPRESSION_UNRESOLVEDEXTRACTVALUE = _EXPRESSION.nested_types_by_name["UnresolvedExtractValue"] +_EXPRESSION_UPDATEFIELDS = _EXPRESSION.nested_types_by_name["UpdateFields"] _EXPRESSION_ALIAS = _EXPRESSION.nested_types_by_name["Alias"] _EXPRESSION_LAMBDAFUNCTION = _EXPRESSION.nested_types_by_name["LambdaFunction"] _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[ @@ -191,6 +192,15 @@ Expression = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedExtractValue) }, ), + "UpdateFields": _reflection.GeneratedProtocolMessageType( + "UpdateFields", + (_message.Message,), + { + "DESCRIPTOR": _EXPRESSION_UPDATEFIELDS, + "__module__": "spark.connect.expressions_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Expression.UpdateFields) + }, + ), "Alias": _reflection.GeneratedProtocolMessageType( "Alias", (_message.Message,), @@ -229,6 +239,7 @@ _sym_db.RegisterMessage(Expression.ExpressionString) _sym_db.RegisterMessage(Expression.UnresolvedStar) _sym_db.RegisterMessage(Expression.UnresolvedRegex) _sym_db.RegisterMessage(Expression.UnresolvedExtractValue) +_sym_db.RegisterMessage(Expression.UpdateFields) _sym_db.RegisterMessage(Expression.Alias) _sym_db.RegisterMessage(Expression.LambdaFunction) @@ -237,43 +248,45 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 78 - _EXPRESSION._serialized_end = 4083 - _EXPRESSION_WINDOW._serialized_start = 1053 - _EXPRESSION_WINDOW._serialized_end = 1836 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1343 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1836 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1610 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1755 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1757 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1836 - _EXPRESSION_SORTORDER._serialized_start = 1839 - _EXPRESSION_SORTORDER._serialized_end = 2264 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2069 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2177 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2179 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2264 - _EXPRESSION_CAST._serialized_start = 2267 - _EXPRESSION_CAST._serialized_end = 2412 - _EXPRESSION_LITERAL._serialized_start = 2415 - _EXPRESSION_LITERAL._serialized_end = 3291 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3058 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3175 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3177 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3275 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3293 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3363 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3366 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3570 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3572 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3622 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3624 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3664 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3666 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3710 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3713 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3845 - _EXPRESSION_ALIAS._serialized_start = 3847 - _EXPRESSION_ALIAS._serialized_end = 3967 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3969 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4070 + _EXPRESSION._serialized_end = 4352 + _EXPRESSION_WINDOW._serialized_start = 1132 + _EXPRESSION_WINDOW._serialized_end = 1915 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1422 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1915 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1689 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1834 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1836 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1915 + _EXPRESSION_SORTORDER._serialized_start = 1918 + _EXPRESSION_SORTORDER._serialized_end = 2343 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2148 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2256 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2258 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2343 + _EXPRESSION_CAST._serialized_start = 2346 + _EXPRESSION_CAST._serialized_end = 2491 + _EXPRESSION_LITERAL._serialized_start = 2494 + _EXPRESSION_LITERAL._serialized_end = 3370 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3137 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3254 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3256 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3354 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3372 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3442 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3445 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3649 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3651 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3701 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3703 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3743 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3745 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3789 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3792 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3924 + _EXPRESSION_UPDATEFIELDS._serialized_start = 3927 + _EXPRESSION_UPDATEFIELDS._serialized_end = 4114 + _EXPRESSION_ALIAS._serialized_start = 4116 + _EXPRESSION_ALIAS._serialized_end = 4236 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4238 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4339 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 6a248a04767..5e5eab5b5d9 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -766,6 +766,50 @@ class Expression(google.protobuf.message.Message): field_name: typing_extensions.Literal["child", b"child", "extraction", b"extraction"], ) -> None: ... + class UpdateFields(google.protobuf.message.Message): + """Add, replace or drop a field of `StructType` expression by name.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STRUCT_EXPRESSION_FIELD_NUMBER: builtins.int + FIELD_NAME_FIELD_NUMBER: builtins.int + VALUE_EXPRESSION_FIELD_NUMBER: builtins.int + @property + def struct_expression(self) -> global___Expression: + """(Required) The struct expression.""" + field_name: builtins.str + """(Required) The field name.""" + @property + def value_expression(self) -> global___Expression: + """(Optional) The expression to add or replace. + + When not set, it means this field will be dropped. + """ + def __init__( + self, + *, + struct_expression: global___Expression | None = ..., + field_name: builtins.str = ..., + value_expression: global___Expression | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "struct_expression", b"struct_expression", "value_expression", b"value_expression" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "field_name", + b"field_name", + "struct_expression", + b"struct_expression", + "value_expression", + b"value_expression", + ], + ) -> None: ... + class Alias(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -853,6 +897,7 @@ class Expression(google.protobuf.message.Message): LAMBDA_FUNCTION_FIELD_NUMBER: builtins.int WINDOW_FIELD_NUMBER: builtins.int UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int + UPDATE_FIELDS_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @property @@ -877,6 +922,8 @@ class Expression(google.protobuf.message.Message): def window(self) -> global___Expression.Window: ... @property def unresolved_extract_value(self) -> global___Expression.UnresolvedExtractValue: ... + @property + def update_fields(self) -> global___Expression.UpdateFields: ... def __init__( self, *, @@ -892,6 +939,7 @@ class Expression(google.protobuf.message.Message): lambda_function: global___Expression.LambdaFunction | None = ..., window: global___Expression.Window | None = ..., unresolved_extract_value: global___Expression.UnresolvedExtractValue | None = ..., + update_fields: global___Expression.UpdateFields | None = ..., ) -> None: ... def HasField( self, @@ -920,6 +968,8 @@ class Expression(google.protobuf.message.Message): b"unresolved_regex", "unresolved_star", b"unresolved_star", + "update_fields", + b"update_fields", "window", b"window", ], @@ -951,6 +1001,8 @@ class Expression(google.protobuf.message.Message): b"unresolved_regex", "unresolved_star", b"unresolved_star", + "update_fields", + b"update_fields", "window", b"window", ], @@ -970,6 +1022,7 @@ class Expression(google.protobuf.message.Message): "lambda_function", "window", "unresolved_extract_value", + "update_fields", ] | None: ... global___Expression = Expression diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 0be990ebbe1..9d18a1fe9b2 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -54,6 +54,7 @@ from pyspark.sql.types import ( BooleanType, ) from pyspark.testing.connectutils import should_test_connect +from pyspark.sql.connect.client import SparkConnectException if should_test_connect: import pandas as pd @@ -61,6 +62,24 @@ if should_test_connect: class SparkConnectTests(SparkConnectSQLTestCase): + def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20): + from pyspark.sql.dataframe import DataFrame as SDF + from pyspark.sql.connect.dataframe import DataFrame as CDF + + assert isinstance(df1, (SDF, CDF)) + if isinstance(df1, SDF): + str1 = df1._jdf.showString(n, truncate, False) + else: + str1 = df1._show_string(n, truncate, False) + + assert isinstance(df2, (SDF, CDF)) + if isinstance(df2, SDF): + str2 = df2._jdf.showString(n, truncate, False) + else: + str2 = df2._show_string(n, truncate, False) + + self.assertEqual(str1, str2) + def test_column_operator(self): # SPARK-41351: Column needs to support != df = self.connect.range(10) @@ -184,6 +203,13 @@ class SparkConnectTests(SparkConnectSQLTestCase): ): not (cdf.a > 2) + with self.assertRaisesRegex( + TypeError, + "Column is not iterable", + ): + for x in cdf.a: + pass + def test_datetime(self): query = """ SELECT * FROM VALUES @@ -743,19 +769,118 @@ class SparkConnectTests(SparkConnectSQLTestCase): sdf.select(sdf.a ** sdf["b"], sdf.d**2, 2**sdf.c).toPandas(), ) - def test_unsupported_functions(self): - # SPARK-41225: Disable unsupported functions. - c = self.connect.range(1).id - for f in ( - "withField", - "dropFields", + def test_column_field_ops(self): + # SPARK-41767: test withField, dropFields + + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT STRUCT(a, b, c, d) AS x, e FROM VALUES + (float(1.0), double(1.0), '2022', 1, 0), + (float(2.0), double(2.0), '2018', NULL, 2), + (float(3.0), double(3.0), NULL, 3, NULL) + AS tab(a, b, c, d, e) + """ + + # +----------------------+----+ + # | x| e| + # +----------------------+----+ + # | {1.0, 1.0, 2022, 1}| 0| + # |{2.0, 2.0, 2018, null}| 2| + # | {3.0, 3.0, null, 3}|null| + # +----------------------+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # add field + self.compare_by_show( + cdf.select(cdf.x.withField("z", cdf.e)), + sdf.select(sdf.x.withField("z", sdf.e)), + truncate=100, + ) + self.compare_by_show( + cdf.select(cdf.x.withField("z", CF.col("e"))), + sdf.select(sdf.x.withField("z", SF.col("e"))), + truncate=100, + ) + self.compare_by_show( + cdf.select(cdf.x.withField("z", CF.lit("xyz"))), + sdf.select(sdf.x.withField("z", SF.lit("xyz"))), + truncate=100, + ) + + # replace field + self.compare_by_show( + cdf.select(cdf.x.withField("a", cdf.e)), + sdf.select(sdf.x.withField("a", sdf.e)), + truncate=100, + ) + self.compare_by_show( + cdf.select(cdf.x.withField("a", CF.col("e"))), + sdf.select(sdf.x.withField("a", SF.col("e"))), + truncate=100, + ) + self.compare_by_show( + cdf.select(cdf.x.withField("a", CF.lit("xyz"))), + sdf.select(sdf.x.withField("a", SF.lit("xyz"))), + truncate=100, + ) + + # drop field + self.compare_by_show( + cdf.select(cdf.x.dropFields("a")), + sdf.select(sdf.x.dropFields("a")), + truncate=100, + ) + self.compare_by_show( + cdf.select(cdf.x.dropFields("z")), + sdf.select(sdf.x.dropFields("z")), + truncate=100, + ) + self.compare_by_show( + cdf.select(cdf.x.dropFields("a", "b", "z")), + sdf.select(sdf.x.dropFields("a", "b", "z")), + truncate=100, + ) + + # check error + # invalid column: not a struct column + with self.assertRaises(SparkConnectException): + cdf.select(cdf.e.withField("a", CF.lit(1))).show() + + # invalid column: not a struct column + with self.assertRaises(SparkConnectException): + cdf.select(cdf.e.dropFields("a")).show() + + # cannot drop all fields in struct + with self.assertRaises(SparkConnectException): + cdf.select(cdf.x.dropFields("a", "b", "c", "d")).show() + + with self.assertRaisesRegex( + TypeError, + "fieldName should be a string", ): - with self.assertRaises(NotImplementedError): - getattr(c, f)() + cdf.select(cdf.x.withField(CF.col("a"), cdf.e)).show() - with self.assertRaises(TypeError): - for x in c: - pass + with self.assertRaisesRegex( + TypeError, + "col should be a Column", + ): + cdf.select(cdf.x.withField("a", 2)).show() + + with self.assertRaisesRegex( + TypeError, + "fieldName should be a string", + ): + cdf.select(cdf.x.dropFields("a", 1, 2)).show() + + with self.assertRaisesRegex( + ValueError, + "dropFields requires at least 1 field", + ): + cdf.select(cdf.x.dropFields()).show() def test_column_string_ops(self): # SPARK-41764: test string ops --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org