This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 55fe6be9ba8 [SPARK-41703][CONNECT][PYTHON] Combine NullType and
typed_null in Literal
55fe6be9ba8 is described below
commit 55fe6be9ba808d8b54004dbdb8540b1cda19984d
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Dec 26 16:12:10 2022 +0800
[SPARK-41703][CONNECT][PYTHON] Combine NullType and typed_null in Literal
### What changes were proposed in this pull request?
Combine NullType and typed_null in Literal
### Why are the changes needed?
we can use typed_null to express NullType-d null
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
updated ut
Closes #39209 from zhengruifeng/connect_combine_nulls.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../main/protobuf/spark/connect/expressions.proto | 14 +-----
.../planner/LiteralValueProtoConverter.scala | 11 ++---
python/pyspark/sql/connect/dataframe.py | 6 +++
python/pyspark/sql/connect/expressions.py | 6 +--
python/pyspark/sql/connect/plan.py | 50 +++++++++++-----------
.../pyspark/sql/connect/proto/expressions_pb2.py | 42 +++++++++---------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 32 ++------------
python/pyspark/sql/connect/types.py | 4 +-
.../sql/tests/connect/test_connect_basic.py | 2 +-
.../connect/test_connect_column_expressions.py | 2 +-
10 files changed, 68 insertions(+), 101 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 65cf9291a2f..90d880939a8 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -139,7 +139,7 @@ message Expression {
message Literal {
oneof literal_type {
- bool null = 1;
+ DataType null = 1;
bytes binary = 2;
bool boolean = 3;
@@ -163,20 +163,8 @@ message Expression {
CalendarInterval calendar_interval = 19;
int32 year_month_interval = 20;
int64 day_time_interval = 21;
-
- DataType typed_null = 22;
}
- // whether the literal type should be treated as a nullable type. Applies
to
- // all members of union other than the Typed null (which should directly
- // declare nullability).
- bool nullable = 50;
-
- // optionally points to a type_variation_anchor defined in this plan.
- // Applies to all members of union other than the Typed null (which should
- // directly declare the type variation).
- uint32 type_variation_reference = 51;
-
message Decimal {
// the string representation.
string value = 1;
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
index 82ffa4f5246..a29a640e7b8 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -33,7 +33,7 @@ object LiteralValueProtoConverter {
def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal
= {
lit.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.NULL =>
- expressions.Literal(null, NullType)
+ expressions.Literal(null,
DataTypeProtoConverter.toCatalystType(lit.getNull))
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
expressions.Literal(lit.getBinary.toByteArray, BinaryType)
@@ -96,9 +96,6 @@ object LiteralValueProtoConverter {
case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
- case proto.Expression.Literal.LiteralTypeCase.TYPED_NULL =>
- expressions.Literal(null,
DataTypeProtoConverter.toCatalystType(lit.getTypedNull))
-
case _ =>
throw InvalidPlanInput(
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
@@ -116,7 +113,11 @@ object LiteralValueProtoConverter {
def toConnectProtoValue(value: Any): proto.Expression.Literal = {
value match {
- case null => proto.Expression.Literal.newBuilder().setNull(true).build()
+ case null =>
+ proto.Expression.Literal
+ .newBuilder()
+ .setNull(DataTypeProtoConverter.toConnectProtoType(NullType))
+ .build()
case b: Boolean =>
proto.Expression.Literal.newBuilder().setBoolean(b).build()
case b: Byte => proto.Expression.Literal.newBuilder().setByte(b).build()
case s: Short =>
proto.Expression.Literal.newBuilder().setShort(s).build()
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 4c15fbf0f22..69ba9147ccc 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -481,6 +481,12 @@ class DataFrame:
melt = unpivot
def hint(self, name: str, *params: Any) -> "DataFrame":
+ for param in params:
+ if param is not None and not isinstance(param, (int, str)):
+ raise TypeError(
+ f"param should be a int or str, but got
{type(param).__name__} {param}"
+ )
+
return DataFrame.withPlan(
plan.Hint(self._plan, name, list(params)),
session=self._session,
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 02d6047bd66..1f63d6b0a10 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -270,10 +270,8 @@ class LiteralExpression(Expression):
expr = proto.Expression()
- if isinstance(self._dataType, NullType):
- expr.literal.null = True
- elif self._value is None:
-
expr.literal.typed_null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
+ if self._value is None:
+
expr.literal.null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
elif isinstance(self._dataType, BinaryType):
expr.literal.binary = bytes(self._value)
elif isinstance(self._dataType, BooleanType):
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 1e3c230a81a..fe2105d7d2c 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -23,7 +23,7 @@ from pyspark.sql.types import DataType
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import SortOrder, ColumnReference
+from pyspark.sql.connect.expressions import SortOrder, ColumnReference,
LiteralExpression
from pyspark.sql.connect.types import pyspark_types_to_proto_types
if TYPE_CHECKING:
@@ -349,20 +349,15 @@ class Hint(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], name: str, params:
List[Any]) -> None:
super().__init__(child)
+
+ assert isinstance(name, str)
+
self.name = name
- self.params = params
- def _convert_value(self, v: Any) -> proto.Expression.Literal:
- value = proto.Expression.Literal()
- if v is None:
- value.null = True
- elif isinstance(v, int):
- value.integer = v
- elif isinstance(v, str):
- value.string = v
- else:
- raise ValueError(f"Could not convert literal for type {type(v)}")
- return value
+ assert isinstance(params, list) and all(
+ p is None or isinstance(p, (int, str)) for p in params
+ )
+ self.params = params
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
@@ -370,7 +365,7 @@ class Hint(LogicalPlan):
plan.hint.input.CopyFrom(self._child.plan(session))
plan.hint.name = self.name
for v in self.params:
- plan.hint.parameters.append(self._convert_value(v))
+
plan.hint.parameters.append(LiteralExpression._from_value(v).to_plan(session).literal)
return plan
def print(self, indent: int = 0) -> str:
@@ -1285,17 +1280,12 @@ class NAReplace(LogicalPlan):
self.cols = cols
self.replacements = replacements
- def _convert_value(self, v: Any) -> proto.Expression.Literal:
- value = proto.Expression.Literal()
- if v is None:
- value.null = True
- elif isinstance(v, bool):
- value.boolean = v
- elif isinstance(v, (int, float)):
- value.double = float(v)
+ def _convert_int_to_float(self, v: Any) -> Any:
+ # a bool is also an int
+ if v is not None and not isinstance(v, bool) and isinstance(v, int):
+ return float(v)
else:
- value.string = v
- return value
+ return v
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
@@ -1306,8 +1296,16 @@ class NAReplace(LogicalPlan):
if len(self.replacements) > 0:
for old_value, new_value in self.replacements.items():
replacement = proto.NAReplace.Replacement()
- replacement.old_value.CopyFrom(self._convert_value(old_value))
- replacement.new_value.CopyFrom(self._convert_value(new_value))
+ replacement.old_value.CopyFrom(
+
LiteralExpression._from_value(self._convert_int_to_float(old_value))
+ .to_plan(session)
+ .literal
+ )
+ replacement.new_value.CopyFrom(
+
LiteralExpression._from_value(self._convert_int_to_float(new_value))
+ .to_plan(session)
+ .literal
+ )
plan.replace.replacements.append(replacement)
return plan
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 9e5b887348c..5e4d25b8b94 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa7\x1e\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb0\x1d\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
[...]
)
@@ -226,7 +226,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 78
- _EXPRESSION._serialized_end = 3957
+ _EXPRESSION._serialized_end = 3838
_EXPRESSION_WINDOW._serialized_start = 943
_EXPRESSION_WINDOW._serialized_end = 1726
_EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1233
@@ -244,23 +244,23 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_CAST._serialized_start = 2157
_EXPRESSION_CAST._serialized_end = 2302
_EXPRESSION_LITERAL._serialized_start = 2305
- _EXPRESSION_LITERAL._serialized_end = 3300
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3067
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3184
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3186
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3284
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3302
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3372
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3375
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3579
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3581
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3631
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3633
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3673
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3675
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3719
- _EXPRESSION_ALIAS._serialized_start = 3721
- _EXPRESSION_ALIAS._serialized_end = 3841
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3843
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 3944
+ _EXPRESSION_LITERAL._serialized_end = 3181
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 2948
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3065
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3067
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3165
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3183
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3253
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3256
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3460
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3462
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3512
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3514
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3554
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3556
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3600
+ _EXPRESSION_ALIAS._serialized_start = 3602
+ _EXPRESSION_ALIAS._serialized_end = 3722
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3724
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 3825
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 3540a563726..26002e649d5 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -456,10 +456,8 @@ class Expression(google.protobuf.message.Message):
CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int
YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int
DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int
- TYPED_NULL_FIELD_NUMBER: builtins.int
- NULLABLE_FIELD_NUMBER: builtins.int
- TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
- null: builtins.bool
+ @property
+ def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
binary: builtins.bytes
boolean: builtins.bool
byte: builtins.int
@@ -481,22 +479,10 @@ class Expression(google.protobuf.message.Message):
def calendar_interval(self) ->
global___Expression.Literal.CalendarInterval: ...
year_month_interval: builtins.int
day_time_interval: builtins.int
- @property
- def typed_null(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
...
- nullable: builtins.bool
- """whether the literal type should be treated as a nullable type.
Applies to
- all members of union other than the Typed null (which should directly
- declare nullability).
- """
- type_variation_reference: builtins.int
- """optionally points to a type_variation_anchor defined in this plan.
- Applies to all members of union other than the Typed null (which should
- directly declare the type variation).
- """
def __init__(
self,
*,
- null: builtins.bool = ...,
+ null: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
binary: builtins.bytes = ...,
boolean: builtins.bool = ...,
byte: builtins.int = ...,
@@ -513,9 +499,6 @@ class Expression(google.protobuf.message.Message):
calendar_interval: global___Expression.Literal.CalendarInterval |
None = ...,
year_month_interval: builtins.int = ...,
day_time_interval: builtins.int = ...,
- typed_null: pyspark.sql.connect.proto.types_pb2.DataType | None =
...,
- nullable: builtins.bool = ...,
- type_variation_reference: builtins.int = ...,
) -> None: ...
def HasField(
self,
@@ -554,8 +537,6 @@ class Expression(google.protobuf.message.Message):
b"timestamp",
"timestamp_ntz",
b"timestamp_ntz",
- "typed_null",
- b"typed_null",
"year_month_interval",
b"year_month_interval",
],
@@ -589,8 +570,6 @@ class Expression(google.protobuf.message.Message):
b"long",
"null",
b"null",
- "nullable",
- b"nullable",
"short",
b"short",
"string",
@@ -599,10 +578,6 @@ class Expression(google.protobuf.message.Message):
b"timestamp",
"timestamp_ntz",
b"timestamp_ntz",
- "type_variation_reference",
- b"type_variation_reference",
- "typed_null",
- b"typed_null",
"year_month_interval",
b"year_month_interval",
],
@@ -627,7 +602,6 @@ class Expression(google.protobuf.message.Message):
"calendar_interval",
"year_month_interval",
"day_time_interval",
- "typed_null",
] | None: ...
class UnresolvedAttribute(google.protobuf.message.Message):
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index 4f64aa40597..e6c179de8cb 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -57,7 +57,9 @@ JVM_LONG_MAX: int = (1 << 63) - 1
def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
ret = pb2.DataType()
- if isinstance(data_type, StringType):
+ if isinstance(data_type, NullType):
+ ret.null.CopyFrom(pb2.DataType.NULL())
+ elif isinstance(data_type, StringType):
ret.string.CopyFrom(pb2.DataType.String())
elif isinstance(data_type, BooleanType):
ret.boolean.CopyFrom(pb2.DataType.Boolean())
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 4c9e29326e1..ab05efc5e5d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -932,7 +932,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
self.connect.read.table(self.tbl_name).hint("REPARTITION",
"id+1").toPandas()
# Hint with unsupported parameter types
- with self.assertRaises(ValueError):
+ with self.assertRaises(TypeError):
self.connect.read.table(self.tbl_name).hint("REPARTITION",
1.1).toPandas()
# Hint with wrong combination
diff --git
a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index d1a9bdaddd2..55ecb859805 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -56,7 +56,7 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
def test_null_literal(self):
null_lit = fun.lit(None)
null_lit_p = null_lit.to_plan(None)
- self.assertEqual(null_lit_p.literal.null, True)
+ self.assertEqual(null_lit_p.literal.HasField("null"), True)
def test_binary_literal(self):
val = b"binary\0\0asas"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]