This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 86f6dde3079 [SPARK-41767][CONNECT][PYTHON] Implement
`Column.{withField, dropFields}`
86f6dde3079 is described below
commit 86f6dde30798e69c7a953ee59788a4a9831b37cd
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 29 20:57:01 2022 +0900
[SPARK-41767][CONNECT][PYTHON] Implement `Column.{withField, dropFields}`
### What changes were proposed in this pull request?
Implement `Column.{withField, dropFields}`
### Why are the changes needed?
For API coverage
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added UT
Closes #39283 from zhengruifeng/connect_column_field.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/expressions.proto | 15 +++
.../sql/connect/planner/SparkConnectPlanner.scala | 17 +++
python/pyspark/sql/column.py | 6 +
python/pyspark/sql/connect/column.py | 37 +++++-
python/pyspark/sql/connect/expressions.py | 53 ++++++++
.../pyspark/sql/connect/proto/expressions_pb2.py | 93 +++++++------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 53 ++++++++
.../sql/tests/connect/test_connect_column.py | 147 +++++++++++++++++++--
8 files changed, 366 insertions(+), 55 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index b8ed9eb6f23..fa2836702c6 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -41,6 +41,7 @@ message Expression {
LambdaFunction lambda_function = 10;
Window window = 11;
UnresolvedExtractValue unresolved_extract_value = 12;
+ UpdateFields update_fields = 13;
}
@@ -241,6 +242,20 @@ message Expression {
Expression extraction = 2;
}
+ // Add, replace or drop a field of `StructType` expression by name.
+ message UpdateFields {
+ // (Required) The struct expression.
+ Expression struct_expression = 1;
+
+ // (Required) The field name.
+ string field_name = 2;
+
+ // (Optional) The expression to add or replace.
+ //
+ // When not set, it means this field will be dropped.
+ Expression value_expression = 3;
+ }
+
message Alias {
// (Required) The expression that alias will be added on.
Expression expr = 1;
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4bb90fc5bc0..d06787e6b14 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -596,6 +596,8 @@ class SparkConnectPlanner(session: SparkSession) {
transformUnresolvedRegex(exp.getUnresolvedRegex)
case proto.Expression.ExprTypeCase.UNRESOLVED_EXTRACT_VALUE =>
transformUnresolvedExtractValue(exp.getUnresolvedExtractValue)
+ case proto.Expression.ExprTypeCase.UPDATE_FIELDS =>
+ transformUpdateFields(exp.getUpdateFields)
case proto.Expression.ExprTypeCase.SORT_ORDER =>
transformSortOrder(exp.getSortOrder)
case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION =>
transformLambdaFunction(exp.getLambdaFunction)
@@ -860,6 +862,21 @@ class SparkConnectPlanner(session: SparkSession) {
transformExpression(extract.getExtraction))
}
+ private def transformUpdateFields(update: proto.Expression.UpdateFields):
UpdateFields = {
+ if (update.hasValueExpression) {
+ // add or replace a field
+ UpdateFields.apply(
+ col = transformExpression(update.getStructExpression),
+ fieldName = update.getFieldName,
+ expr = transformExpression(update.getValueExpression))
+ } else {
+ // drop a field
+ UpdateFields.apply(
+ col = transformExpression(update.getStructExpression),
+ fieldName = update.getFieldName)
+ }
+ }
+
private def transformWindowExpression(window: proto.Expression.Window) = {
if (!window.hasWindowFunction) {
throw InvalidPlanInput(s"WindowFunction is required in WindowExpression")
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 5a0987b4cfe..cd7b6932c2f 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -522,6 +522,9 @@ class Column:
.. versionadded:: 3.1.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
fieldName : str
@@ -569,6 +572,9 @@ class Column:
.. versionadded:: 3.1.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
fieldNames : str
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index 58d86a3d389..2667e795974 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -43,6 +43,8 @@ from pyspark.sql.connect.expressions import (
SortOrder,
CastExpression,
WindowExpression,
+ WithField,
+ DropField,
)
@@ -359,11 +361,38 @@ class Column:
getField.__doc__ = PySparkColumn.getField.__doc__
- def withField(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("withField() is not yet implemented.")
+ def withField(self, fieldName: str, col: "Column") -> "Column":
+ if not isinstance(fieldName, str):
+ raise TypeError(
+ f"fieldName should be a string, but got
{type(fieldName).__name__} {fieldName}"
+ )
+
+ if not isinstance(col, Column):
+ raise TypeError(f"col should be a Column, but got
{type(col).__name__} {col}")
+
+ return Column(WithField(self._expr, fieldName, col._expr))
+
+ withField.__doc__ = PySparkColumn.withField.__doc__
+
+ def dropFields(self, *fieldNames: str) -> "Column":
+ dropField: Optional[DropField] = None
+ for fieldName in fieldNames:
+ if not isinstance(fieldName, str):
+ raise TypeError(
+ f"fieldName should be a string, but got
{type(fieldName).__name__} {fieldName}"
+ )
+
+ if dropField is None:
+ dropField = DropField(self._expr, fieldName)
+ else:
+ dropField = DropField(dropField, fieldName)
+
+ if dropField is None:
+ raise ValueError("dropFields requires at least 1 field")
+
+ return Column(dropField)
- def dropFields(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("dropFields() is not yet implemented.")
+ dropFields.__doc__ = PySparkColumn.dropFields.__doc__
def __getattr__(self, item: Any) -> "Column":
if item.startswith("__"):
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index fa0cfd52b1b..27397fc0c13 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -420,6 +420,59 @@ class UnresolvedFunction(Expression):
return f"{self._name}({', '.join([str(arg) for arg in
self._args])})"
+class WithField(Expression):
+ def __init__(
+ self,
+ structExpr: Expression,
+ fieldName: str,
+ valueExpr: Expression,
+ ) -> None:
+ super().__init__()
+
+ assert isinstance(structExpr, Expression)
+ self._structExpr = structExpr
+
+ assert isinstance(fieldName, str)
+ self._fieldName = fieldName
+
+ assert isinstance(valueExpr, Expression)
+ self._valueExpr = valueExpr
+
+ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+ expr = proto.Expression()
+
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
+ expr.update_fields.field_name = self._fieldName
+
expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session))
+ return expr
+
+ def __repr__(self) -> str:
+ return f"WithField({self._structExpr}, {self._fieldName},
{self._valueExpr})"
+
+
+class DropField(Expression):
+ def __init__(
+ self,
+ structExpr: Expression,
+ fieldName: str,
+ ) -> None:
+ super().__init__()
+
+ assert isinstance(structExpr, Expression)
+ self._structExpr = structExpr
+
+ assert isinstance(fieldName, str)
+ self._fieldName = fieldName
+
+ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+ expr = proto.Expression()
+
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
+ expr.update_fields.field_name = self._fieldName
+ return expr
+
+ def __repr__(self) -> str:
+ return f"DropField({self._structExpr}, {self._fieldName})"
+
+
class UnresolvedExtractValue(Expression):
def __init__(
self,
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 849b10cf90e..01c24d1bcd9 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa5\x1f\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb2!\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_strin
[...]
)
@@ -54,6 +54,7 @@ _EXPRESSION_EXPRESSIONSTRING =
_EXPRESSION.nested_types_by_name["ExpressionStrin
_EXPRESSION_UNRESOLVEDSTAR = _EXPRESSION.nested_types_by_name["UnresolvedStar"]
_EXPRESSION_UNRESOLVEDREGEX =
_EXPRESSION.nested_types_by_name["UnresolvedRegex"]
_EXPRESSION_UNRESOLVEDEXTRACTVALUE =
_EXPRESSION.nested_types_by_name["UnresolvedExtractValue"]
+_EXPRESSION_UPDATEFIELDS = _EXPRESSION.nested_types_by_name["UpdateFields"]
_EXPRESSION_ALIAS = _EXPRESSION.nested_types_by_name["Alias"]
_EXPRESSION_LAMBDAFUNCTION = _EXPRESSION.nested_types_by_name["LambdaFunction"]
_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE =
_EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
@@ -191,6 +192,15 @@ Expression = _reflection.GeneratedProtocolMessageType(
#
@@protoc_insertion_point(class_scope:spark.connect.Expression.UnresolvedExtractValue)
},
),
+ "UpdateFields": _reflection.GeneratedProtocolMessageType(
+ "UpdateFields",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _EXPRESSION_UPDATEFIELDS,
+ "__module__": "spark.connect.expressions_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.Expression.UpdateFields)
+ },
+ ),
"Alias": _reflection.GeneratedProtocolMessageType(
"Alias",
(_message.Message,),
@@ -229,6 +239,7 @@ _sym_db.RegisterMessage(Expression.ExpressionString)
_sym_db.RegisterMessage(Expression.UnresolvedStar)
_sym_db.RegisterMessage(Expression.UnresolvedRegex)
_sym_db.RegisterMessage(Expression.UnresolvedExtractValue)
+_sym_db.RegisterMessage(Expression.UpdateFields)
_sym_db.RegisterMessage(Expression.Alias)
_sym_db.RegisterMessage(Expression.LambdaFunction)
@@ -237,43 +248,45 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 78
- _EXPRESSION._serialized_end = 4083
- _EXPRESSION_WINDOW._serialized_start = 1053
- _EXPRESSION_WINDOW._serialized_end = 1836
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1343
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1836
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1610
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1755
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1757
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1836
- _EXPRESSION_SORTORDER._serialized_start = 1839
- _EXPRESSION_SORTORDER._serialized_end = 2264
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2069
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2177
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2179
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2264
- _EXPRESSION_CAST._serialized_start = 2267
- _EXPRESSION_CAST._serialized_end = 2412
- _EXPRESSION_LITERAL._serialized_start = 2415
- _EXPRESSION_LITERAL._serialized_end = 3291
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3058
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3175
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3177
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3275
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3293
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3363
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3366
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3570
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3572
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3622
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3624
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3664
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3666
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3710
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3713
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3845
- _EXPRESSION_ALIAS._serialized_start = 3847
- _EXPRESSION_ALIAS._serialized_end = 3967
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3969
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4070
+ _EXPRESSION._serialized_end = 4352
+ _EXPRESSION_WINDOW._serialized_start = 1132
+ _EXPRESSION_WINDOW._serialized_end = 1915
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1422
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 1915
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1689
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 1834
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 1836
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 1915
+ _EXPRESSION_SORTORDER._serialized_start = 1918
+ _EXPRESSION_SORTORDER._serialized_end = 2343
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2148
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2256
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2258
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2343
+ _EXPRESSION_CAST._serialized_start = 2346
+ _EXPRESSION_CAST._serialized_end = 2491
+ _EXPRESSION_LITERAL._serialized_start = 2494
+ _EXPRESSION_LITERAL._serialized_end = 3370
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3137
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3254
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3256
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3354
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3372
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3442
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3445
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3649
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3651
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3701
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3703
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3743
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3745
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3789
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 3792
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 3924
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 3927
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 4114
+ _EXPRESSION_ALIAS._serialized_start = 4116
+ _EXPRESSION_ALIAS._serialized_end = 4236
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4238
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4339
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 6a248a04767..5e5eab5b5d9 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -766,6 +766,50 @@ class Expression(google.protobuf.message.Message):
field_name: typing_extensions.Literal["child", b"child",
"extraction", b"extraction"],
) -> None: ...
+ class UpdateFields(google.protobuf.message.Message):
+ """Add, replace or drop a field of `StructType` expression by name."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ STRUCT_EXPRESSION_FIELD_NUMBER: builtins.int
+ FIELD_NAME_FIELD_NUMBER: builtins.int
+ VALUE_EXPRESSION_FIELD_NUMBER: builtins.int
+ @property
+ def struct_expression(self) -> global___Expression:
+ """(Required) The struct expression."""
+ field_name: builtins.str
+ """(Required) The field name."""
+ @property
+ def value_expression(self) -> global___Expression:
+ """(Optional) The expression to add or replace.
+
+ When not set, it means this field will be dropped.
+ """
+ def __init__(
+ self,
+ *,
+ struct_expression: global___Expression | None = ...,
+ field_name: builtins.str = ...,
+ value_expression: global___Expression | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "struct_expression", b"struct_expression", "value_expression",
b"value_expression"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "field_name",
+ b"field_name",
+ "struct_expression",
+ b"struct_expression",
+ "value_expression",
+ b"value_expression",
+ ],
+ ) -> None: ...
+
class Alias(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -853,6 +897,7 @@ class Expression(google.protobuf.message.Message):
LAMBDA_FUNCTION_FIELD_NUMBER: builtins.int
WINDOW_FIELD_NUMBER: builtins.int
UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int
+ UPDATE_FIELDS_FIELD_NUMBER: builtins.int
@property
def literal(self) -> global___Expression.Literal: ...
@property
@@ -877,6 +922,8 @@ class Expression(google.protobuf.message.Message):
def window(self) -> global___Expression.Window: ...
@property
def unresolved_extract_value(self) ->
global___Expression.UnresolvedExtractValue: ...
+ @property
+ def update_fields(self) -> global___Expression.UpdateFields: ...
def __init__(
self,
*,
@@ -892,6 +939,7 @@ class Expression(google.protobuf.message.Message):
lambda_function: global___Expression.LambdaFunction | None = ...,
window: global___Expression.Window | None = ...,
unresolved_extract_value: global___Expression.UnresolvedExtractValue |
None = ...,
+ update_fields: global___Expression.UpdateFields | None = ...,
) -> None: ...
def HasField(
self,
@@ -920,6 +968,8 @@ class Expression(google.protobuf.message.Message):
b"unresolved_regex",
"unresolved_star",
b"unresolved_star",
+ "update_fields",
+ b"update_fields",
"window",
b"window",
],
@@ -951,6 +1001,8 @@ class Expression(google.protobuf.message.Message):
b"unresolved_regex",
"unresolved_star",
b"unresolved_star",
+ "update_fields",
+ b"update_fields",
"window",
b"window",
],
@@ -970,6 +1022,7 @@ class Expression(google.protobuf.message.Message):
"lambda_function",
"window",
"unresolved_extract_value",
+ "update_fields",
] | None: ...
global___Expression = Expression
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py
b/python/pyspark/sql/tests/connect/test_connect_column.py
index 0be990ebbe1..9d18a1fe9b2 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -54,6 +54,7 @@ from pyspark.sql.types import (
BooleanType,
)
from pyspark.testing.connectutils import should_test_connect
+from pyspark.sql.connect.client import SparkConnectException
if should_test_connect:
import pandas as pd
@@ -61,6 +62,24 @@ if should_test_connect:
class SparkConnectTests(SparkConnectSQLTestCase):
+ def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
+ from pyspark.sql.dataframe import DataFrame as SDF
+ from pyspark.sql.connect.dataframe import DataFrame as CDF
+
+ assert isinstance(df1, (SDF, CDF))
+ if isinstance(df1, SDF):
+ str1 = df1._jdf.showString(n, truncate, False)
+ else:
+ str1 = df1._show_string(n, truncate, False)
+
+ assert isinstance(df2, (SDF, CDF))
+ if isinstance(df2, SDF):
+ str2 = df2._jdf.showString(n, truncate, False)
+ else:
+ str2 = df2._show_string(n, truncate, False)
+
+ self.assertEqual(str1, str2)
+
def test_column_operator(self):
# SPARK-41351: Column needs to support !=
df = self.connect.range(10)
@@ -184,6 +203,13 @@ class SparkConnectTests(SparkConnectSQLTestCase):
):
not (cdf.a > 2)
+ with self.assertRaisesRegex(
+ TypeError,
+ "Column is not iterable",
+ ):
+ for x in cdf.a:
+ pass
+
def test_datetime(self):
query = """
SELECT * FROM VALUES
@@ -743,19 +769,118 @@ class SparkConnectTests(SparkConnectSQLTestCase):
sdf.select(sdf.a ** sdf["b"], sdf.d**2, 2**sdf.c).toPandas(),
)
- def test_unsupported_functions(self):
- # SPARK-41225: Disable unsupported functions.
- c = self.connect.range(1).id
- for f in (
- "withField",
- "dropFields",
+ def test_column_field_ops(self):
+ # SPARK-41767: test withField, dropFields
+
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT STRUCT(a, b, c, d) AS x, e FROM VALUES
+ (float(1.0), double(1.0), '2022', 1, 0),
+ (float(2.0), double(2.0), '2018', NULL, 2),
+ (float(3.0), double(3.0), NULL, 3, NULL)
+ AS tab(a, b, c, d, e)
+ """
+
+ # +----------------------+----+
+ # | x| e|
+ # +----------------------+----+
+ # | {1.0, 1.0, 2022, 1}| 0|
+ # |{2.0, 2.0, 2018, null}| 2|
+ # | {3.0, 3.0, null, 3}|null|
+ # +----------------------+----+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # add field
+ self.compare_by_show(
+ cdf.select(cdf.x.withField("z", cdf.e)),
+ sdf.select(sdf.x.withField("z", sdf.e)),
+ truncate=100,
+ )
+ self.compare_by_show(
+ cdf.select(cdf.x.withField("z", CF.col("e"))),
+ sdf.select(sdf.x.withField("z", SF.col("e"))),
+ truncate=100,
+ )
+ self.compare_by_show(
+ cdf.select(cdf.x.withField("z", CF.lit("xyz"))),
+ sdf.select(sdf.x.withField("z", SF.lit("xyz"))),
+ truncate=100,
+ )
+
+ # replace field
+ self.compare_by_show(
+ cdf.select(cdf.x.withField("a", cdf.e)),
+ sdf.select(sdf.x.withField("a", sdf.e)),
+ truncate=100,
+ )
+ self.compare_by_show(
+ cdf.select(cdf.x.withField("a", CF.col("e"))),
+ sdf.select(sdf.x.withField("a", SF.col("e"))),
+ truncate=100,
+ )
+ self.compare_by_show(
+ cdf.select(cdf.x.withField("a", CF.lit("xyz"))),
+ sdf.select(sdf.x.withField("a", SF.lit("xyz"))),
+ truncate=100,
+ )
+
+ # drop field
+ self.compare_by_show(
+ cdf.select(cdf.x.dropFields("a")),
+ sdf.select(sdf.x.dropFields("a")),
+ truncate=100,
+ )
+ self.compare_by_show(
+ cdf.select(cdf.x.dropFields("z")),
+ sdf.select(sdf.x.dropFields("z")),
+ truncate=100,
+ )
+ self.compare_by_show(
+ cdf.select(cdf.x.dropFields("a", "b", "z")),
+ sdf.select(sdf.x.dropFields("a", "b", "z")),
+ truncate=100,
+ )
+
+ # check error
+ # invalid column: not a struct column
+ with self.assertRaises(SparkConnectException):
+ cdf.select(cdf.e.withField("a", CF.lit(1))).show()
+
+ # invalid column: not a struct column
+ with self.assertRaises(SparkConnectException):
+ cdf.select(cdf.e.dropFields("a")).show()
+
+ # cannot drop all fields in struct
+ with self.assertRaises(SparkConnectException):
+ cdf.select(cdf.x.dropFields("a", "b", "c", "d")).show()
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "fieldName should be a string",
):
- with self.assertRaises(NotImplementedError):
- getattr(c, f)()
+ cdf.select(cdf.x.withField(CF.col("a"), cdf.e)).show()
- with self.assertRaises(TypeError):
- for x in c:
- pass
+ with self.assertRaisesRegex(
+ TypeError,
+ "col should be a Column",
+ ):
+ cdf.select(cdf.x.withField("a", 2)).show()
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "fieldName should be a string",
+ ):
+ cdf.select(cdf.x.dropFields("a", 1, 2)).show()
+
+ with self.assertRaisesRegex(
+ ValueError,
+ "dropFields requires at least 1 field",
+ ):
+ cdf.select(cdf.x.dropFields()).show()
def test_column_string_ops(self):
# SPARK-41764: test string ops
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]