This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 55fe6be9ba8 [SPARK-41703][CONNECT][PYTHON] Combine NullType and 
typed_null in Literal
55fe6be9ba8 is described below

commit 55fe6be9ba808d8b54004dbdb8540b1cda19984d
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Dec 26 16:12:10 2022 +0800

    [SPARK-41703][CONNECT][PYTHON] Combine NullType and typed_null in Literal
    
    ### What changes were proposed in this pull request?
    Combine NullType and typed_null in Literal
    
    ### Why are the changes needed?
    we can use typed_null to express NullType-d null
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    updated ut
    
    Closes #39209 from zhengruifeng/connect_combine_nulls.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../main/protobuf/spark/connect/expressions.proto  | 14 +-----
 .../planner/LiteralValueProtoConverter.scala       | 11 ++---
 python/pyspark/sql/connect/dataframe.py            |  6 +++
 python/pyspark/sql/connect/expressions.py          |  6 +--
 python/pyspark/sql/connect/plan.py                 | 50 +++++++++++-----------
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 42 +++++++++---------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 32 ++------------
 python/pyspark/sql/connect/types.py                |  4 +-
 .../sql/tests/connect/test_connect_basic.py        |  2 +-
 .../connect/test_connect_column_expressions.py     |  2 +-
 10 files changed, 68 insertions(+), 101 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 65cf9291a2f..90d880939a8 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -139,7 +139,7 @@ message Expression {
 
   message Literal {
     oneof literal_type {
-      bool null = 1;
+      DataType null = 1;
       bytes binary = 2;
       bool boolean = 3;
 
@@ -163,20 +163,8 @@ message Expression {
       CalendarInterval calendar_interval = 19;
       int32 year_month_interval = 20;
       int64 day_time_interval = 21;
-
-      DataType typed_null = 22;
     }
 
-    // whether the literal type should be treated as a nullable type. Applies 
to
-    // all members of union other than the Typed null (which should directly
-    // declare nullability).
-    bool nullable = 50;
-
-    // optionally points to a type_variation_anchor defined in this plan.
-    // Applies to all members of union other than the Typed null (which should
-    // directly declare the type variation).
-    uint32 type_variation_reference = 51;
-
     message Decimal {
       // the string representation.
       string value = 1;
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
index 82ffa4f5246..a29a640e7b8 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -33,7 +33,7 @@ object LiteralValueProtoConverter {
   def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal 
= {
     lit.getLiteralTypeCase match {
       case proto.Expression.Literal.LiteralTypeCase.NULL =>
-        expressions.Literal(null, NullType)
+        expressions.Literal(null, 
DataTypeProtoConverter.toCatalystType(lit.getNull))
 
       case proto.Expression.Literal.LiteralTypeCase.BINARY =>
         expressions.Literal(lit.getBinary.toByteArray, BinaryType)
@@ -96,9 +96,6 @@ object LiteralValueProtoConverter {
       case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
         expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
 
-      case proto.Expression.Literal.LiteralTypeCase.TYPED_NULL =>
-        expressions.Literal(null, 
DataTypeProtoConverter.toCatalystType(lit.getTypedNull))
-
       case _ =>
         throw InvalidPlanInput(
           s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
@@ -116,7 +113,11 @@ object LiteralValueProtoConverter {
 
   def toConnectProtoValue(value: Any): proto.Expression.Literal = {
     value match {
-      case null => proto.Expression.Literal.newBuilder().setNull(true).build()
+      case null =>
+        proto.Expression.Literal
+          .newBuilder()
+          .setNull(DataTypeProtoConverter.toConnectProtoType(NullType))
+          .build()
       case b: Boolean => 
proto.Expression.Literal.newBuilder().setBoolean(b).build()
       case b: Byte => proto.Expression.Literal.newBuilder().setByte(b).build()
       case s: Short => 
proto.Expression.Literal.newBuilder().setShort(s).build()
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 4c15fbf0f22..69ba9147ccc 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -481,6 +481,12 @@ class DataFrame:
     melt = unpivot
 
     def hint(self, name: str, *params: Any) -> "DataFrame":
+        for param in params:
+            if param is not None and not isinstance(param, (int, str)):
+                raise TypeError(
+                    f"param should be a int or str, but got 
{type(param).__name__} {param}"
+                )
+
         return DataFrame.withPlan(
             plan.Hint(self._plan, name, list(params)),
             session=self._session,
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 02d6047bd66..1f63d6b0a10 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -270,10 +270,8 @@ class LiteralExpression(Expression):
 
         expr = proto.Expression()
 
-        if isinstance(self._dataType, NullType):
-            expr.literal.null = True
-        elif self._value is None:
-            
expr.literal.typed_null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
+        if self._value is None:
+            
expr.literal.null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
         elif isinstance(self._dataType, BinaryType):
             expr.literal.binary = bytes(self._value)
         elif isinstance(self._dataType, BooleanType):
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 1e3c230a81a..fe2105d7d2c 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -23,7 +23,7 @@ from pyspark.sql.types import DataType
 
 import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import SortOrder, ColumnReference
+from pyspark.sql.connect.expressions import SortOrder, ColumnReference, 
LiteralExpression
 from pyspark.sql.connect.types import pyspark_types_to_proto_types
 
 if TYPE_CHECKING:
@@ -349,20 +349,15 @@ class Hint(LogicalPlan):
 
     def __init__(self, child: Optional["LogicalPlan"], name: str, params: 
List[Any]) -> None:
         super().__init__(child)
+
+        assert isinstance(name, str)
+
         self.name = name
-        self.params = params
 
-    def _convert_value(self, v: Any) -> proto.Expression.Literal:
-        value = proto.Expression.Literal()
-        if v is None:
-            value.null = True
-        elif isinstance(v, int):
-            value.integer = v
-        elif isinstance(v, str):
-            value.string = v
-        else:
-            raise ValueError(f"Could not convert literal for type {type(v)}")
-        return value
+        assert isinstance(params, list) and all(
+            p is None or isinstance(p, (int, str)) for p in params
+        )
+        self.params = params
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
@@ -370,7 +365,7 @@ class Hint(LogicalPlan):
         plan.hint.input.CopyFrom(self._child.plan(session))
         plan.hint.name = self.name
         for v in self.params:
-            plan.hint.parameters.append(self._convert_value(v))
+            
plan.hint.parameters.append(LiteralExpression._from_value(v).to_plan(session).literal)
         return plan
 
     def print(self, indent: int = 0) -> str:
@@ -1285,17 +1280,12 @@ class NAReplace(LogicalPlan):
         self.cols = cols
         self.replacements = replacements
 
-    def _convert_value(self, v: Any) -> proto.Expression.Literal:
-        value = proto.Expression.Literal()
-        if v is None:
-            value.null = True
-        elif isinstance(v, bool):
-            value.boolean = v
-        elif isinstance(v, (int, float)):
-            value.double = float(v)
+    def _convert_int_to_float(self, v: Any) -> Any:
+        # a bool is also an int
+        if v is not None and not isinstance(v, bool) and isinstance(v, int):
+            return float(v)
         else:
-            value.string = v
-        return value
+            return v
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
@@ -1306,8 +1296,16 @@ class NAReplace(LogicalPlan):
         if len(self.replacements) > 0:
             for old_value, new_value in self.replacements.items():
                 replacement = proto.NAReplace.Replacement()
-                replacement.old_value.CopyFrom(self._convert_value(old_value))
-                replacement.new_value.CopyFrom(self._convert_value(new_value))
+                replacement.old_value.CopyFrom(
+                    
LiteralExpression._from_value(self._convert_int_to_float(old_value))
+                    .to_plan(session)
+                    .literal
+                )
+                replacement.new_value.CopyFrom(
+                    
LiteralExpression._from_value(self._convert_int_to_float(new_value))
+                    .to_plan(session)
+                    .literal
+                )
                 plan.replace.replacements.append(replacement)
         return plan
 
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 9e5b887348c..5e4d25b8b94 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa7\x1e\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb0\x1d\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
 )
 
 
@@ -226,7 +226,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 78
-    _EXPRESSION._serialized_end = 3957
+    _EXPRESSION._serialized_end = 3838
     _EXPRESSION_WINDOW._serialized_start = 943
     _EXPRESSION_WINDOW._serialized_end = 1726
     _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1233
@@ -244,23 +244,23 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_CAST._serialized_start = 2157
     _EXPRESSION_CAST._serialized_end = 2302
     _EXPRESSION_LITERAL._serialized_start = 2305
-    _EXPRESSION_LITERAL._serialized_end = 3300
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3067
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3184
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3186
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3284
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3302
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3372
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3375
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3579
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3581
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3631
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3633
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3673
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3675
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3719
-    _EXPRESSION_ALIAS._serialized_start = 3721
-    _EXPRESSION_ALIAS._serialized_end = 3841
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3843
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 3944
+    _EXPRESSION_LITERAL._serialized_end = 3181
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 2948
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3065
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3067
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3165
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3183
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3253
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3256
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3460
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3462
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3512
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3514
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3554
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3556
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 3600
+    _EXPRESSION_ALIAS._serialized_start = 3602
+    _EXPRESSION_ALIAS._serialized_end = 3722
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 3724
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 3825
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 3540a563726..26002e649d5 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -456,10 +456,8 @@ class Expression(google.protobuf.message.Message):
         CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int
         YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int
         DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int
-        TYPED_NULL_FIELD_NUMBER: builtins.int
-        NULLABLE_FIELD_NUMBER: builtins.int
-        TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
-        null: builtins.bool
+        @property
+        def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
         binary: builtins.bytes
         boolean: builtins.bool
         byte: builtins.int
@@ -481,22 +479,10 @@ class Expression(google.protobuf.message.Message):
         def calendar_interval(self) -> 
global___Expression.Literal.CalendarInterval: ...
         year_month_interval: builtins.int
         day_time_interval: builtins.int
-        @property
-        def typed_null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: 
...
-        nullable: builtins.bool
-        """whether the literal type should be treated as a nullable type. 
Applies to
-        all members of union other than the Typed null (which should directly
-        declare nullability).
-        """
-        type_variation_reference: builtins.int
-        """optionally points to a type_variation_anchor defined in this plan.
-        Applies to all members of union other than the Typed null (which should
-        directly declare the type variation).
-        """
         def __init__(
             self,
             *,
-            null: builtins.bool = ...,
+            null: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
             binary: builtins.bytes = ...,
             boolean: builtins.bool = ...,
             byte: builtins.int = ...,
@@ -513,9 +499,6 @@ class Expression(google.protobuf.message.Message):
             calendar_interval: global___Expression.Literal.CalendarInterval | 
None = ...,
             year_month_interval: builtins.int = ...,
             day_time_interval: builtins.int = ...,
-            typed_null: pyspark.sql.connect.proto.types_pb2.DataType | None = 
...,
-            nullable: builtins.bool = ...,
-            type_variation_reference: builtins.int = ...,
         ) -> None: ...
         def HasField(
             self,
@@ -554,8 +537,6 @@ class Expression(google.protobuf.message.Message):
                 b"timestamp",
                 "timestamp_ntz",
                 b"timestamp_ntz",
-                "typed_null",
-                b"typed_null",
                 "year_month_interval",
                 b"year_month_interval",
             ],
@@ -589,8 +570,6 @@ class Expression(google.protobuf.message.Message):
                 b"long",
                 "null",
                 b"null",
-                "nullable",
-                b"nullable",
                 "short",
                 b"short",
                 "string",
@@ -599,10 +578,6 @@ class Expression(google.protobuf.message.Message):
                 b"timestamp",
                 "timestamp_ntz",
                 b"timestamp_ntz",
-                "type_variation_reference",
-                b"type_variation_reference",
-                "typed_null",
-                b"typed_null",
                 "year_month_interval",
                 b"year_month_interval",
             ],
@@ -627,7 +602,6 @@ class Expression(google.protobuf.message.Message):
             "calendar_interval",
             "year_month_interval",
             "day_time_interval",
-            "typed_null",
         ] | None: ...
 
     class UnresolvedAttribute(google.protobuf.message.Message):
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index 4f64aa40597..e6c179de8cb 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -57,7 +57,9 @@ JVM_LONG_MAX: int = (1 << 63) - 1
 
 def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
     ret = pb2.DataType()
-    if isinstance(data_type, StringType):
+    if isinstance(data_type, NullType):
+        ret.null.CopyFrom(pb2.DataType.NULL())
+    elif isinstance(data_type, StringType):
         ret.string.CopyFrom(pb2.DataType.String())
     elif isinstance(data_type, BooleanType):
         ret.boolean.CopyFrom(pb2.DataType.Boolean())
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 4c9e29326e1..ab05efc5e5d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -932,7 +932,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             self.connect.read.table(self.tbl_name).hint("REPARTITION", 
"id+1").toPandas()
 
         # Hint with unsupported parameter types
-        with self.assertRaises(ValueError):
+        with self.assertRaises(TypeError):
             self.connect.read.table(self.tbl_name).hint("REPARTITION", 
1.1).toPandas()
 
         # Hint with wrong combination
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py 
b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index d1a9bdaddd2..55ecb859805 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -56,7 +56,7 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
     def test_null_literal(self):
         null_lit = fun.lit(None)
         null_lit_p = null_lit.to_plan(None)
-        self.assertEqual(null_lit_p.literal.null, True)
+        self.assertEqual(null_lit_p.literal.HasField("null"), True)
 
     def test_binary_literal(self):
         val = b"binary\0\0asas"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to